mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:47:33 +00:00
feat(eventstore): increase parallel write capabilities (#5940)
This implementation increases parallel write capabilities of the eventstore. Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and [06](https://zitadel.com/docs/support/advisory/a10006). The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`. If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
@@ -1,51 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/internal/repository"
|
||||
z_sql "github.com/zitadel/zitadel/internal/eventstore/v1/internal/repository/sql"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type Eventstore interface {
|
||||
Health(ctx context.Context) error
|
||||
FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (events []*models.Event, err error)
|
||||
Subscribe(aggregates ...models.AggregateType) *Subscription
|
||||
InstanceIDs(ctx context.Context, searchQuery *models.SearchQuery) ([]string, error)
|
||||
}
|
||||
|
||||
var _ Eventstore = (*eventstore)(nil)
|
||||
|
||||
type eventstore struct {
|
||||
repo repository.Repository
|
||||
}
|
||||
|
||||
func Start(db *database.DB, allowOrderByCreationDate bool) (Eventstore, error) {
|
||||
return &eventstore{
|
||||
repo: z_sql.Start(db, allowOrderByCreationDate),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (_ []*models.Event, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if err := searchQuery.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return es.repo.Filter(ctx, models.FactoryFromSearchQuery(searchQuery))
|
||||
}
|
||||
|
||||
func (es *eventstore) Health(ctx context.Context) error {
|
||||
return es.repo.Health(ctx)
|
||||
}
|
||||
|
||||
func (es *eventstore) InstanceIDs(ctx context.Context, searchQuery *models.SearchQuery) ([]string, error) {
|
||||
if err := searchQuery.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return es.repo.InstanceIDs(ctx, models.FactoryFromSearchQuery(searchQuery))
|
||||
}
|
@@ -1,3 +0,0 @@
|
||||
package v1
|
||||
|
||||
//go:generate mockgen -package mock -destination ./mock/eventstore.mock.go github.com/zitadel/zitadel/internal/eventstore Eventstore
|
@@ -1,3 +0,0 @@
|
||||
package repository
|
||||
|
||||
//go:generate mockgen -package mock -destination ./mock/repository.mock.go github.com/zitadel/zitadel/internal/eventstore/internal/repository Repository
|
@@ -1,34 +0,0 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
func NewMock(t *testing.T) *MockRepository {
|
||||
return NewMockRepository(gomock.NewController(t))
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectFilter(query *models.SearchQuery, eventAmount int) *MockRepository {
|
||||
events := make([]*models.Event, eventAmount)
|
||||
m.EXPECT().Filter(context.Background(), query).Return(events, nil).MaxTimes(1)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectFilterFail(query *models.SearchQuery, err error) *MockRepository {
|
||||
m.EXPECT().Filter(context.Background(), query).Return(nil, err).MaxTimes(1)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectPush(aggregates ...*models.Aggregate) *MockRepository {
|
||||
m.EXPECT().PushAggregates(context.Background(), aggregates).Return(nil).MaxTimes(1)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectPushError(err error, aggregates ...*models.Aggregate) *MockRepository {
|
||||
m.EXPECT().PushAggregates(context.Background(), aggregates).Return(err).MaxTimes(1)
|
||||
return m
|
||||
}
|
@@ -1,98 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/zitadel/internal/eventstore/internal/repository (interfaces: Repository)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
reflect "reflect"
|
||||
)
|
||||
|
||||
// MockRepository is a mock of Repository interface
|
||||
type MockRepository struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRepositoryMockRecorder
|
||||
}
|
||||
|
||||
// MockRepositoryMockRecorder is the mock recorder for MockRepository
|
||||
type MockRepositoryMockRecorder struct {
|
||||
mock *MockRepository
|
||||
}
|
||||
|
||||
// NewMockRepository creates a new mock instance
|
||||
func NewMockRepository(ctrl *gomock.Controller) *MockRepository {
|
||||
mock := &MockRepository{ctrl: ctrl}
|
||||
mock.recorder = &MockRepositoryMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use
|
||||
func (m *MockRepository) EXPECT() *MockRepositoryMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Filter mocks base method
|
||||
func (m *MockRepository) Filter(arg0 context.Context, arg1 *models.SearchQueryFactory) ([]*models.Event, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Filter", arg0, arg1)
|
||||
ret0, _ := ret[0].([]*models.Event)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Filter indicates an expected call of Filter
|
||||
func (mr *MockRepositoryMockRecorder) Filter(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockRepository)(nil).Filter), arg0, arg1)
|
||||
}
|
||||
|
||||
// Health mocks base method
|
||||
func (m *MockRepository) Health(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Health", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Health indicates an expected call of Health
|
||||
func (mr *MockRepositoryMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockRepository)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// LatestSequence mocks base method
|
||||
func (m *MockRepository) LatestSequence(arg0 context.Context, arg1 *models.SearchQueryFactory) (uint64, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LatestSequence", arg0, arg1)
|
||||
ret0, _ := ret[0].(uint64)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LatestSequence indicates an expected call of LatestSequence
|
||||
func (mr *MockRepositoryMockRecorder) LatestSequence(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestSequence", reflect.TypeOf((*MockRepository)(nil).LatestSequence), arg0, arg1)
|
||||
}
|
||||
|
||||
// PushAggregates mocks base method
|
||||
func (m *MockRepository) PushAggregates(arg0 context.Context, arg1 ...*models.Aggregate) error {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{arg0}
|
||||
for _, a := range arg1 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "PushAggregates", varargs...)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// PushAggregates indicates an expected call of PushAggregates
|
||||
func (mr *MockRepositoryMockRecorder) PushAggregates(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]interface{}{arg0}, arg1...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushAggregates", reflect.TypeOf((*MockRepository)(nil).PushAggregates), varargs...)
|
||||
}
|
@@ -1,18 +0,0 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
type Repository interface {
|
||||
Health(ctx context.Context) error
|
||||
|
||||
// Filter returns all events matching the given search query
|
||||
Filter(ctx context.Context, searchQuery *models.SearchQueryFactory) (events []*models.Event, err error)
|
||||
//LatestSequence returns the latest sequence found by the search query
|
||||
LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error)
|
||||
//InstanceIDs returns the instance ids found by the search query
|
||||
InstanceIDs(ctx context.Context, queryFactory *models.SearchQueryFactory) ([]string, error)
|
||||
}
|
@@ -1,12 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
func Start(client *database.DB, allowOrderByCreationDate bool) *SQL {
|
||||
return &SQL{
|
||||
client: client,
|
||||
allowOrderByCreationDate: allowOrderByCreationDate,
|
||||
}
|
||||
}
|
@@ -1,208 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"regexp"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
const (
|
||||
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1`
|
||||
)
|
||||
|
||||
var (
|
||||
eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_aggregate_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "instance_id", "aggregate_type", "aggregate_id", "aggregate_version"}
|
||||
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY creation_date, event_sequence LIMIT \$2`).String()
|
||||
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` \) ORDER BY creation_date DESC, event_sequence DESC`).String()
|
||||
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY creation_date, event_sequence LIMIT \$3`).String()
|
||||
expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 \) ORDER BY creation_date, event_sequence LIMIT \$3`).String()
|
||||
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` \) ORDER BY creation_date, event_sequence`).String()
|
||||
|
||||
expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` +
|
||||
`\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, instance_id, previous_aggregate_sequence, previous_aggregate_type_sequence\) ` +
|
||||
`SELECT \$1, \$2, \$3, \$4, COALESCE\(\$5, now\(\)\), \$6, \$7, \$8, \$9, \$10, \$11 ` +
|
||||
`WHERE EXISTS \(` +
|
||||
`SELECT 1 FROM eventstore\.events WHERE aggregate_type = \$12 AND aggregate_id = \$13 HAVING MAX\(event_sequence\) = \$14 OR \(\$14::BIGINT IS NULL AND COUNT\(\*\) = 0\)\) ` +
|
||||
`RETURNING event_sequence, creation_date`).String()
|
||||
)
|
||||
|
||||
type dbMock struct {
|
||||
sqlClient *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
}
|
||||
|
||||
func (db *dbMock) close() {
|
||||
db.sqlClient.Close()
|
||||
}
|
||||
|
||||
func mockDB(t *testing.T) *dbMock {
|
||||
mockDB := dbMock{}
|
||||
var err error
|
||||
mockDB.sqlClient, mockDB.mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occured while creating stub db %v", err)
|
||||
}
|
||||
|
||||
mockDB.mock.MatchExpectationsInOrder(true)
|
||||
|
||||
return &mockDB
|
||||
}
|
||||
|
||||
func (db *dbMock) expectBegin(err error) *dbMock {
|
||||
if err != nil {
|
||||
db.mock.ExpectBegin().WillReturnError(err)
|
||||
} else {
|
||||
db.mock.ExpectBegin()
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectSavepoint() *dbMock {
|
||||
db.mock.ExpectExec("SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectReleaseSavepoint(err error) *dbMock {
|
||||
expectation := db.mock.ExpectExec("RELEASE SAVEPOINT")
|
||||
if err == nil {
|
||||
expectation.WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
} else {
|
||||
expectation.WillReturnError(err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectCommit(err error) *dbMock {
|
||||
if err != nil {
|
||||
db.mock.ExpectCommit().WillReturnError(err)
|
||||
} else {
|
||||
db.mock.ExpectCommit()
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRollback(err error) *dbMock {
|
||||
if err != nil {
|
||||
db.mock.ExpectRollback().WillReturnError(err)
|
||||
} else {
|
||||
db.mock.ExpectRollback()
|
||||
}
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectInsertEvent(e *models.Event, returnedSequence uint64) *dbMock {
|
||||
db.mock.ExpectQuery(expectedInsertStatement).
|
||||
WithArgs(
|
||||
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.InstanceID, Sequence(e.PreviousSequence),
|
||||
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"event_sequence", "creation_date"}).
|
||||
AddRow(returnedSequence, time.Now().UTC()),
|
||||
)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectInsertEventError(e *models.Event) *dbMock {
|
||||
db.mock.ExpectQuery(expectedInsertStatement).
|
||||
WithArgs(
|
||||
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, e.InstanceID, Sequence(e.PreviousSequence),
|
||||
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
|
||||
).
|
||||
WillReturnError(sql.ErrTxDone)
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, eventCount int) *dbMock {
|
||||
rows := sqlmock.NewRows(eventColumns)
|
||||
for i := 0; i < eventCount; i++ {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsLimitFormat).
|
||||
WithArgs(aggregateType, limit).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *dbMock {
|
||||
rows := sqlmock.NewRows(eventColumns)
|
||||
for i := eventCount; i > 0; i-- {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsDescFormat).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
|
||||
rows := sqlmock.NewRows(eventColumns)
|
||||
for i := limit; i > 0; i-- {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit).
|
||||
WithArgs(aggregateType, aggregateID, limit).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
|
||||
rows := sqlmock.NewRows(eventColumns)
|
||||
for i := limit; i > 0; i-- {
|
||||
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
|
||||
}
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit).
|
||||
WithArgs(aggregateType, aggregateID, limit).
|
||||
WillReturnRows(rows)
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(expectedGetAllEvents).
|
||||
WillReturnError(returnedErr)
|
||||
db.mock.ExpectRollback()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
|
||||
WithArgs(aggregateType).
|
||||
WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence))
|
||||
db.mock.ExpectCommit()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
|
||||
WithArgs(aggregateType).WillReturnError(err)
|
||||
// db.mock.ExpectRollback()
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectPrepareInsert(err error) *dbMock {
|
||||
prepare := db.mock.ExpectPrepare(expectedInsertStatement)
|
||||
if err != nil {
|
||||
prepare.WillReturnError(err)
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
@@ -1,99 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type Querier interface {
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
}
|
||||
|
||||
func (db *SQL) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
|
||||
if !searchQuery.InstanceFiltered {
|
||||
logging.WithFields("stack", string(debug.Stack())).Warn("instanceid not filtered")
|
||||
}
|
||||
return db.filter(ctx, db.client, searchQuery)
|
||||
}
|
||||
|
||||
func (server *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
|
||||
query, limit, values, rowScanner := server.buildQuery(ctx, db, searchQuery)
|
||||
if query == "" {
|
||||
return nil, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
|
||||
events = make([]*es_models.Event, 0, limit)
|
||||
err = db.QueryContext(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
event := new(es_models.Event)
|
||||
err := rowScanner(rows.Scan, event)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
events = append(events, event)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
query, values...,
|
||||
)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return nil, errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
|
||||
}
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
|
||||
query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory)
|
||||
if query == "" {
|
||||
return 0, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
sequence := new(Sequence)
|
||||
err := db.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
return rowScanner(row.Scan, sequence)
|
||||
}, query, values...)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
logging.New().WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Info("query failed")
|
||||
return 0, errs.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
|
||||
}
|
||||
return uint64(*sequence), nil
|
||||
}
|
||||
|
||||
func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (ids []string, err error) {
|
||||
query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory)
|
||||
if query == "" {
|
||||
return nil, errs.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory")
|
||||
}
|
||||
|
||||
err = db.client.QueryContext(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var id string
|
||||
err := rowScanner(rows.Scan, &id)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
query, values...)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return nil, errs.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids")
|
||||
}
|
||||
|
||||
return ids, nil
|
||||
}
|
@@ -1,243 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"math"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
type mockEvents struct {
|
||||
events []*es_models.Event
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func TestSQL_Filter(t *testing.T) {
|
||||
type fields struct {
|
||||
client *dbMock
|
||||
}
|
||||
type args struct {
|
||||
events *mockEvents
|
||||
searchQuery *es_models.SearchQueryFactory
|
||||
}
|
||||
type res struct {
|
||||
wantErr bool
|
||||
isErrFunc func(error) bool
|
||||
eventsLen int
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "only limit filter",
|
||||
fields: fields{
|
||||
client: mockDB(t).expectFilterEventsLimit("user", 34, 3),
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory().Limit(34).AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
eventsLen: 3,
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "only desc filter",
|
||||
fields: fields{
|
||||
client: mockDB(t).expectFilterEventsDesc("user", 34),
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
eventsLen: 34,
|
||||
wantErr: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no events found",
|
||||
fields: fields{
|
||||
client: mockDB(t).expectFilterEventsError(sql.ErrNoRows),
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("nonAggregate").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
isErrFunc: errors.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter fails because sql internal error",
|
||||
fields: fields{
|
||||
client: mockDB(t).expectFilterEventsError(sql.ErrConnDone),
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
isErrFunc: errors.IsInternal,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter by aggregate id",
|
||||
fields: fields{
|
||||
client: mockDB(t).expectFilterEventsAggregateIDLimit("user", "hop", 5),
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: false,
|
||||
isErrFunc: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter by aggregate id and aggregate type",
|
||||
fields: fields{
|
||||
client: mockDB(t).expectFilterEventsAggregateIDTypeLimit("user", "hop", 5),
|
||||
},
|
||||
args: args{
|
||||
events: &mockEvents{t: t},
|
||||
searchQuery: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").AggregateIDs("hop").Factory(),
|
||||
},
|
||||
res: res{
|
||||
wantErr: false,
|
||||
isErrFunc: nil,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sql := &SQL{
|
||||
client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)},
|
||||
allowOrderByCreationDate: true,
|
||||
}
|
||||
events, err := sql.Filter(context.Background(), tt.args.searchQuery)
|
||||
if (err != nil) != tt.res.wantErr {
|
||||
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
}
|
||||
|
||||
if tt.res.eventsLen != 0 && len(events) != tt.res.eventsLen {
|
||||
t.Errorf("events has wrong length got: %d want %d", len(events), tt.res.eventsLen)
|
||||
}
|
||||
if tt.res.wantErr && !tt.res.isErrFunc(err) {
|
||||
t.Errorf("got wrong error %v", err)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
tt.fields.client.close()
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSQL_LatestSequence(t *testing.T) {
|
||||
type fields struct {
|
||||
client *dbMock
|
||||
}
|
||||
type args struct {
|
||||
searchQuery *es_models.SearchQueryFactory
|
||||
}
|
||||
type res struct {
|
||||
wantErr bool
|
||||
isErrFunc func(error) bool
|
||||
sequence uint64
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid query factory",
|
||||
args: args{
|
||||
searchQuery: nil,
|
||||
},
|
||||
fields: fields{
|
||||
client: mockDB(t),
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
isErrFunc: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no events for aggregate",
|
||||
args: args{
|
||||
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(),
|
||||
},
|
||||
fields: fields{
|
||||
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrNoRows),
|
||||
},
|
||||
res: res{
|
||||
wantErr: false,
|
||||
sequence: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "sql query error",
|
||||
args: args{
|
||||
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("idiot").Factory(),
|
||||
},
|
||||
fields: fields{
|
||||
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrConnDone),
|
||||
},
|
||||
res: res{
|
||||
wantErr: true,
|
||||
isErrFunc: errors.IsInternal,
|
||||
sequence: 0,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "events for aggregate found",
|
||||
args: args{
|
||||
searchQuery: es_models.NewSearchQueryFactory().Columns(es_models.Columns_Max_Sequence).AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
fields: fields{
|
||||
client: mockDB(t).expectLatestSequenceFilter("user", math.MaxUint64),
|
||||
},
|
||||
res: res{
|
||||
wantErr: false,
|
||||
sequence: math.MaxUint64,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
sql := &SQL{
|
||||
client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)},
|
||||
}
|
||||
|
||||
sequence, err := sql.LatestSequence(context.Background(), tt.args.searchQuery)
|
||||
if (err != nil) != tt.res.wantErr {
|
||||
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
}
|
||||
|
||||
if tt.res.sequence != sequence {
|
||||
t.Errorf("events has wrong length got: %d want %d", sequence, tt.res.sequence)
|
||||
}
|
||||
if tt.res.wantErr && !tt.res.isErrFunc(err) {
|
||||
t.Errorf("got wrong error %v", err)
|
||||
}
|
||||
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("there were unfulfilled expectations: %s", err)
|
||||
}
|
||||
tt.fields.client.close()
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,238 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
z_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
const (
|
||||
selectStmt = "SELECT" +
|
||||
" creation_date" +
|
||||
", event_type" +
|
||||
", event_sequence" +
|
||||
", previous_aggregate_sequence" +
|
||||
", event_data" +
|
||||
", editor_service" +
|
||||
", editor_user" +
|
||||
", resource_owner" +
|
||||
", instance_id" +
|
||||
", aggregate_type" +
|
||||
", aggregate_id" +
|
||||
", aggregate_version" +
|
||||
" FROM eventstore.events"
|
||||
)
|
||||
|
||||
func (sql *SQL) buildQuery(ctx context.Context, db dialect.Database, queryFactory *es_models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scan, dest interface{}) error) {
|
||||
searchQuery, err := queryFactory.Build()
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("search query factory invalid")
|
||||
return "", 0, nil, nil
|
||||
}
|
||||
query, rowScanner = prepareColumns(searchQuery.Columns)
|
||||
where, values := prepareCondition(searchQuery.Filters)
|
||||
if where == "" || query == "" {
|
||||
return "", 0, nil, nil
|
||||
}
|
||||
|
||||
if travel := db.Timetravel(call.Took(ctx)); travel != "" {
|
||||
query += travel
|
||||
}
|
||||
query += where
|
||||
|
||||
if searchQuery.Columns == es_models.Columns_Event {
|
||||
var order string
|
||||
if sql.allowOrderByCreationDate {
|
||||
order = " ORDER BY creation_date, event_sequence"
|
||||
if searchQuery.Desc {
|
||||
order = " ORDER BY creation_date DESC, event_sequence DESC"
|
||||
}
|
||||
} else {
|
||||
order = " ORDER BY event_sequence"
|
||||
if searchQuery.Desc {
|
||||
order = " ORDER BY event_sequence DESC"
|
||||
}
|
||||
}
|
||||
query += order
|
||||
}
|
||||
|
||||
if searchQuery.Limit > 0 {
|
||||
values = append(values, searchQuery.Limit)
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
query = numberPlaceholder(query, "?", "$")
|
||||
|
||||
return query, searchQuery.Limit, values, rowScanner
|
||||
}
|
||||
|
||||
func prepareCondition(filters [][]*es_models.Filter) (clause string, values []interface{}) {
|
||||
values = make([]interface{}, 0, len(filters))
|
||||
clauses := make([]string, len(filters))
|
||||
|
||||
if len(filters) == 0 {
|
||||
return clause, values
|
||||
}
|
||||
for i, filter := range filters {
|
||||
subClauses := make([]string, 0, len(filter))
|
||||
for _, f := range filter {
|
||||
value := f.GetValue()
|
||||
|
||||
subClauses = append(subClauses, getCondition(f))
|
||||
if subClauses[len(subClauses)-1] == "" {
|
||||
return "", nil
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
clauses[i] = "( " + strings.Join(subClauses, " AND ") + " )"
|
||||
}
|
||||
return " WHERE " + strings.Join(clauses, " OR "), values
|
||||
}
|
||||
|
||||
type scan func(dest ...interface{}) error
|
||||
|
||||
func prepareColumns(columns es_models.Columns) (string, func(s scan, dest interface{}) error) {
|
||||
switch columns {
|
||||
case es_models.Columns_Max_Sequence:
|
||||
return "SELECT MAX(event_sequence) FROM eventstore.events", func(row scan, dest interface{}) (err error) {
|
||||
sequence, ok := dest.(*Sequence)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
|
||||
}
|
||||
err = row(sequence)
|
||||
if err == nil || errors.Is(err, sql.ErrNoRows) {
|
||||
return nil
|
||||
}
|
||||
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
case es_models.Columns_InstanceIDs:
|
||||
return "SELECT DISTINCT instance_id FROM eventstore.events", func(row scan, dest interface{}) (err error) {
|
||||
instanceID, ok := dest.(*string)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-Fef5h", "type must be *string]")
|
||||
}
|
||||
err = row(instanceID)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to scan row")
|
||||
return z_errors.ThrowInternal(err, "SQL-SFef3", "unable to scan row")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
case es_models.Columns_Event:
|
||||
return selectStmt, func(row scan, dest interface{}) (err error) {
|
||||
event, ok := dest.(*es_models.Event)
|
||||
if !ok {
|
||||
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
|
||||
}
|
||||
var previousSequence Sequence
|
||||
data := make(Data, 0)
|
||||
|
||||
err = row(
|
||||
&event.CreationDate,
|
||||
&event.Type,
|
||||
&event.Sequence,
|
||||
&previousSequence,
|
||||
&data,
|
||||
&event.EditorService,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.InstanceID,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&event.AggregateVersion,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to scan row")
|
||||
return z_errors.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
|
||||
}
|
||||
|
||||
event.PreviousSequence = uint64(previousSequence)
|
||||
|
||||
event.Data = make([]byte, len(data))
|
||||
copy(event.Data, data)
|
||||
|
||||
return nil
|
||||
}
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func numberPlaceholder(query, old, new string) string {
|
||||
for i, hasChanged := 1, true; hasChanged; i++ {
|
||||
newQuery := strings.Replace(query, old, new+strconv.Itoa(i), 1)
|
||||
hasChanged = query != newQuery
|
||||
query = newQuery
|
||||
}
|
||||
return query
|
||||
}
|
||||
|
||||
func getCondition(filter *es_models.Filter) (condition string) {
|
||||
field := getField(filter.GetField())
|
||||
operation := getOperation(filter.GetOperation())
|
||||
if field == "" || operation == "" {
|
||||
return ""
|
||||
}
|
||||
format := getConditionFormat(filter.GetOperation())
|
||||
|
||||
return fmt.Sprintf(format, field, operation)
|
||||
}
|
||||
|
||||
func getConditionFormat(operation es_models.Operation) string {
|
||||
switch operation {
|
||||
case es_models.Operation_In:
|
||||
return "%s %s ANY(?)"
|
||||
case es_models.Operation_NotIn:
|
||||
return "%s %s ALL(?)"
|
||||
}
|
||||
return "%s %s ?"
|
||||
}
|
||||
|
||||
func getField(field es_models.Field) string {
|
||||
switch field {
|
||||
case es_models.Field_AggregateID:
|
||||
return "aggregate_id"
|
||||
case es_models.Field_AggregateType:
|
||||
return "aggregate_type"
|
||||
case es_models.Field_LatestSequence:
|
||||
return "event_sequence"
|
||||
case es_models.Field_ResourceOwner:
|
||||
return "resource_owner"
|
||||
case es_models.Field_InstanceID:
|
||||
return "instance_id"
|
||||
case es_models.Field_EditorService:
|
||||
return "editor_service"
|
||||
case es_models.Field_EditorUser:
|
||||
return "editor_user"
|
||||
case es_models.Field_EventType:
|
||||
return "event_type"
|
||||
case es_models.Field_CreationDate:
|
||||
return "creation_date"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func getOperation(operation es_models.Operation) string {
|
||||
switch operation {
|
||||
case es_models.Operation_Equals, es_models.Operation_In:
|
||||
return "="
|
||||
case es_models.Operation_Greater:
|
||||
return ">"
|
||||
case es_models.Operation_Less:
|
||||
return "<"
|
||||
case es_models.Operation_NotIn:
|
||||
return "<>"
|
||||
}
|
||||
return ""
|
||||
}
|
@@ -1,504 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
func Test_numberPlaceholder(t *testing.T) {
|
||||
type args struct {
|
||||
query string
|
||||
old string
|
||||
new string
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "no replaces",
|
||||
args: args{
|
||||
new: "$",
|
||||
old: "?",
|
||||
query: "SELECT * FROM eventstore.events",
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT * FROM eventstore.events",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "two replaces",
|
||||
args: args{
|
||||
new: "$",
|
||||
old: "?",
|
||||
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND LIMIT = ?",
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND LIMIT = $2",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := numberPlaceholder(tt.args.query, tt.args.old, tt.args.new); got != tt.res.query {
|
||||
t.Errorf("numberPlaceholder() = %v, want %v", got, tt.res.query)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getOperation(t *testing.T) {
|
||||
t.Run("all ops", func(t *testing.T) {
|
||||
for op, expected := range map[es_models.Operation]string{
|
||||
es_models.Operation_Equals: "=",
|
||||
es_models.Operation_In: "=",
|
||||
es_models.Operation_Greater: ">",
|
||||
es_models.Operation_Less: "<",
|
||||
es_models.Operation(-1): "",
|
||||
} {
|
||||
if got := getOperation(op); got != expected {
|
||||
t.Errorf("getOperation() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getField(t *testing.T) {
|
||||
t.Run("all fields", func(t *testing.T) {
|
||||
for field, expected := range map[es_models.Field]string{
|
||||
es_models.Field_AggregateType: "aggregate_type",
|
||||
es_models.Field_AggregateID: "aggregate_id",
|
||||
es_models.Field_LatestSequence: "event_sequence",
|
||||
es_models.Field_ResourceOwner: "resource_owner",
|
||||
es_models.Field_InstanceID: "instance_id",
|
||||
es_models.Field_EditorService: "editor_service",
|
||||
es_models.Field_EditorUser: "editor_user",
|
||||
es_models.Field_EventType: "event_type",
|
||||
es_models.Field(-1): "",
|
||||
} {
|
||||
if got := getField(field); got != expected {
|
||||
t.Errorf("getField() = %v, want %v", got, expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func Test_getConditionFormat(t *testing.T) {
|
||||
type args struct {
|
||||
operation es_models.Operation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "no in operation",
|
||||
args: args{
|
||||
operation: es_models.Operation_Equals,
|
||||
},
|
||||
want: "%s %s ?",
|
||||
},
|
||||
{
|
||||
name: "in operation",
|
||||
args: args{
|
||||
operation: es_models.Operation_In,
|
||||
},
|
||||
want: "%s %s ANY(?)",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getConditionFormat(tt.args.operation); got != tt.want {
|
||||
t.Errorf("prepareConditionFormat() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_getCondition(t *testing.T) {
|
||||
type args struct {
|
||||
filter *es_models.Filter
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "equals",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateID, "", es_models.Operation_Equals)},
|
||||
want: "aggregate_id = ?",
|
||||
},
|
||||
{
|
||||
name: "greater",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 0, es_models.Operation_Greater)},
|
||||
want: "event_sequence > ?",
|
||||
},
|
||||
{
|
||||
name: "less",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 5000, es_models.Operation_Less)},
|
||||
want: "event_sequence < ?",
|
||||
},
|
||||
{
|
||||
name: "in list",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation_In)},
|
||||
want: "aggregate_type = ANY(?)",
|
||||
},
|
||||
{
|
||||
name: "invalid operation",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid field",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation_Equals)},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid field and operation",
|
||||
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := getCondition(tt.args.filter); got != tt.want {
|
||||
t.Errorf("getCondition() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prepareColumns(t *testing.T) {
|
||||
type args struct {
|
||||
columns es_models.Columns
|
||||
dest interface{}
|
||||
dbErr error
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
dbRow []interface{}
|
||||
expected interface{}
|
||||
dbErr func(error) bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid columns",
|
||||
args: args{columns: es_models.Columns(-1)},
|
||||
res: res{
|
||||
query: "",
|
||||
dbErr: func(err error) bool { return err == nil },
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "max column",
|
||||
args: args{
|
||||
columns: es_models.Columns_Max_Sequence,
|
||||
dest: new(Sequence),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT MAX(event_sequence) FROM eventstore.events",
|
||||
dbRow: []interface{}{Sequence(5)},
|
||||
expected: Sequence(5),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "max sequence wrong dest type",
|
||||
args: args{
|
||||
columns: es_models.Columns_Max_Sequence,
|
||||
dest: new(uint64),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT MAX(event_sequence) FROM eventstore.events",
|
||||
dbErr: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(es_models.Event),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
|
||||
dbRow: []interface{}{time.Time{}, es_models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", "", es_models.AggregateType("user"), "hodor", es_models.Version("")},
|
||||
expected: es_models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event wrong dest type",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(uint64),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
|
||||
dbErr: errors.IsErrorInvalidArgument,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event query error",
|
||||
args: args{
|
||||
columns: es_models.Columns_Event,
|
||||
dest: new(es_models.Event),
|
||||
dbErr: sql.ErrConnDone,
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
|
||||
dbErr: errors.IsInternal,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
query, rowScanner := prepareColumns(tt.args.columns)
|
||||
if query != tt.res.query {
|
||||
t.Errorf("prepareColumns() got = %v, want %v", query, tt.res.query)
|
||||
}
|
||||
if tt.res.query == "" && rowScanner != nil {
|
||||
t.Errorf("row scanner should be nil")
|
||||
}
|
||||
if rowScanner == nil {
|
||||
return
|
||||
}
|
||||
err := rowScanner(prepareTestScan(tt.args.dbErr, tt.res.dbRow), tt.args.dest)
|
||||
if tt.res.dbErr != nil {
|
||||
if !tt.res.dbErr(err) {
|
||||
t.Errorf("wrong error type in rowScanner got: %v", err)
|
||||
}
|
||||
} else {
|
||||
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
|
||||
t.Errorf("unexpected result from rowScanner want: %v got: %v", tt.res.dbRow, tt.args.dest)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func prepareTestScan(err error, res []interface{}) scan {
|
||||
return func(dests ...interface{}) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(dests) != len(res) {
|
||||
return errors.ThrowInvalidArgumentf(nil, "SQL-NML1q", "expected len %d got %d", len(res), len(dests))
|
||||
}
|
||||
for i, r := range res {
|
||||
reflect.ValueOf(dests[i]).Elem().Set(reflect.ValueOf(r))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_prepareCondition(t *testing.T) {
|
||||
type args struct {
|
||||
filters [][]*es_models.Filter
|
||||
}
|
||||
type res struct {
|
||||
clause string
|
||||
values []interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "nil filters",
|
||||
args: args{
|
||||
filters: nil,
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty filters",
|
||||
args: args{
|
||||
filters: [][]*es_models.Filter{},
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid condition",
|
||||
args: args{
|
||||
filters: [][]*es_models.Filter{
|
||||
{
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: "",
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "array as condition value",
|
||||
args: args{
|
||||
filters: [][]*es_models.Filter{
|
||||
{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE ( aggregate_type = ANY(?) )",
|
||||
values: []interface{}{[]es_models.AggregateType{"user", "org"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple filters",
|
||||
args: args{
|
||||
filters: [][]*es_models.Filter{
|
||||
{
|
||||
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
|
||||
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
|
||||
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
|
||||
values: []interface{}{[]es_models.AggregateType{"user", "org"}, "1234", []es_models.EventType{"user.created", "org.created"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotClause, gotValues := prepareCondition(tt.args.filters)
|
||||
if gotClause != tt.res.clause {
|
||||
t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause)
|
||||
}
|
||||
if len(gotValues) != len(tt.res.values) {
|
||||
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
|
||||
return
|
||||
}
|
||||
for i, value := range gotValues {
|
||||
if !reflect.DeepEqual(value, tt.res.values[i]) {
|
||||
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_buildQuery(t *testing.T) {
|
||||
type args struct {
|
||||
queryFactory *es_models.SearchQueryFactory
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
limit uint64
|
||||
values []interface{}
|
||||
rowScanner bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid query factory",
|
||||
args: args{
|
||||
queryFactory: nil,
|
||||
},
|
||||
res: res{
|
||||
query: "",
|
||||
limit: 0,
|
||||
rowScanner: false,
|
||||
values: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory().OrderDesc().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE ( aggregate_type = $1 ) ORDER BY creation_date DESC, event_sequence DESC",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user")},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with limit",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory().Limit(5).AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE ( aggregate_type = $1 ) ORDER BY creation_date, event_sequence LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "with limit and order by desc",
|
||||
args: args{
|
||||
queryFactory: es_models.NewSearchQueryFactory().Limit(5).OrderDesc().AddQuery().AggregateTypes("user").Factory(),
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE ( aggregate_type = $1 ) ORDER BY creation_date DESC, event_sequence DESC LIMIT $2",
|
||||
rowScanner: true,
|
||||
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
|
||||
limit: 5,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
ctx := context.Background()
|
||||
db := new(testDB)
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotQuery, gotLimit, gotValues, gotRowScanner := (&SQL{allowOrderByCreationDate: true}).buildQuery(ctx, db, tt.args.queryFactory)
|
||||
if gotQuery != tt.res.query {
|
||||
t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query)
|
||||
}
|
||||
if gotLimit != tt.res.limit {
|
||||
t.Errorf("buildQuery() gotLimit = %v, want %v", gotLimit, tt.res.limit)
|
||||
}
|
||||
if len(gotValues) != len(tt.res.values) {
|
||||
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
|
||||
return
|
||||
}
|
||||
for i, value := range gotValues {
|
||||
if !reflect.DeepEqual(value, tt.res.values[i]) {
|
||||
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
|
||||
}
|
||||
}
|
||||
if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) {
|
||||
t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testDB struct{}
|
||||
|
||||
func (_ *testDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
|
||||
|
||||
func (*testDB) DatabaseName() string { return "db" }
|
||||
|
||||
func (*testDB) Username() string { return "user" }
|
||||
|
||||
func (*testDB) Type() string { return "type" }
|
@@ -1,16 +0,0 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
type SQL struct {
|
||||
client *database.DB
|
||||
allowOrderByCreationDate bool
|
||||
}
|
||||
|
||||
func (db *SQL) Health(ctx context.Context) error {
|
||||
return db.client.Ping()
|
||||
}
|
@@ -1,47 +0,0 @@
|
||||
package sql
|
||||
|
||||
import "database/sql/driver"
|
||||
|
||||
// Data represents a byte array that may be null.
|
||||
// Data implements the sql.Scanner interface
|
||||
type Data []byte
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (data *Data) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*data = nil
|
||||
return nil
|
||||
}
|
||||
*data = Data(value.([]byte))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (data Data) Value() (driver.Value, error) {
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []byte(data), nil
|
||||
}
|
||||
|
||||
// Sequence represents a number that may be null.
|
||||
// Sequence implements the sql.Scanner interface
|
||||
type Sequence uint64
|
||||
|
||||
// Scan implements the Scanner interface.
|
||||
func (seq *Sequence) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
*seq = 0
|
||||
return nil
|
||||
}
|
||||
*seq = Sequence(value.(int64))
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the driver Valuer interface.
|
||||
func (seq Sequence) Value() (driver.Value, error) {
|
||||
if seq == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return int64(seq), nil
|
||||
}
|
@@ -1,47 +0,0 @@
|
||||
package locker
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
insertStmtFormat = "INSERT INTO %s" +
|
||||
" (locker_id, locked_until, view_name, instance_id) VALUES ($1, now()+$2::INTERVAL, $3, $4)" +
|
||||
" ON CONFLICT (view_name, instance_id)" +
|
||||
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
|
||||
" WHERE locks.view_name = $3 AND locks.instance_id = $4 AND (locks.locker_id = $1 OR locks.locked_until < now())"
|
||||
millisecondsAsSeconds = int64(time.Second / time.Millisecond)
|
||||
)
|
||||
|
||||
type lock struct {
|
||||
LockerID string `gorm:"column:locker_id;primary_key"`
|
||||
LockedUntil time.Time `gorm:"column:locked_until"`
|
||||
ViewName string `gorm:"column:view_name;primary_key"`
|
||||
}
|
||||
|
||||
func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel, instanceID string, waitTime time.Duration) error {
|
||||
return crdb.ExecuteTx(context.Background(), dbClient, nil, func(tx *sql.Tx) error {
|
||||
insert := fmt.Sprintf(insertStmtFormat, lockTable)
|
||||
result, err := tx.Exec(insert,
|
||||
lockerID, waitTime, viewModel, instanceID)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
|
||||
if rows, _ := result.RowsAffected(); rows == 0 {
|
||||
return caos_errs.ThrowAlreadyExists(nil, "SPOOL-lso0e", "view already locked")
|
||||
}
|
||||
logging.LogWithFields("LOCKE-lOgbg", "view", viewModel, "locker", lockerID).Debug("locker changed")
|
||||
return nil
|
||||
})
|
||||
}
|
@@ -1,126 +0,0 @@
|
||||
package locker
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
)
|
||||
|
||||
type dbMock struct {
|
||||
db *sql.DB
|
||||
mock sqlmock.Sqlmock
|
||||
}
|
||||
|
||||
func mockDB(t *testing.T) *dbMock {
|
||||
mockDB := dbMock{}
|
||||
var err error
|
||||
mockDB.db, mockDB.mock, err = sqlmock.New()
|
||||
if err != nil {
|
||||
t.Fatalf("error occured while creating stub db %v", err)
|
||||
}
|
||||
|
||||
mockDB.mock.MatchExpectationsInOrder(true)
|
||||
|
||||
return &mockDB
|
||||
}
|
||||
|
||||
func (db *dbMock) expectCommit() *dbMock {
|
||||
db.mock.ExpectCommit()
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRollback() *dbMock {
|
||||
db.mock.ExpectRollback()
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectBegin() *dbMock {
|
||||
db.mock.ExpectBegin()
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectSavepoint() *dbMock {
|
||||
db.mock.ExpectExec("SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectReleaseSavepoint() *dbMock {
|
||||
db.mock.ExpectExec("RELEASE SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func (db *dbMock) expectRenew(lockerID, view, instanceID string, affectedRows int64) *dbMock {
|
||||
query := db.mock.
|
||||
ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\) ON CONFLICT \(view_name, instance_id\) DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL WHERE locks\.view_name = \$3 AND locks\.instance_id = \$4 AND \(locks\.locker_id = \$1 OR locks\.locked_until < now\(\)\)`).
|
||||
WithArgs(lockerID, sqlmock.AnyArg(), view, instanceID).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
|
||||
if affectedRows == 0 {
|
||||
query.WillReturnResult(sqlmock.NewResult(0, 0))
|
||||
} else {
|
||||
query.WillReturnResult(sqlmock.NewResult(1, affectedRows))
|
||||
}
|
||||
|
||||
return db
|
||||
}
|
||||
|
||||
func Test_locker_Renew(t *testing.T) {
|
||||
type fields struct {
|
||||
db *dbMock
|
||||
}
|
||||
type args struct {
|
||||
tableName string
|
||||
lockerID string
|
||||
viewModel string
|
||||
instanceID string
|
||||
waitTime time.Duration
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "renew succeeded",
|
||||
fields: fields{
|
||||
db: mockDB(t).
|
||||
expectBegin().
|
||||
expectSavepoint().
|
||||
expectRenew("locker", "view", "instanceID", 1).
|
||||
expectReleaseSavepoint().
|
||||
expectCommit(),
|
||||
},
|
||||
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "renew now rows updated",
|
||||
fields: fields{
|
||||
db: mockDB(t).
|
||||
expectBegin().
|
||||
expectSavepoint().
|
||||
expectRenew("locker", "view", "instanceID", 0).
|
||||
expectRollback(),
|
||||
},
|
||||
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", instanceID: "instanceID", waitTime: 1 * time.Second},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.instanceID, tt.args.waitTime); (err != nil) != tt.wantErr {
|
||||
t.Errorf("locker.Renew() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if err := tt.fields.db.mock.ExpectationsWereMet(); err != nil {
|
||||
t.Errorf("not all database expectations met: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,84 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/zitadel/internal/eventstore/v1 (interfaces: Eventstore)
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockEventstore is a mock of Eventstore interface.
|
||||
type MockEventstore struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockEventstoreMockRecorder
|
||||
}
|
||||
|
||||
// MockEventstoreMockRecorder is the mock recorder for MockEventstore.
|
||||
type MockEventstoreMockRecorder struct {
|
||||
mock *MockEventstore
|
||||
}
|
||||
|
||||
// NewMockEventstore creates a new mock instance.
|
||||
func NewMockEventstore(ctrl *gomock.Controller) *MockEventstore {
|
||||
mock := &MockEventstore{ctrl: ctrl}
|
||||
mock.recorder = &MockEventstoreMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockEventstore) EXPECT() *MockEventstoreMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// FilterEvents mocks base method.
|
||||
func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQuery) ([]*models.Event, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterEvents", arg0, arg1)
|
||||
ret0, _ := ret[0].([]*models.Event)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// FilterEvents indicates an expected call of FilterEvents.
|
||||
func (mr *MockEventstoreMockRecorder) FilterEvents(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterEvents", reflect.TypeOf((*MockEventstore)(nil).FilterEvents), arg0, arg1)
|
||||
}
|
||||
|
||||
// Health mocks base method.
|
||||
func (m *MockEventstore) Health(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Health", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Health indicates an expected call of Health.
|
||||
func (mr *MockEventstoreMockRecorder) Health(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockEventstore)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// Subscribe mocks base method.
|
||||
func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscription {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []interface{}{}
|
||||
for _, a := range arg0 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Subscribe", varargs...)
|
||||
ret0, _ := ret[0].(*v1.Subscription)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Subscribe indicates an expected call of Subscribe.
|
||||
func (mr *MockEventstoreMockRecorder) Subscribe(arg0 ...interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockEventstore)(nil).Subscribe), arg0...)
|
||||
}
|
@@ -1,99 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
type AggregateType string
|
||||
|
||||
func (at AggregateType) String() string {
|
||||
return string(at)
|
||||
}
|
||||
|
||||
type Aggregates []*Aggregate
|
||||
|
||||
type Aggregate struct {
|
||||
ID string
|
||||
typ AggregateType
|
||||
PreviousSequence uint64
|
||||
version Version
|
||||
|
||||
editorService string
|
||||
editorUser string
|
||||
resourceOwner string
|
||||
instanceID string
|
||||
Events []*Event
|
||||
Precondition *precondition
|
||||
}
|
||||
|
||||
func (a *Aggregate) Type() AggregateType {
|
||||
return a.typ
|
||||
}
|
||||
|
||||
type precondition struct {
|
||||
Query *SearchQuery
|
||||
Validation func(...*Event) error
|
||||
}
|
||||
|
||||
func (a *Aggregate) AppendEvent(typ EventType, payload interface{}) (*Aggregate, error) {
|
||||
if string(typ) == "" {
|
||||
return a, errors.ThrowInvalidArgument(nil, "MODEL-TGoCb", "no event type")
|
||||
}
|
||||
data, err := eventData(payload)
|
||||
if err != nil {
|
||||
return a, err
|
||||
}
|
||||
|
||||
e := &Event{
|
||||
CreationDate: time.Now(),
|
||||
Data: data,
|
||||
Type: typ,
|
||||
AggregateID: a.ID,
|
||||
AggregateType: a.typ,
|
||||
AggregateVersion: a.version,
|
||||
EditorService: a.editorService,
|
||||
EditorUser: a.editorUser,
|
||||
ResourceOwner: a.resourceOwner,
|
||||
InstanceID: a.instanceID,
|
||||
}
|
||||
|
||||
a.Events = append(a.Events, e)
|
||||
return a, nil
|
||||
}
|
||||
|
||||
func (a *Aggregate) SetPrecondition(query *SearchQuery, validateFunc func(...*Event) error) *Aggregate {
|
||||
a.Precondition = &precondition{Query: query, Validation: validateFunc}
|
||||
return a
|
||||
}
|
||||
|
||||
func (a *Aggregate) Validate() error {
|
||||
if a == nil {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-yi5AC", "aggregate is nil")
|
||||
}
|
||||
if a.ID == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-FSjKV", "id not set")
|
||||
}
|
||||
if string(a.typ) == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-aj4t2", "type not set")
|
||||
}
|
||||
if err := a.version.Validate(); err != nil {
|
||||
return errors.ThrowPreconditionFailed(err, "MODEL-PupjX", "invalid version")
|
||||
}
|
||||
|
||||
if a.editorService == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-clYbY", "editor service not set")
|
||||
}
|
||||
if a.editorUser == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-Xcssi", "editor user not set")
|
||||
}
|
||||
if a.resourceOwner == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-eBYUW", "resource owner not set")
|
||||
}
|
||||
if a.Precondition != nil && (a.Precondition.Query == nil || a.Precondition.Validation == nil) {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-EEUvA", "invalid precondition")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -1,59 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
)
|
||||
|
||||
type AggregateCreator struct {
|
||||
serviceName string
|
||||
}
|
||||
|
||||
func NewAggregateCreator(serviceName string) *AggregateCreator {
|
||||
return &AggregateCreator{serviceName: serviceName}
|
||||
}
|
||||
|
||||
type option func(*Aggregate)
|
||||
|
||||
func (c *AggregateCreator) NewAggregate(ctx context.Context, id string, typ AggregateType, version Version, previousSequence uint64, opts ...option) (*Aggregate, error) {
|
||||
ctxData := authz.GetCtxData(ctx)
|
||||
instance := authz.GetInstance(ctx)
|
||||
editorUser := ctxData.UserID
|
||||
resourceOwner := ctxData.OrgID
|
||||
instanceID := instance.InstanceID()
|
||||
|
||||
aggregate := &Aggregate{
|
||||
ID: id,
|
||||
typ: typ,
|
||||
PreviousSequence: previousSequence,
|
||||
version: version,
|
||||
Events: make([]*Event, 0, 2),
|
||||
editorService: c.serviceName,
|
||||
editorUser: editorUser,
|
||||
resourceOwner: resourceOwner,
|
||||
instanceID: instanceID,
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(aggregate)
|
||||
}
|
||||
|
||||
if err := aggregate.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return aggregate, nil
|
||||
}
|
||||
|
||||
func OverwriteEditorUser(userID string) func(*Aggregate) {
|
||||
return func(a *Aggregate) {
|
||||
a.editorUser = userID
|
||||
}
|
||||
}
|
||||
|
||||
func OverwriteResourceOwner(resourceOwner string) func(*Aggregate) {
|
||||
return func(a *Aggregate) {
|
||||
a.resourceOwner = resourceOwner
|
||||
}
|
||||
}
|
@@ -1,118 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestAggregateCreator_NewAggregate(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
typ AggregateType
|
||||
version Version
|
||||
opts []option
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
creator *AggregateCreator
|
||||
args args
|
||||
want *Aggregate
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no ctxdata and no options",
|
||||
creator: &AggregateCreator{serviceName: "admin"},
|
||||
wantErr: true,
|
||||
want: nil,
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
id: "hodor",
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no id error",
|
||||
creator: &AggregateCreator{serviceName: "admin"},
|
||||
wantErr: true,
|
||||
want: nil,
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
opts: []option{
|
||||
OverwriteEditorUser("hodor"),
|
||||
OverwriteResourceOwner("org"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no type error",
|
||||
creator: &AggregateCreator{serviceName: "admin"},
|
||||
wantErr: true,
|
||||
want: nil,
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
id: "hodor",
|
||||
version: "v1.0.0",
|
||||
opts: []option{
|
||||
OverwriteEditorUser("hodor"),
|
||||
OverwriteResourceOwner("org"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid version error",
|
||||
creator: &AggregateCreator{serviceName: "admin"},
|
||||
wantErr: true,
|
||||
want: nil,
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
id: "hodor",
|
||||
typ: "user",
|
||||
opts: []option{
|
||||
OverwriteEditorUser("hodor"),
|
||||
OverwriteResourceOwner("org"),
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "create ok",
|
||||
creator: &AggregateCreator{serviceName: "admin"},
|
||||
wantErr: false,
|
||||
want: &Aggregate{
|
||||
ID: "hodor",
|
||||
Events: make([]*Event, 0, 2),
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
editorService: "admin",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
},
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
id: "hodor",
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
opts: []option{
|
||||
OverwriteEditorUser("hodor"),
|
||||
OverwriteResourceOwner("org"),
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.creator.NewAggregate(tt.args.ctx, tt.args.id, tt.args.typ, tt.args.version, 0, tt.args.opts...)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("AggregateCreator.NewAggregate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("AggregateCreator.NewAggregate() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,310 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
func TestAggregate_AppendEvent(t *testing.T) {
|
||||
type fields struct {
|
||||
aggregate *Aggregate
|
||||
}
|
||||
type args struct {
|
||||
typ EventType
|
||||
payload interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *Aggregate
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "no event type error",
|
||||
fields: fields{aggregate: &Aggregate{}},
|
||||
args: args{},
|
||||
want: &Aggregate{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid payload error",
|
||||
fields: fields{aggregate: &Aggregate{}},
|
||||
args: args{typ: "user", payload: 134},
|
||||
want: &Aggregate{},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "event added",
|
||||
fields: fields{aggregate: &Aggregate{Events: []*Event{}}},
|
||||
args: args{typ: "user.deactivated"},
|
||||
want: &Aggregate{Events: []*Event{
|
||||
{Type: "user.deactivated"},
|
||||
}},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "event added",
|
||||
fields: fields{aggregate: &Aggregate{Events: []*Event{
|
||||
{},
|
||||
}}},
|
||||
args: args{typ: "user.deactivated"},
|
||||
want: &Aggregate{Events: []*Event{
|
||||
{},
|
||||
{Type: "user.deactivated"},
|
||||
}},
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.fields.aggregate.AppendEvent(tt.args.typ, tt.args.payload)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Aggregate.AppendEvent() error = %v, wantErr %v", err, tt.wantErr)
|
||||
return
|
||||
}
|
||||
if len(tt.fields.aggregate.Events) != len(got.Events) {
|
||||
t.Errorf("events len should be %d but was %d", len(tt.fields.aggregate.Events), len(got.Events))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregate_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
aggregate *Aggregate
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "aggregate nil error",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "aggregate empty error",
|
||||
wantErr: true,
|
||||
fields: fields{aggregate: &Aggregate{}},
|
||||
},
|
||||
{
|
||||
name: "no id error",
|
||||
wantErr: true,
|
||||
fields: fields{aggregate: &Aggregate{
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
editorService: "svc",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Events: []*Event{
|
||||
{
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "no type error",
|
||||
wantErr: true,
|
||||
fields: fields{aggregate: &Aggregate{
|
||||
ID: "aggID",
|
||||
version: "v1.0.0",
|
||||
editorService: "svc",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Events: []*Event{
|
||||
{
|
||||
AggregateID: "hodor",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "invalid version error",
|
||||
wantErr: true,
|
||||
fields: fields{aggregate: &Aggregate{
|
||||
ID: "aggID",
|
||||
typ: "user",
|
||||
editorService: "svc",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Events: []*Event{
|
||||
{
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "no query in precondition error",
|
||||
wantErr: true,
|
||||
fields: fields{aggregate: &Aggregate{
|
||||
ID: "aggID",
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
editorService: "svc",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Precondition: &precondition{
|
||||
Validation: func(...*Event) error { return nil },
|
||||
},
|
||||
Events: []*Event{
|
||||
{
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "no func in precondition error",
|
||||
wantErr: true,
|
||||
fields: fields{aggregate: &Aggregate{
|
||||
ID: "aggID",
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
editorService: "svc",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Precondition: &precondition{
|
||||
Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(),
|
||||
},
|
||||
Events: []*Event{
|
||||
{
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "validation without precondition ok",
|
||||
wantErr: false,
|
||||
fields: fields{aggregate: &Aggregate{
|
||||
ID: "aggID",
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
editorService: "svc",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Events: []*Event{
|
||||
{
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "validation with precondition ok",
|
||||
wantErr: false,
|
||||
fields: fields{aggregate: &Aggregate{
|
||||
ID: "aggID",
|
||||
typ: "user",
|
||||
version: "v1.0.0",
|
||||
editorService: "svc",
|
||||
editorUser: "hodor",
|
||||
resourceOwner: "org",
|
||||
PreviousSequence: 5,
|
||||
Precondition: &precondition{
|
||||
Validation: func(...*Event) error { return nil },
|
||||
Query: NewSearchQuery().AddQuery().AggregateIDFilter("hodor").SearchQuery(),
|
||||
},
|
||||
Events: []*Event{
|
||||
{
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
}},
|
||||
}},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := tt.fields.aggregate.Validate()
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("Aggregate.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.wantErr && !errors.IsPreconditionFailed(err) {
|
||||
t.Errorf("error must extend precondition failed: %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAggregate_SetPrecondition(t *testing.T) {
|
||||
type fields struct {
|
||||
aggregate *Aggregate
|
||||
}
|
||||
type args struct {
|
||||
query *SearchQuery
|
||||
validateFunc func(...*Event) error
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *Aggregate
|
||||
}{
|
||||
{
|
||||
name: "set precondition",
|
||||
fields: fields{aggregate: &Aggregate{}},
|
||||
args: args{
|
||||
query: &SearchQuery{},
|
||||
validateFunc: func(...*Event) error { return nil },
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
got := tt.fields.aggregate.SetPrecondition(tt.args.query, tt.args.validateFunc)
|
||||
if got.Precondition == nil {
|
||||
t.Error("precondition must not be nil")
|
||||
t.FailNow()
|
||||
}
|
||||
if got.Precondition.Query == nil {
|
||||
t.Error("query of precondition must not be nil")
|
||||
}
|
||||
if got.Precondition.Validation == nil {
|
||||
t.Error("precondition func must not be nil")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -6,6 +6,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type EventType string
|
||||
@@ -14,23 +15,85 @@ func (et EventType) String() string {
|
||||
return string(et)
|
||||
}
|
||||
|
||||
var _ eventstore.Event = (*Event)(nil)
|
||||
|
||||
type Event struct {
|
||||
ID string
|
||||
Sequence uint64
|
||||
Seq uint64
|
||||
Pos float64
|
||||
CreationDate time.Time
|
||||
Type EventType
|
||||
Typ eventstore.EventType
|
||||
PreviousSequence uint64
|
||||
Data []byte
|
||||
|
||||
AggregateID string
|
||||
AggregateType AggregateType
|
||||
AggregateVersion Version
|
||||
EditorService string
|
||||
EditorUser string
|
||||
AggregateType eventstore.AggregateType
|
||||
AggregateVersion eventstore.Version
|
||||
Service string
|
||||
User string
|
||||
ResourceOwner string
|
||||
InstanceID string
|
||||
}
|
||||
|
||||
// Aggregate implements [eventstore.Event]
|
||||
func (e *Event) Aggregate() *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
ID: e.AggregateID,
|
||||
Type: e.AggregateType,
|
||||
ResourceOwner: e.ResourceOwner,
|
||||
InstanceID: e.InstanceID,
|
||||
// Version: eventstore.Version(e.AggregateVersion),
|
||||
}
|
||||
}
|
||||
|
||||
// CreatedAt implements [eventstore.Event]
|
||||
func (e *Event) CreatedAt() time.Time {
|
||||
return e.CreationDate
|
||||
}
|
||||
|
||||
// DataAsBytes implements [eventstore.Event]
|
||||
func (e *Event) DataAsBytes() []byte {
|
||||
return e.Data
|
||||
}
|
||||
|
||||
// Unmarshal implements [eventstore.Event]
|
||||
func (e *Event) Unmarshal(ptr any) error {
|
||||
if len(e.Data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(e.Data, ptr)
|
||||
}
|
||||
|
||||
// EditorService implements [eventstore.Event]
|
||||
func (e *Event) EditorService() string {
|
||||
return e.Service
|
||||
}
|
||||
|
||||
// Creator implements [eventstore.action]
|
||||
func (e *Event) Creator() string {
|
||||
return e.User
|
||||
}
|
||||
|
||||
// Sequence implements [eventstore.Event]
|
||||
func (e *Event) Sequence() uint64 {
|
||||
return e.Seq
|
||||
}
|
||||
|
||||
// Position implements [eventstore.Event]
|
||||
func (e *Event) Position() float64 {
|
||||
return e.Pos
|
||||
}
|
||||
|
||||
// Type implements [eventstore.action]
|
||||
func (e *Event) Type() eventstore.EventType {
|
||||
return e.Typ
|
||||
}
|
||||
|
||||
// Type implements [eventstore.action]
|
||||
func (e *Event) Revision() uint16 {
|
||||
return 0
|
||||
}
|
||||
|
||||
func eventData(i interface{}) ([]byte, error) {
|
||||
switch v := i.(type) {
|
||||
case []byte:
|
||||
@@ -63,7 +126,7 @@ func (e *Event) Validate() error {
|
||||
if e == nil {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-oEAG4", "event is nil")
|
||||
}
|
||||
if string(e.Type) == "" {
|
||||
if string(e.Typ) == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-R2sB0", "type not defined")
|
||||
}
|
||||
|
||||
@@ -74,13 +137,12 @@ func (e *Event) Validate() error {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-EzdyK", "aggregate type not set")
|
||||
}
|
||||
if err := e.AggregateVersion.Validate(); err != nil {
|
||||
return err
|
||||
return errors.ThrowPreconditionFailed(err, "MODEL-KO71q", "version invalid")
|
||||
}
|
||||
|
||||
if e.EditorService == "" {
|
||||
if e.Service == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-4Yqik", "editor service not set")
|
||||
}
|
||||
if e.EditorUser == "" {
|
||||
if e.User == "" {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-L3NHO", "editor user not set")
|
||||
}
|
||||
if e.ResourceOwner == "" {
|
||||
|
@@ -95,10 +95,10 @@ func TestEvent_Validate(t *testing.T) {
|
||||
fields: fields{event: &Event{
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
Service: "management",
|
||||
User: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
Typ: "born",
|
||||
}},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -107,10 +107,10 @@ func TestEvent_Validate(t *testing.T) {
|
||||
fields: fields{event: &Event{
|
||||
AggregateID: "hodor",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
Service: "management",
|
||||
User: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
Typ: "born",
|
||||
}},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -119,10 +119,10 @@ func TestEvent_Validate(t *testing.T) {
|
||||
fields: fields{event: &Event{
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
Service: "management",
|
||||
User: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
Typ: "born",
|
||||
}},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -132,9 +132,9 @@ func TestEvent_Validate(t *testing.T) {
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorUser: "hodor",
|
||||
User: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
Typ: "born",
|
||||
}},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -144,9 +144,9 @@ func TestEvent_Validate(t *testing.T) {
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
Service: "management",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
Typ: "born",
|
||||
}},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -156,9 +156,9 @@ func TestEvent_Validate(t *testing.T) {
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
Type: "born",
|
||||
Service: "management",
|
||||
User: "hodor",
|
||||
Typ: "born",
|
||||
}},
|
||||
wantErr: true,
|
||||
},
|
||||
@@ -168,8 +168,8 @@ func TestEvent_Validate(t *testing.T) {
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
Service: "management",
|
||||
User: "hodor",
|
||||
ResourceOwner: "org",
|
||||
}},
|
||||
wantErr: true,
|
||||
@@ -180,10 +180,10 @@ func TestEvent_Validate(t *testing.T) {
|
||||
AggregateID: "hodor",
|
||||
AggregateType: "user",
|
||||
AggregateVersion: "v1.0.0",
|
||||
EditorService: "management",
|
||||
EditorUser: "hodor",
|
||||
Service: "management",
|
||||
User: "hodor",
|
||||
ResourceOwner: "org",
|
||||
Type: "born",
|
||||
Typ: "born",
|
||||
}},
|
||||
wantErr: false,
|
||||
},
|
||||
|
@@ -1,15 +0,0 @@
|
||||
package models
|
||||
|
||||
type Field int32
|
||||
|
||||
const (
|
||||
Field_AggregateType Field = 1 + iota
|
||||
Field_AggregateID
|
||||
Field_LatestSequence
|
||||
Field_ResourceOwner
|
||||
Field_EditorService
|
||||
Field_EditorUser
|
||||
Field_EventType
|
||||
Field_CreationDate
|
||||
Field_InstanceID
|
||||
)
|
@@ -1,46 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
type Filter struct {
|
||||
field Field
|
||||
value interface{}
|
||||
operation Operation
|
||||
}
|
||||
|
||||
//NewFilter is used in tests. Use searchQuery.*Filter() instead
|
||||
func NewFilter(field Field, value interface{}, operation Operation) *Filter {
|
||||
return &Filter{
|
||||
field: field,
|
||||
value: value,
|
||||
operation: operation,
|
||||
}
|
||||
}
|
||||
|
||||
func (f *Filter) GetField() Field {
|
||||
return f.field
|
||||
}
|
||||
func (f *Filter) GetOperation() Operation {
|
||||
return f.operation
|
||||
}
|
||||
func (f *Filter) GetValue() interface{} {
|
||||
return f.value
|
||||
}
|
||||
|
||||
func (f *Filter) Validate() error {
|
||||
if f == nil {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-z6KcG", "filter is nil")
|
||||
}
|
||||
if f.field <= 0 {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-zw62U", "field not definded")
|
||||
}
|
||||
if f.value == nil {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-GJ9ct", "no value definded")
|
||||
}
|
||||
if f.operation <= 0 {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-RrQTy", "operation not definded")
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,104 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestNewFilter(t *testing.T) {
|
||||
type args struct {
|
||||
field Field
|
||||
value interface{}
|
||||
operation Operation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Filter
|
||||
}{
|
||||
{
|
||||
name: "aggregateID equals",
|
||||
args: args{
|
||||
field: Field_AggregateID,
|
||||
value: "hodor",
|
||||
operation: Operation_Equals,
|
||||
},
|
||||
want: &Filter{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NewFilter(tt.args.field, tt.args.value, tt.args.operation); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewFilter() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
field Field
|
||||
value interface{}
|
||||
operation Operation
|
||||
isNil bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "correct filter",
|
||||
fields: fields{
|
||||
field: Field_LatestSequence,
|
||||
operation: Operation_Greater,
|
||||
value: uint64(235),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "filter is nil",
|
||||
fields: fields{isNil: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no field error",
|
||||
fields: fields{
|
||||
operation: Operation_Greater,
|
||||
value: uint64(235),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no value error",
|
||||
fields: fields{
|
||||
field: Field_LatestSequence,
|
||||
operation: Operation_Greater,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no operation error",
|
||||
fields: fields{
|
||||
field: Field_LatestSequence,
|
||||
value: uint64(235),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var f *Filter
|
||||
if !tt.fields.isNil {
|
||||
f = &Filter{
|
||||
field: tt.fields.field,
|
||||
value: tt.fields.value,
|
||||
operation: tt.fields.operation,
|
||||
}
|
||||
}
|
||||
if err := f.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Filter.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -2,6 +2,8 @@ package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type ObjectRoot struct {
|
||||
@@ -13,25 +15,25 @@ type ObjectRoot struct {
|
||||
ChangeDate time.Time `json:"-"`
|
||||
}
|
||||
|
||||
func (o *ObjectRoot) AppendEvent(event *Event) {
|
||||
func (o *ObjectRoot) AppendEvent(event eventstore.Event) {
|
||||
if o.AggregateID == "" {
|
||||
o.AggregateID = event.AggregateID
|
||||
} else if o.AggregateID != event.AggregateID {
|
||||
o.AggregateID = event.Aggregate().ID
|
||||
} else if o.AggregateID != event.Aggregate().ID {
|
||||
return
|
||||
}
|
||||
if o.ResourceOwner == "" {
|
||||
o.ResourceOwner = event.ResourceOwner
|
||||
o.ResourceOwner = event.Aggregate().ResourceOwner
|
||||
}
|
||||
if o.InstanceID == "" {
|
||||
o.InstanceID = event.InstanceID
|
||||
o.InstanceID = event.Aggregate().InstanceID
|
||||
}
|
||||
|
||||
o.ChangeDate = event.CreationDate
|
||||
o.ChangeDate = event.CreatedAt()
|
||||
if o.CreationDate.IsZero() {
|
||||
o.CreationDate = o.ChangeDate
|
||||
}
|
||||
|
||||
o.Sequence = event.Sequence
|
||||
o.Sequence = event.Sequence()
|
||||
}
|
||||
func (o *ObjectRoot) IsZero() bool {
|
||||
return o.AggregateID == ""
|
||||
|
@@ -27,7 +27,7 @@ func TestObjectRoot_AppendEvent(t *testing.T) {
|
||||
args{
|
||||
&Event{
|
||||
AggregateID: "aggID",
|
||||
Sequence: 34555,
|
||||
Seq: 34555,
|
||||
CreationDate: time.Now(),
|
||||
},
|
||||
true,
|
||||
@@ -44,7 +44,7 @@ func TestObjectRoot_AppendEvent(t *testing.T) {
|
||||
args{
|
||||
&Event{
|
||||
AggregateID: "agg",
|
||||
Sequence: 34555425,
|
||||
Seq: 34555425,
|
||||
CreationDate: time.Now(),
|
||||
PreviousSequence: 22,
|
||||
},
|
||||
@@ -70,8 +70,8 @@ func TestObjectRoot_AppendEvent(t *testing.T) {
|
||||
t.Error("creationDate and changedate should differ")
|
||||
}
|
||||
}
|
||||
if o.Sequence != tt.args.event.Sequence {
|
||||
t.Errorf("sequence not equal to event: event: %d root: %d", tt.args.event.Sequence, o.Sequence)
|
||||
if o.Sequence != tt.args.event.Seq {
|
||||
t.Errorf("sequence not equal to event: event: %d root: %d", tt.args.event.Seq, o.Sequence)
|
||||
}
|
||||
if !o.ChangeDate.Equal(tt.args.event.CreationDate) {
|
||||
t.Errorf("changedate should be equal to event creation date: event: %v root: %v", tt.args.event.CreationDate, o.ChangeDate)
|
||||
|
@@ -1,11 +0,0 @@
|
||||
package models
|
||||
|
||||
type Operation int32
|
||||
|
||||
const (
|
||||
Operation_Equals Operation = 1 + iota
|
||||
Operation_Greater
|
||||
Operation_Less
|
||||
Operation_In
|
||||
Operation_NotIn
|
||||
)
|
@@ -1,296 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
type SearchQueryFactory struct {
|
||||
columns Columns
|
||||
limit uint64
|
||||
desc bool
|
||||
queries []*query
|
||||
|
||||
InstanceFiltered bool
|
||||
}
|
||||
|
||||
type query struct {
|
||||
desc bool
|
||||
aggregateTypes []AggregateType
|
||||
aggregateIDs []string
|
||||
sequenceFrom uint64
|
||||
sequenceTo uint64
|
||||
eventTypes []EventType
|
||||
resourceOwner string
|
||||
instanceID string
|
||||
ignoredInstanceIDs []string
|
||||
creationDate time.Time
|
||||
factory *SearchQueryFactory
|
||||
}
|
||||
|
||||
type searchQuery struct {
|
||||
Columns Columns
|
||||
Limit uint64
|
||||
Desc bool
|
||||
Filters [][]*Filter
|
||||
}
|
||||
|
||||
type Columns int32
|
||||
|
||||
const (
|
||||
Columns_Event = iota
|
||||
Columns_Max_Sequence
|
||||
Columns_InstanceIDs
|
||||
// insert new columns-types before this columnsCount because count is needed for validation
|
||||
columnsCount
|
||||
)
|
||||
|
||||
// FactoryFromSearchQuery is deprecated because it's for migration purposes. use NewSearchQueryFactory
|
||||
func FactoryFromSearchQuery(q *SearchQuery) *SearchQueryFactory {
|
||||
factory := &SearchQueryFactory{
|
||||
columns: q.Columns,
|
||||
desc: q.Desc,
|
||||
limit: q.Limit,
|
||||
queries: make([]*query, len(q.Queries)),
|
||||
}
|
||||
|
||||
for i, qq := range q.Queries {
|
||||
factory.queries[i] = &query{factory: factory}
|
||||
for _, filter := range qq.Filters {
|
||||
switch filter.field {
|
||||
case Field_AggregateType:
|
||||
factory.queries[i] = factory.queries[i].aggregateTypesMig(filter.value.([]AggregateType)...)
|
||||
case Field_AggregateID:
|
||||
if aggregateID, ok := filter.value.(string); ok {
|
||||
factory.queries[i] = factory.queries[i].AggregateIDs(aggregateID)
|
||||
} else if aggregateIDs, ok := filter.value.([]string); ok {
|
||||
factory.queries[i] = factory.queries[i].AggregateIDs(aggregateIDs...)
|
||||
}
|
||||
case Field_LatestSequence:
|
||||
if filter.operation == Operation_Greater {
|
||||
factory.queries[i] = factory.queries[i].SequenceGreater(filter.value.(uint64))
|
||||
} else {
|
||||
factory.queries[i] = factory.queries[i].SequenceLess(filter.value.(uint64))
|
||||
}
|
||||
case Field_ResourceOwner:
|
||||
factory.queries[i] = factory.queries[i].ResourceOwner(filter.value.(string))
|
||||
case Field_InstanceID:
|
||||
factory.InstanceFiltered = true
|
||||
if filter.operation == Operation_Equals {
|
||||
factory.queries[i] = factory.queries[i].InstanceID(filter.value.(string))
|
||||
} else if filter.operation == Operation_NotIn {
|
||||
factory.queries[i] = factory.queries[i].IgnoredInstanceIDs(filter.value.([]string)...)
|
||||
}
|
||||
case Field_EventType:
|
||||
factory.queries[i] = factory.queries[i].EventTypes(filter.value.([]EventType)...)
|
||||
case Field_EditorService, Field_EditorUser:
|
||||
logging.WithFields("value", filter.value).Panic("field not converted to factory")
|
||||
case Field_CreationDate:
|
||||
factory.queries[i] = factory.queries[i].CreationDateNewer(filter.value.(time.Time))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return factory
|
||||
}
|
||||
|
||||
func NewSearchQueryFactory() *SearchQueryFactory {
|
||||
return &SearchQueryFactory{}
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) Columns(columns Columns) *SearchQueryFactory {
|
||||
factory.columns = columns
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) Limit(limit uint64) *SearchQueryFactory {
|
||||
factory.limit = limit
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) OrderDesc() *SearchQueryFactory {
|
||||
factory.desc = true
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) OrderAsc() *SearchQueryFactory {
|
||||
factory.desc = false
|
||||
return factory
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) AddQuery() *query {
|
||||
q := &query{factory: factory}
|
||||
factory.queries = append(factory.queries, q)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) Factory() *SearchQueryFactory {
|
||||
return q.factory
|
||||
}
|
||||
|
||||
func (q *query) SequenceGreater(sequence uint64) *query {
|
||||
q.sequenceFrom = sequence
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) SequenceLess(sequence uint64) *query {
|
||||
q.sequenceTo = sequence
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) AggregateTypes(types ...AggregateType) *query {
|
||||
q.aggregateTypes = types
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) AggregateIDs(ids ...string) *query {
|
||||
q.aggregateIDs = ids
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) aggregateTypesMig(types ...AggregateType) *query {
|
||||
q.aggregateTypes = types
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) EventTypes(types ...EventType) *query {
|
||||
q.eventTypes = types
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) ResourceOwner(resourceOwner string) *query {
|
||||
q.resourceOwner = resourceOwner
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) InstanceID(instanceID string) *query {
|
||||
q.instanceID = instanceID
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) IgnoredInstanceIDs(instanceIDs ...string) *query {
|
||||
q.ignoredInstanceIDs = instanceIDs
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *query) CreationDateNewer(time time.Time) *query {
|
||||
q.creationDate = time
|
||||
return q
|
||||
}
|
||||
|
||||
func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
|
||||
if factory == nil ||
|
||||
len(factory.queries) < 1 ||
|
||||
(factory.columns < 0 || factory.columns >= columnsCount) {
|
||||
return nil, errors.ThrowPreconditionFailed(nil, "MODEL-tGAD3", "factory invalid")
|
||||
}
|
||||
filters := make([][]*Filter, len(factory.queries))
|
||||
|
||||
for i, query := range factory.queries {
|
||||
for _, f := range []func() *Filter{
|
||||
query.aggregateTypeFilter,
|
||||
query.aggregateIDFilter,
|
||||
query.sequenceFromFilter,
|
||||
query.sequenceToFilter,
|
||||
query.eventTypeFilter,
|
||||
query.resourceOwnerFilter,
|
||||
query.instanceIDFilter,
|
||||
query.ignoredInstanceIDsFilter,
|
||||
query.creationDateNewerFilter,
|
||||
} {
|
||||
if filter := f(); filter != nil {
|
||||
filters[i] = append(filters[i], filter)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &searchQuery{
|
||||
Columns: factory.columns,
|
||||
Limit: factory.limit,
|
||||
Desc: factory.desc,
|
||||
Filters: filters,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (q *query) aggregateIDFilter() *Filter {
|
||||
if len(q.aggregateIDs) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(q.aggregateIDs) == 1 {
|
||||
return NewFilter(Field_AggregateID, q.aggregateIDs[0], Operation_Equals)
|
||||
}
|
||||
return NewFilter(Field_AggregateID, q.aggregateIDs, Operation_In)
|
||||
}
|
||||
|
||||
func (q *query) eventTypeFilter() *Filter {
|
||||
if len(q.eventTypes) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(q.eventTypes) == 1 {
|
||||
return NewFilter(Field_EventType, q.eventTypes[0], Operation_Equals)
|
||||
}
|
||||
return NewFilter(Field_EventType, q.eventTypes, Operation_In)
|
||||
}
|
||||
|
||||
func (q *query) aggregateTypeFilter() *Filter {
|
||||
if len(q.aggregateTypes) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(q.aggregateTypes) == 1 {
|
||||
return NewFilter(Field_AggregateType, q.aggregateTypes[0], Operation_Equals)
|
||||
}
|
||||
return NewFilter(Field_AggregateType, q.aggregateTypes, Operation_In)
|
||||
}
|
||||
|
||||
func (q *query) sequenceFromFilter() *Filter {
|
||||
if q.sequenceFrom == 0 {
|
||||
return nil
|
||||
}
|
||||
sortOrder := Operation_Greater
|
||||
if q.factory.desc {
|
||||
sortOrder = Operation_Less
|
||||
}
|
||||
return NewFilter(Field_LatestSequence, q.sequenceFrom, sortOrder)
|
||||
}
|
||||
|
||||
func (q *query) sequenceToFilter() *Filter {
|
||||
if q.sequenceTo == 0 {
|
||||
return nil
|
||||
}
|
||||
sortOrder := Operation_Less
|
||||
if q.factory.desc {
|
||||
sortOrder = Operation_Greater
|
||||
}
|
||||
return NewFilter(Field_LatestSequence, q.sequenceTo, sortOrder)
|
||||
}
|
||||
|
||||
func (q *query) resourceOwnerFilter() *Filter {
|
||||
if q.resourceOwner == "" {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_ResourceOwner, q.resourceOwner, Operation_Equals)
|
||||
}
|
||||
|
||||
func (q *query) instanceIDFilter() *Filter {
|
||||
if q.instanceID == "" {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_InstanceID, q.instanceID, Operation_Equals)
|
||||
}
|
||||
|
||||
func (q *query) ignoredInstanceIDsFilter() *Filter {
|
||||
if len(q.ignoredInstanceIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_InstanceID, q.ignoredInstanceIDs, Operation_NotIn)
|
||||
}
|
||||
|
||||
func (q *query) creationDateNewerFilter() *Filter {
|
||||
if q.creationDate.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(Field_CreationDate, q.creationDate, Operation_Greater)
|
||||
}
|
@@ -1,148 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
//SearchQuery is deprecated. Use SearchQueryFactory
|
||||
type SearchQuery struct {
|
||||
Columns Columns
|
||||
Limit uint64
|
||||
Desc bool
|
||||
Filters []*Filter
|
||||
Queries []*Query
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
searchQuery *SearchQuery
|
||||
Filters []*Filter
|
||||
}
|
||||
|
||||
//NewSearchQuery is deprecated. Use SearchQueryFactory
|
||||
func NewSearchQuery() *SearchQuery {
|
||||
return &SearchQuery{
|
||||
Filters: make([]*Filter, 0, 4),
|
||||
Queries: make([]*Query, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (q *SearchQuery) SetColumn(columns Columns) *SearchQuery {
|
||||
q.Columns = columns
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) AddQuery() *Query {
|
||||
query := &Query{
|
||||
searchQuery: q,
|
||||
}
|
||||
q.Queries = append(q.Queries, query)
|
||||
|
||||
return query
|
||||
}
|
||||
|
||||
//SearchQuery returns the SearchQuery of the sub query
|
||||
func (q *Query) SearchQuery() *SearchQuery {
|
||||
return q.searchQuery
|
||||
}
|
||||
func (q *Query) setFilter(filter *Filter) *Query {
|
||||
for i, f := range q.Filters {
|
||||
if f.field == filter.field && f.field != Field_LatestSequence {
|
||||
q.Filters[i] = filter
|
||||
return q
|
||||
}
|
||||
}
|
||||
q.Filters = append(q.Filters, filter)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) SetLimit(limit uint64) *SearchQuery {
|
||||
q.Limit = limit
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) OrderDesc() *SearchQuery {
|
||||
q.Desc = true
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) OrderAsc() *SearchQuery {
|
||||
q.Desc = false
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) AggregateIDFilter(id string) *Query {
|
||||
return q.setFilter(NewFilter(Field_AggregateID, id, Operation_Equals))
|
||||
}
|
||||
|
||||
func (q *Query) AggregateIDsFilter(ids ...string) *Query {
|
||||
return q.setFilter(NewFilter(Field_AggregateID, ids, Operation_In))
|
||||
}
|
||||
|
||||
func (q *Query) AggregateTypeFilter(types ...AggregateType) *Query {
|
||||
return q.setFilter(NewFilter(Field_AggregateType, types, Operation_In))
|
||||
}
|
||||
|
||||
func (q *Query) EventTypesFilter(types ...EventType) *Query {
|
||||
return q.setFilter(NewFilter(Field_EventType, types, Operation_In))
|
||||
}
|
||||
|
||||
func (q *Query) LatestSequenceFilter(sequence uint64) *Query {
|
||||
if sequence == 0 {
|
||||
return q
|
||||
}
|
||||
sortOrder := Operation_Greater
|
||||
return q.setFilter(NewFilter(Field_LatestSequence, sequence, sortOrder))
|
||||
}
|
||||
|
||||
func (q *Query) SequenceBetween(from, to uint64) *Query {
|
||||
q.setFilter(NewFilter(Field_LatestSequence, from, Operation_Greater))
|
||||
q.setFilter(NewFilter(Field_LatestSequence, to, Operation_Less))
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *Query) ResourceOwnerFilter(resourceOwner string) *Query {
|
||||
return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals))
|
||||
}
|
||||
|
||||
func (q *Query) InstanceIDFilter(instanceID string) *Query {
|
||||
return q.setFilter(NewFilter(Field_InstanceID, instanceID, Operation_Equals))
|
||||
}
|
||||
|
||||
func (q *Query) ExcludedInstanceIDsFilter(instanceIDs ...string) *Query {
|
||||
return q.setFilter(NewFilter(Field_InstanceID, instanceIDs, Operation_NotIn))
|
||||
}
|
||||
|
||||
func (q *Query) CreationDateNewerFilter(time time.Time) *Query {
|
||||
return q.setFilter(NewFilter(Field_CreationDate, time, Operation_Greater))
|
||||
}
|
||||
|
||||
func (q *SearchQuery) setFilter(filter *Filter) *SearchQuery {
|
||||
for i, f := range q.Filters {
|
||||
if f.field == filter.field && f.field != Field_LatestSequence {
|
||||
q.Filters[i] = filter
|
||||
return q
|
||||
}
|
||||
}
|
||||
q.Filters = append(q.Filters, filter)
|
||||
return q
|
||||
}
|
||||
|
||||
func (q *SearchQuery) Validate() error {
|
||||
if q == nil {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-J5xQi", "search query is nil")
|
||||
}
|
||||
if len(q.Queries) == 0 {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-pF3DR", "no filters set")
|
||||
}
|
||||
for _, query := range q.Queries {
|
||||
for _, filter := range query.Filters {
|
||||
if err := filter.Validate(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
@@ -1,65 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestSearchQuery_setFilter(t *testing.T) {
|
||||
type fields struct {
|
||||
query *SearchQuery
|
||||
}
|
||||
type args struct {
|
||||
filters []*Filter
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want *SearchQuery
|
||||
}{
|
||||
{
|
||||
name: "set idFilter",
|
||||
fields: fields{query: NewSearchQuery()},
|
||||
args: args{filters: []*Filter{
|
||||
{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
|
||||
}},
|
||||
want: &SearchQuery{Filters: []*Filter{
|
||||
{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
|
||||
}},
|
||||
},
|
||||
{
|
||||
name: "overwrite idFilter",
|
||||
fields: fields{query: NewSearchQuery()},
|
||||
args: args{filters: []*Filter{
|
||||
{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
|
||||
{field: Field_AggregateID, operation: Operation_Equals, value: "ursli"},
|
||||
}},
|
||||
want: &SearchQuery{Filters: []*Filter{
|
||||
{field: Field_AggregateID, operation: Operation_Equals, value: "ursli"},
|
||||
}},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := tt.fields.query
|
||||
for _, filter := range tt.args.filters {
|
||||
got = got.setFilter(filter)
|
||||
}
|
||||
for _, wantFilter := range tt.want.Filters {
|
||||
found := false
|
||||
for _, gotFilter := range got.Filters {
|
||||
if gotFilter.field == wantFilter.field {
|
||||
found = true
|
||||
if !reflect.DeepEqual(wantFilter, gotFilter) {
|
||||
t.Errorf("filter not as expected: want: %v got %v", wantFilter, gotFilter)
|
||||
}
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("filter field %v not found", wantFilter.field)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,590 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
func testSetColumns(columns Columns) func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
factory = factory.Columns(columns)
|
||||
return factory
|
||||
}
|
||||
}
|
||||
|
||||
func testSetLimit(limit uint64) func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
factory = factory.Limit(limit)
|
||||
return factory
|
||||
}
|
||||
}
|
||||
|
||||
func testAddQuery(queryFuncs ...func(*query) *query) func(*SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(builder *SearchQueryFactory) *SearchQueryFactory {
|
||||
query := builder.AddQuery()
|
||||
for _, queryFunc := range queryFuncs {
|
||||
queryFunc(query)
|
||||
}
|
||||
return query.Factory()
|
||||
}
|
||||
}
|
||||
|
||||
func testSetSequence(sequence uint64) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.SequenceGreater(sequence)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetAggregateIDs(aggregateIDs ...string) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.AggregateIDs(aggregateIDs...)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetAggregateTypes(aggregateTypes ...AggregateType) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.AggregateTypes(aggregateTypes...)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetEventTypes(eventTypes ...EventType) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.EventTypes(eventTypes...)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetResourceOwner(resourceOwner string) func(*query) *query {
|
||||
return func(q *query) *query {
|
||||
q.ResourceOwner(resourceOwner)
|
||||
return q
|
||||
}
|
||||
}
|
||||
|
||||
func testSetSortOrder(asc bool) func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
return func(factory *SearchQueryFactory) *SearchQueryFactory {
|
||||
if asc {
|
||||
factory = factory.OrderAsc()
|
||||
} else {
|
||||
factory = factory.OrderDesc()
|
||||
}
|
||||
return factory
|
||||
}
|
||||
}
|
||||
|
||||
func assertFactory(t *testing.T, want, got *SearchQueryFactory) {
|
||||
t.Helper()
|
||||
|
||||
if got.columns != want.columns {
|
||||
t.Errorf("wrong column: got: %v want: %v", got.columns, want.columns)
|
||||
}
|
||||
if got.desc != want.desc {
|
||||
t.Errorf("wrong desc: got: %v want: %v", got.desc, want.desc)
|
||||
}
|
||||
if got.limit != want.limit {
|
||||
t.Errorf("wrong limit: got: %v want: %v", got.limit, want.limit)
|
||||
}
|
||||
if len(got.queries) != len(want.queries) {
|
||||
t.Errorf("wrong length of queries: got: %v want: %v", len(got.queries), len(want.queries))
|
||||
}
|
||||
|
||||
for i, query := range got.queries {
|
||||
assertQuery(t, i, want.queries[i], query)
|
||||
}
|
||||
}
|
||||
|
||||
func assertQuery(t *testing.T, i int, want, got *query) {
|
||||
t.Helper()
|
||||
|
||||
if !reflect.DeepEqual(got.aggregateIDs, want.aggregateIDs) {
|
||||
t.Errorf("wrong aggregateIDs in query %d : got: %v want: %v", i, got.aggregateIDs, want.aggregateIDs)
|
||||
}
|
||||
if !reflect.DeepEqual(got.aggregateTypes, want.aggregateTypes) {
|
||||
t.Errorf("wrong aggregateTypes in query %d : got: %v want: %v", i, got.aggregateTypes, want.aggregateTypes)
|
||||
}
|
||||
if got.sequenceFrom != want.sequenceFrom {
|
||||
t.Errorf("wrong sequenceFrom in query %d : got: %v want: %v", i, got.sequenceFrom, want.sequenceFrom)
|
||||
}
|
||||
if got.sequenceTo != want.sequenceTo {
|
||||
t.Errorf("wrong sequenceTo in query %d : got: %v want: %v", i, got.sequenceTo, want.sequenceTo)
|
||||
}
|
||||
if !reflect.DeepEqual(got.eventTypes, want.eventTypes) {
|
||||
t.Errorf("wrong eventTypes in query %d : got: %v want: %v", i, got.eventTypes, want.eventTypes)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchQueryFactorySetters(t *testing.T) {
|
||||
type args struct {
|
||||
setters []func(*SearchQueryFactory) *SearchQueryFactory
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res *SearchQueryFactory
|
||||
}{
|
||||
{
|
||||
name: "New factory",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
|
||||
},
|
||||
res: &SearchQueryFactory{},
|
||||
},
|
||||
{
|
||||
name: "set columns",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetColumns(Columns_Max_Sequence)},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
columns: Columns_Max_Sequence,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set limit",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetLimit(100)},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
limit: 100,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set sequence",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetSequence(90))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
queries: []*query{
|
||||
{
|
||||
sequenceFrom: 90,
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set aggregateTypes",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user", "org"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
queries: []*query{
|
||||
{
|
||||
aggregateTypes: []AggregateType{"user", "org"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set aggregateIDs",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateIDs("1235", "09824"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
queries: []*query{
|
||||
{
|
||||
aggregateIDs: []string{"1235", "09824"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set eventTypes",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetEventTypes("user.created", "user.updated"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
queries: []*query{
|
||||
{
|
||||
eventTypes: []EventType{"user.created", "user.updated"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "set resource owner",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetResourceOwner("hodor"))},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
queries: []*query{
|
||||
{
|
||||
resourceOwner: "hodor",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "default search query",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testAddQuery(testSetAggregateTypes("user"), testSetAggregateIDs("1235", "024")), testSetSortOrder(false)},
|
||||
},
|
||||
res: &SearchQueryFactory{
|
||||
desc: true,
|
||||
queries: []*query{
|
||||
{
|
||||
aggregateTypes: []AggregateType{"user"},
|
||||
aggregateIDs: []string{"1235", "024"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
factory := NewSearchQueryFactory()
|
||||
for _, setter := range tt.args.setters {
|
||||
factory = setter(factory)
|
||||
}
|
||||
assertFactory(t, tt.res, factory)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSearchQueryFactoryBuild(t *testing.T) {
|
||||
type args struct {
|
||||
setters []func(*SearchQueryFactory) *SearchQueryFactory
|
||||
}
|
||||
type res struct {
|
||||
isErr func(err error) bool
|
||||
query *searchQuery
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "no aggregate types",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
|
||||
},
|
||||
res: res{
|
||||
isErr: errors.IsPreconditionFailed,
|
||||
query: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid column (too low)",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetColumns(Columns(-1)),
|
||||
testAddQuery(testSetAggregateTypes("user")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: errors.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid column (too high)",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetColumns(columnsCount),
|
||||
testAddQuery(testSetAggregateTypes("user")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: errors.IsPreconditionFailed,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(testSetAggregateTypes("user")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate types",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(testSetAggregateTypes("user", "org")),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type, limit, desc",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetLimit(5),
|
||||
testSetSortOrder(false),
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetSequence(100),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: true,
|
||||
Limit: 5,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type, limit, asc",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetLimit(5),
|
||||
testSetSortOrder(true),
|
||||
testAddQuery(
|
||||
testSetSequence(100),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 5,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Greater),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type, limit, desc, max event sequence cols",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testSetLimit(5),
|
||||
testSetSortOrder(false),
|
||||
testSetColumns(Columns_Max_Sequence),
|
||||
testAddQuery(
|
||||
testSetSequence(100),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: Columns_Max_Sequence,
|
||||
Desc: true,
|
||||
Limit: 5,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type and aggregate id",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(
|
||||
testSetAggregateIDs("1234"),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_AggregateID, "1234", Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type and aggregate ids",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(
|
||||
testSetAggregateIDs("1234", "0815"),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type and sequence greater",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(
|
||||
testSetSequence(8),
|
||||
testSetAggregateTypes("user"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_LatestSequence, uint64(8), Operation_Greater),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type and event type",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetEventTypes("user.created"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_EventType, EventType("user.created"), Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type and event types",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetEventTypes("user.created", "user.changed"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "filter aggregate type resource owner",
|
||||
args: args{
|
||||
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
|
||||
testAddQuery(
|
||||
testSetAggregateTypes("user"),
|
||||
testSetResourceOwner("hodor"),
|
||||
),
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
isErr: nil,
|
||||
query: &searchQuery{
|
||||
Columns: 0,
|
||||
Desc: false,
|
||||
Limit: 0,
|
||||
Filters: [][]*Filter{
|
||||
{
|
||||
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
|
||||
NewFilter(Field_ResourceOwner, "hodor", Operation_Equals),
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
factory := NewSearchQueryFactory()
|
||||
for _, f := range tt.args.setters {
|
||||
factory = f(factory)
|
||||
}
|
||||
query, err := factory.Build()
|
||||
if tt.res.isErr != nil && !tt.res.isErr(err) {
|
||||
t.Errorf("wrong error: %v", err)
|
||||
return
|
||||
}
|
||||
if err != nil && tt.res.isErr == nil {
|
||||
t.Errorf("no error expected: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(query, tt.res.query) {
|
||||
t.Errorf("NewSearchQueryFactory() = %v, want %v", factory, tt.res)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,22 +0,0 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
var versionRegexp = regexp.MustCompile(`^v[0-9]+(\.[0-9]+){0,2}$`)
|
||||
|
||||
type Version string
|
||||
|
||||
func (v Version) Validate() error {
|
||||
if !versionRegexp.MatchString(string(v)) {
|
||||
return errors.ThrowPreconditionFailed(nil, "MODEL-luDuS", "version is not semver")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (v Version) String() string {
|
||||
return string(v)
|
||||
}
|
@@ -1,39 +0,0 @@
|
||||
package models
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestVersion_Validate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
v Version
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
"correct version",
|
||||
"v1.23.23",
|
||||
false,
|
||||
},
|
||||
{
|
||||
"no v prefix",
|
||||
"1.2.2",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"letters in version",
|
||||
"v1.as.3",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"no version",
|
||||
"",
|
||||
true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.v.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Version.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,93 +0,0 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
const (
|
||||
eventLimit = 10000
|
||||
)
|
||||
|
||||
type Handler interface {
|
||||
ViewModel() string
|
||||
EventQuery(ctx context.Context, instanceIDs []string) (*models.SearchQuery, error)
|
||||
Reduce(*models.Event) error
|
||||
OnError(event *models.Event, err error) error
|
||||
OnSuccess(instanceIDs []string) error
|
||||
MinimumCycleDuration() time.Duration
|
||||
LockDuration() time.Duration
|
||||
QueryLimit() uint64
|
||||
|
||||
AggregateTypes() []models.AggregateType
|
||||
CurrentSequence(ctx context.Context, instanceID string) (uint64, error)
|
||||
Eventstore() v1.Eventstore
|
||||
|
||||
Subscription() *v1.Subscription
|
||||
}
|
||||
|
||||
func ReduceEvent(ctx context.Context, handler Handler, event *models.Event) {
|
||||
defer func() {
|
||||
err := recover()
|
||||
|
||||
if err != nil {
|
||||
handler.Subscription().Unsubscribe()
|
||||
logging.WithFields(
|
||||
"cause", err,
|
||||
"stack", string(debug.Stack()),
|
||||
"sequence", event.Sequence,
|
||||
"instance", event.InstanceID,
|
||||
).Error("reduce panicked")
|
||||
}
|
||||
}()
|
||||
currentSequence, err := handler.CurrentSequence(ctx, event.InstanceID)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("unable to get current sequence")
|
||||
return
|
||||
}
|
||||
|
||||
searchQuery := models.NewSearchQuery().
|
||||
AddQuery().
|
||||
AggregateTypeFilter(handler.AggregateTypes()...).
|
||||
SequenceBetween(currentSequence, event.Sequence).
|
||||
InstanceIDFilter(event.InstanceID).
|
||||
SearchQuery().
|
||||
SetLimit(eventLimit)
|
||||
|
||||
unprocessedEvents, err := handler.Eventstore().FilterEvents(ctx, searchQuery)
|
||||
if err != nil {
|
||||
logging.WithFields("sequence", event.Sequence).Warn("filter failed")
|
||||
return
|
||||
}
|
||||
|
||||
for _, unprocessedEvent := range unprocessedEvents {
|
||||
currentSequence, err := handler.CurrentSequence(ctx, unprocessedEvent.InstanceID)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("unable to get current sequence")
|
||||
return
|
||||
}
|
||||
if unprocessedEvent.Sequence < currentSequence {
|
||||
logging.WithFields(
|
||||
"unprocessed", unprocessedEvent.Sequence,
|
||||
"current", currentSequence,
|
||||
"view", handler.ViewModel()).
|
||||
Warn("sequence not matching")
|
||||
return
|
||||
}
|
||||
|
||||
err = handler.Reduce(unprocessedEvent)
|
||||
logging.WithFields("sequence", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
|
||||
}
|
||||
if len(unprocessedEvents) == eventLimit {
|
||||
logging.WithFields("sequence", event.Sequence).Warn("didnt process event")
|
||||
return
|
||||
}
|
||||
err = handler.Reduce(event)
|
||||
logging.WithFields("sequence", event.Sequence).OnError(err).Warn("reduce failed")
|
||||
}
|
@@ -1,36 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
var (
|
||||
_ AppendEventError = (*appendEventError)(nil)
|
||||
_ errors.Error = (*appendEventError)(nil)
|
||||
)
|
||||
|
||||
type AppendEventError interface {
|
||||
error
|
||||
IsAppendEventError()
|
||||
}
|
||||
|
||||
type appendEventError struct {
|
||||
*errors.CaosError
|
||||
}
|
||||
|
||||
func ThrowAppendEventError(parent error, id, message string) error {
|
||||
return &appendEventError{errors.CreateCaosError(parent, id, message)}
|
||||
}
|
||||
|
||||
func ThrowAggregaterf(parent error, id, format string, a ...interface{}) error {
|
||||
return ThrowAppendEventError(parent, id, fmt.Sprintf(format, a...))
|
||||
}
|
||||
|
||||
func (err *appendEventError) IsAppendEventError() {}
|
||||
|
||||
func IsAppendEventError(err error) bool {
|
||||
_, ok := err.(AppendEventError)
|
||||
return ok
|
||||
}
|
@@ -1,31 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestAppendEventError(t *testing.T) {
|
||||
var err interface{}
|
||||
err = new(appendEventError)
|
||||
_, ok := err.(*appendEventError)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestThrowAppendEventErrorf(t *testing.T) {
|
||||
err := ThrowAggregaterf(nil, "id", "msg")
|
||||
_, ok := err.(*appendEventError)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func TestIsAppendEventError(t *testing.T) {
|
||||
err := ThrowAppendEventError(nil, "id", "msg")
|
||||
ok := IsAppendEventError(err)
|
||||
assert.True(t, ok)
|
||||
|
||||
err = errors.New("i am found")
|
||||
ok = IsAppendEventError(err)
|
||||
assert.False(t, ok)
|
||||
}
|
@@ -1,27 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
type filterFunc func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error)
|
||||
type appendFunc func(...*es_models.Event) error
|
||||
type AggregateFunc func(context.Context) (*es_models.Aggregate, error)
|
||||
|
||||
func Filter(ctx context.Context, filter filterFunc, appender appendFunc, query *es_models.SearchQuery) error {
|
||||
events, err := filter(ctx, query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(events) == 0 {
|
||||
return errors.ThrowNotFound(nil, "EVENT-8due3", "no events found")
|
||||
}
|
||||
err = appender(events...)
|
||||
if err != nil {
|
||||
return ThrowAppendEventError(err, "SDK-awiWK", "Errors.Internal")
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,77 +0,0 @@
|
||||
package sdk
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
func TestFilter(t *testing.T) {
|
||||
type args struct {
|
||||
filter filterFunc
|
||||
appender appendFunc
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
wantErr func(error) bool
|
||||
}{
|
||||
{
|
||||
name: "filter error",
|
||||
args: args{
|
||||
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
|
||||
return nil, errors.ThrowInternal(nil, "test-46VX2", "test error")
|
||||
},
|
||||
appender: nil,
|
||||
},
|
||||
wantErr: errors.IsInternal,
|
||||
},
|
||||
{
|
||||
name: "no events found",
|
||||
args: args{
|
||||
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
|
||||
return []*es_models.Event{}, nil
|
||||
},
|
||||
appender: nil,
|
||||
},
|
||||
wantErr: errors.IsNotFound,
|
||||
},
|
||||
{
|
||||
name: "append fails",
|
||||
args: args{
|
||||
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
|
||||
return []*es_models.Event{&es_models.Event{}}, nil
|
||||
},
|
||||
appender: func(...*es_models.Event) error {
|
||||
return errors.ThrowInvalidArgument(nil, "SDK-DhBzl", "test error")
|
||||
},
|
||||
},
|
||||
wantErr: IsAppendEventError,
|
||||
},
|
||||
{
|
||||
name: "filter correct",
|
||||
args: args{
|
||||
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
|
||||
return []*es_models.Event{&es_models.Event{}}, nil
|
||||
},
|
||||
appender: func(...*es_models.Event) error {
|
||||
return nil
|
||||
},
|
||||
},
|
||||
wantErr: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := Filter(context.Background(), tt.args.filter, tt.args.appender, nil)
|
||||
if tt.wantErr == nil && err != nil {
|
||||
t.Errorf("no error expected %v", err)
|
||||
}
|
||||
if tt.wantErr != nil && !tt.wantErr(err) {
|
||||
t.Errorf("no error has wrong type %v", err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,42 +0,0 @@
|
||||
package spooler
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/query"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Eventstore v1.Eventstore
|
||||
EventstoreV2 *eventstore.Eventstore
|
||||
Locker Locker
|
||||
ViewHandlers []query.Handler
|
||||
ConcurrentWorkers int
|
||||
ConcurrentInstances int
|
||||
}
|
||||
|
||||
func (c *Config) New() *Spooler {
|
||||
lockID, err := id.SonyFlakeGenerator().Next()
|
||||
logging.OnError(err).Panic("unable to generate lockID")
|
||||
|
||||
//shuffle the handlers for better balance when running multiple pods
|
||||
rand.Shuffle(len(c.ViewHandlers), func(i, j int) {
|
||||
c.ViewHandlers[i], c.ViewHandlers[j] = c.ViewHandlers[j], c.ViewHandlers[i]
|
||||
})
|
||||
|
||||
return &Spooler{
|
||||
handlers: c.ViewHandlers,
|
||||
lockID: lockID,
|
||||
eventstore: c.Eventstore,
|
||||
esV2: c.EventstoreV2,
|
||||
locker: c.Locker,
|
||||
queue: make(chan *spooledHandler, len(c.ViewHandlers)),
|
||||
workers: c.ConcurrentWorkers,
|
||||
concurrentInstances: c.ConcurrentInstances,
|
||||
}
|
||||
}
|
@@ -1,3 +0,0 @@
|
||||
package spooler
|
||||
|
||||
//go:generate mockgen -source spooler.go -destination ./mock/spooler.go -package mock
|
@@ -1,49 +0,0 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: spooler.go
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
time "time"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
// MockLocker is a mock of Locker interface.
|
||||
type MockLocker struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockLockerMockRecorder
|
||||
}
|
||||
|
||||
// MockLockerMockRecorder is the mock recorder for MockLocker.
|
||||
type MockLockerMockRecorder struct {
|
||||
mock *MockLocker
|
||||
}
|
||||
|
||||
// NewMockLocker creates a new mock instance.
|
||||
func NewMockLocker(ctrl *gomock.Controller) *MockLocker {
|
||||
mock := &MockLocker{ctrl: ctrl}
|
||||
mock.recorder = &MockLockerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockLocker) EXPECT() *MockLockerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Renew mocks base method.
|
||||
func (m *MockLocker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, instanceID, waitTime)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Renew indicates an expected call of Renew.
|
||||
func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, instanceID, waitTime interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, instanceID, waitTime)
|
||||
}
|
@@ -1,292 +0,0 @@
|
||||
package spooler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"runtime/debug"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/view/repository"
|
||||
)
|
||||
|
||||
const (
|
||||
systemID = "system"
|
||||
schedulerSucceeded = eventstore.EventType("system.projections.scheduler.succeeded")
|
||||
aggregateType = eventstore.AggregateType("system")
|
||||
aggregateID = "SYSTEM"
|
||||
)
|
||||
|
||||
type Spooler struct {
|
||||
handlers []query.Handler
|
||||
locker Locker
|
||||
lockID string
|
||||
eventstore v1.Eventstore
|
||||
esV2 *eventstore.Eventstore
|
||||
workers int
|
||||
queue chan *spooledHandler
|
||||
concurrentInstances int
|
||||
}
|
||||
|
||||
type Locker interface {
|
||||
Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error
|
||||
}
|
||||
|
||||
type spooledHandler struct {
|
||||
query.Handler
|
||||
locker Locker
|
||||
queuedAt time.Time
|
||||
eventstore v1.Eventstore
|
||||
esV2 *eventstore.Eventstore
|
||||
concurrentInstances int
|
||||
succeededOnce bool
|
||||
}
|
||||
|
||||
func (s *Spooler) Start() {
|
||||
defer logging.WithFields("lockerID", s.lockID, "workers", s.workers).Info("spooler started")
|
||||
if s.workers < 1 {
|
||||
return
|
||||
}
|
||||
|
||||
for i := 0; i < s.workers; i++ {
|
||||
go func(workerIdx int) {
|
||||
workerID := s.lockID + "--" + strconv.Itoa(workerIdx)
|
||||
for task := range s.queue {
|
||||
go requeueTask(task, s.queue)
|
||||
task.load(workerID)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
go func() {
|
||||
for _, handler := range s.handlers {
|
||||
s.queue <- &spooledHandler{Handler: handler, locker: s.locker, queuedAt: time.Now(), eventstore: s.eventstore, esV2: s.esV2, concurrentInstances: s.concurrentInstances}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func requeueTask(task *spooledHandler, queue chan<- *spooledHandler) {
|
||||
time.Sleep(task.MinimumCycleDuration() - time.Since(task.queuedAt))
|
||||
task.queuedAt = time.Now()
|
||||
queue <- task
|
||||
}
|
||||
|
||||
func (s *spooledHandler) hasSucceededOnce(ctx context.Context) (bool, error) {
|
||||
events, err := s.esV2.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
AggregateIDs(aggregateID).
|
||||
EventTypes(schedulerSucceeded).
|
||||
EventData(map[string]interface{}{
|
||||
"name": s.ViewModel(),
|
||||
}).
|
||||
Builder(),
|
||||
)
|
||||
return len(events) > 0 && err == nil, err
|
||||
}
|
||||
|
||||
func (s *spooledHandler) setSucceededOnce(ctx context.Context) error {
|
||||
_, err := s.esV2.Push(ctx, &handler.ProjectionSucceededEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(ctx,
|
||||
eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"),
|
||||
schedulerSucceeded,
|
||||
),
|
||||
Name: s.ViewModel(),
|
||||
})
|
||||
s.succeededOnce = err == nil
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *spooledHandler) load(workerID string) {
|
||||
errs := make(chan error)
|
||||
defer func() {
|
||||
close(errs)
|
||||
err := recover()
|
||||
|
||||
if err != nil {
|
||||
logging.WithFields(
|
||||
"cause", err,
|
||||
"stack", string(debug.Stack()),
|
||||
).Error("reduce panicked")
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
go s.awaitError(cancel, errs, workerID)
|
||||
hasLocked := s.lock(ctx, errs, workerID)
|
||||
|
||||
if <-hasLocked {
|
||||
if !s.succeededOnce {
|
||||
var err error
|
||||
s.succeededOnce, err = s.hasSucceededOnce(ctx)
|
||||
if err != nil {
|
||||
logging.WithFields("view", s.ViewModel()).OnError(err).Debug("initial lock failed for first schedule")
|
||||
errs <- err
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
instanceIDQuery := models.NewSearchQuery().SetColumn(models.Columns_InstanceIDs).AddQuery().ExcludedInstanceIDsFilter("")
|
||||
for {
|
||||
if s.succeededOnce {
|
||||
// since we have at least one successful run, we can restrict it to events not older than
|
||||
// twice the requeue time (just to be sure not to miss an event)
|
||||
instanceIDQuery = instanceIDQuery.CreationDateNewerFilter(time.Now().Add(-2 * s.MinimumCycleDuration()))
|
||||
}
|
||||
ids, err := s.eventstore.InstanceIDs(ctx, instanceIDQuery.SearchQuery())
|
||||
if err != nil {
|
||||
errs <- err
|
||||
break
|
||||
}
|
||||
for i := 0; i < len(ids); i = i + s.concurrentInstances {
|
||||
max := i + s.concurrentInstances
|
||||
if max > len(ids) {
|
||||
max = len(ids)
|
||||
}
|
||||
err = s.processInstances(ctx, workerID, ids[i:max])
|
||||
if err != nil {
|
||||
errs <- err
|
||||
}
|
||||
}
|
||||
if ctx.Err() == nil {
|
||||
if !s.succeededOnce {
|
||||
err = s.setSucceededOnce(ctx)
|
||||
logging.WithFields("view", s.ViewModel()).OnError(err).Warn("unable to push first schedule succeeded")
|
||||
}
|
||||
errs <- nil
|
||||
}
|
||||
break
|
||||
}
|
||||
}
|
||||
<-ctx.Done()
|
||||
}
|
||||
|
||||
func (s *spooledHandler) processInstances(ctx context.Context, workerID string, ids []string) error {
|
||||
for {
|
||||
processCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
events, err := s.query(processCtx, ids)
|
||||
if err != nil {
|
||||
cancel()
|
||||
return err
|
||||
}
|
||||
if len(events) == 0 {
|
||||
cancel()
|
||||
return nil
|
||||
}
|
||||
err = s.process(processCtx, events, workerID, ids)
|
||||
cancel()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if uint64(len(events)) < s.QueryLimit() {
|
||||
// no more events to process
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *spooledHandler) awaitError(cancel func(), errs chan error, workerID string) {
|
||||
select {
|
||||
case err := <-errs:
|
||||
cancel()
|
||||
logging.OnError(err).WithField("view", s.ViewModel()).WithField("worker", workerID).Debug("load canceled")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *spooledHandler) process(ctx context.Context, events []*models.Event, workerID string, instanceIDs []string) error {
|
||||
for i, event := range events {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logging.WithFields("view", s.ViewModel(), "worker", workerID, "traceID", tracing.TraceIDFromCtx(ctx)).Debug("context canceled")
|
||||
return nil
|
||||
default:
|
||||
if err := s.Reduce(event); err != nil {
|
||||
err = s.OnError(event, err)
|
||||
if err == nil {
|
||||
continue
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
return s.process(ctx, events[i:], workerID, instanceIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
err := s.OnSuccess(instanceIDs)
|
||||
logging.WithFields("view", s.ViewModel(), "worker", workerID, "traceID", tracing.TraceIDFromCtx(ctx)).OnError(err).Warn("could not process on success func")
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *spooledHandler) query(ctx context.Context, instanceIDs []string) ([]*models.Event, error) {
|
||||
query, err := s.EventQuery(ctx, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query.Limit = s.QueryLimit()
|
||||
return s.eventstore.FilterEvents(ctx, query)
|
||||
}
|
||||
|
||||
// lock ensures the lock on the database.
|
||||
// the returned channel will be closed if ctx is done or an error occured durring lock
|
||||
func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID string) chan bool {
|
||||
renewTimer := time.After(0)
|
||||
locked := make(chan bool)
|
||||
|
||||
go func(locked chan bool) {
|
||||
var firstLock sync.Once
|
||||
defer close(locked)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-renewTimer:
|
||||
err := s.locker.Renew(workerID, s.ViewModel(), systemID, s.LockDuration())
|
||||
firstLock.Do(func() {
|
||||
locked <- err == nil
|
||||
})
|
||||
if err == nil {
|
||||
renewTimer = time.After(s.LockDuration())
|
||||
continue
|
||||
}
|
||||
|
||||
if ctx.Err() == nil {
|
||||
errs <- err
|
||||
}
|
||||
return
|
||||
}
|
||||
}
|
||||
}(locked)
|
||||
|
||||
return locked
|
||||
}
|
||||
|
||||
func HandleError(event *models.Event, failedErr error,
|
||||
latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error),
|
||||
processFailedEvent func(*repository.FailedEvent) error,
|
||||
processSequence func(*models.Event) error,
|
||||
errorCountUntilSkip uint64) error {
|
||||
failedEvent, err := latestFailedEvent(event.Sequence, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
failedEvent.FailureCount++
|
||||
failedEvent.ErrMsg = failedErr.Error()
|
||||
failedEvent.InstanceID = event.InstanceID
|
||||
failedEvent.LastFailed = time.Now()
|
||||
err = processFailedEvent(failedEvent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if errorCountUntilSkip <= failedEvent.FailureCount {
|
||||
return processSequence(event)
|
||||
}
|
||||
return failedErr
|
||||
}
|
||||
|
||||
func HandleSuccess(updateSpoolerRunTimestamp func([]string) error, instanceIDs []string) error {
|
||||
return updateSpoolerRunTimestamp(instanceIDs)
|
||||
}
|
@@ -1,514 +0,0 @@
|
||||
package spooler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
v1 "github.com/zitadel/zitadel/internal/eventstore/v1"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/query"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/spooler/mock"
|
||||
"github.com/zitadel/zitadel/internal/view/repository"
|
||||
)
|
||||
|
||||
var (
|
||||
testNow = time.Now()
|
||||
)
|
||||
|
||||
type testHandler struct {
|
||||
cycleDuration time.Duration
|
||||
processSleep time.Duration
|
||||
processError error
|
||||
queryError error
|
||||
viewModel string
|
||||
bulkLimit uint64
|
||||
maxErrCount int
|
||||
}
|
||||
|
||||
func (h *testHandler) AggregateTypes() []models.AggregateType {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHandler) CurrentSequence(ctx context.Context, instanceID string) (uint64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (h *testHandler) Eventstore() v1.Eventstore {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHandler) ViewModel() string {
|
||||
return h.viewModel
|
||||
}
|
||||
|
||||
func (h *testHandler) Subscription() *v1.Subscription {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHandler) EventQuery(ctx context.Context, instanceIDs []string) (*models.SearchQuery, error) {
|
||||
if h.queryError != nil {
|
||||
return nil, h.queryError
|
||||
}
|
||||
return &models.SearchQuery{}, nil
|
||||
}
|
||||
|
||||
func (h *testHandler) Reduce(*models.Event) error {
|
||||
<-time.After(h.processSleep)
|
||||
return h.processError
|
||||
}
|
||||
|
||||
func (h *testHandler) OnError(event *models.Event, err error) error {
|
||||
if h.maxErrCount == 2 {
|
||||
return nil
|
||||
}
|
||||
h.maxErrCount++
|
||||
return err
|
||||
}
|
||||
|
||||
func (h *testHandler) OnSuccess([]string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *testHandler) MinimumCycleDuration() time.Duration {
|
||||
return h.cycleDuration
|
||||
}
|
||||
|
||||
func (h *testHandler) LockDuration() time.Duration {
|
||||
return h.cycleDuration / 2
|
||||
}
|
||||
|
||||
func (h *testHandler) QueryLimit() uint64 {
|
||||
return h.bulkLimit
|
||||
}
|
||||
|
||||
type eventstoreStub struct {
|
||||
events []*models.Event
|
||||
err error
|
||||
}
|
||||
|
||||
func (es *eventstoreStub) Subscribe(...models.AggregateType) *v1.Subscription { return nil }
|
||||
|
||||
func (es *eventstoreStub) Health(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventstoreStub) AggregateCreator() *models.AggregateCreator {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventstoreStub) FilterEvents(ctx context.Context, in *models.SearchQuery) ([]*models.Event, error) {
|
||||
if es.err != nil {
|
||||
return nil, es.err
|
||||
}
|
||||
return es.events, nil
|
||||
}
|
||||
func (es *eventstoreStub) PushAggregates(ctx context.Context, in ...*models.Aggregate) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (es *eventstoreStub) LatestSequence(ctx context.Context, in *models.SearchQueryFactory) (uint64, error) {
|
||||
return 0, nil
|
||||
}
|
||||
func (es *eventstoreStub) InstanceIDs(ctx context.Context, in *models.SearchQuery) ([]string, error) {
|
||||
return nil, nil
|
||||
}
|
||||
func (es *eventstoreStub) V2() *eventstore.Eventstore {
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestSpooler_process(t *testing.T) {
|
||||
type fields struct {
|
||||
currentHandler *testHandler
|
||||
}
|
||||
type args struct {
|
||||
timeout time.Duration
|
||||
events []*models.Event
|
||||
instanceIDs []string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
wantErr bool
|
||||
wantRetries int
|
||||
}{
|
||||
{
|
||||
name: "process all events",
|
||||
fields: fields{
|
||||
currentHandler: &testHandler{},
|
||||
},
|
||||
args: args{
|
||||
timeout: 0,
|
||||
events: []*models.Event{{}, {}},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "deadline exeeded",
|
||||
fields: fields{
|
||||
currentHandler: &testHandler{processSleep: 501 * time.Millisecond},
|
||||
},
|
||||
args: args{
|
||||
timeout: 1 * time.Second,
|
||||
events: []*models.Event{{}, {}, {}, {}},
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "process error",
|
||||
fields: fields{
|
||||
currentHandler: &testHandler{processSleep: 1 * time.Second, processError: fmt.Errorf("i am an error")},
|
||||
},
|
||||
args: args{
|
||||
events: []*models.Event{{}, {}},
|
||||
},
|
||||
wantErr: false,
|
||||
wantRetries: 2,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &spooledHandler{
|
||||
Handler: tt.fields.currentHandler,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
var start time.Time
|
||||
if tt.args.timeout > 0 {
|
||||
ctx, _ = context.WithTimeout(ctx, tt.args.timeout)
|
||||
start = time.Now()
|
||||
}
|
||||
|
||||
if err := s.process(ctx, tt.args.events, "test", tt.args.instanceIDs); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Spooler.process() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
if tt.fields.currentHandler.maxErrCount != tt.wantRetries {
|
||||
t.Errorf("Spooler.process() wrong retry count got: %d want %d", tt.fields.currentHandler.maxErrCount, tt.wantRetries)
|
||||
}
|
||||
|
||||
elapsed := time.Since(start).Round(1 * time.Second)
|
||||
if tt.args.timeout != 0 && elapsed != tt.args.timeout {
|
||||
t.Errorf("wrong timeout wanted %v elapsed %v since %v", tt.args.timeout, elapsed, time.Since(start))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpooler_awaitError(t *testing.T) {
|
||||
type fields struct {
|
||||
currentHandler query.Handler
|
||||
err error
|
||||
canceled bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
}{
|
||||
{
|
||||
"no error",
|
||||
fields{
|
||||
err: nil,
|
||||
currentHandler: &testHandler{processSleep: 500 * time.Millisecond},
|
||||
canceled: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"with error",
|
||||
fields{
|
||||
err: fmt.Errorf("hodor"),
|
||||
currentHandler: &testHandler{processSleep: 500 * time.Millisecond},
|
||||
canceled: false,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
s := &spooledHandler{
|
||||
Handler: tt.fields.currentHandler,
|
||||
}
|
||||
c := make(chan interface{})
|
||||
errs := make(chan error)
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
|
||||
go func() {
|
||||
s.awaitError(cancel, errs, "test")
|
||||
c <- nil
|
||||
}()
|
||||
errs <- tt.fields.err
|
||||
|
||||
<-c
|
||||
if ctx.Err() == nil {
|
||||
t.Error("cancel function was not called")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSpooler_load checks if load terminates
|
||||
func TestSpooler_load(t *testing.T) {
|
||||
type fields struct {
|
||||
currentHandler query.Handler
|
||||
locker *testLocker
|
||||
eventstore v1.Eventstore
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
}{
|
||||
{
|
||||
"lock exists",
|
||||
fields{
|
||||
currentHandler: &testHandler{processSleep: 500 * time.Millisecond, viewModel: "testView1", cycleDuration: 1 * time.Second, bulkLimit: 10},
|
||||
locker: newTestLocker(t, "testID", "testView1").expectRenew(t, fmt.Errorf("lock already exists"), 500*time.Millisecond),
|
||||
},
|
||||
},
|
||||
{
|
||||
"lock fails",
|
||||
fields{
|
||||
currentHandler: &testHandler{processSleep: 100 * time.Millisecond, viewModel: "testView2", cycleDuration: 1 * time.Second, bulkLimit: 10},
|
||||
locker: newTestLocker(t, "testID", "testView2").expectRenew(t, fmt.Errorf("fail"), 500*time.Millisecond),
|
||||
eventstore: &eventstoreStub{events: []*models.Event{{}}},
|
||||
},
|
||||
},
|
||||
{
|
||||
"query fails",
|
||||
fields{
|
||||
currentHandler: &testHandler{processSleep: 100 * time.Millisecond, viewModel: "testView3", queryError: fmt.Errorf("query fail"), cycleDuration: 1 * time.Second, bulkLimit: 10},
|
||||
locker: newTestLocker(t, "testID", "testView3").expectRenew(t, nil, 500*time.Millisecond),
|
||||
eventstore: &eventstoreStub{err: fmt.Errorf("fail")},
|
||||
},
|
||||
},
|
||||
{
|
||||
"process event fails",
|
||||
fields{
|
||||
currentHandler: &testHandler{processError: fmt.Errorf("oups"), processSleep: 100 * time.Millisecond, viewModel: "testView4", cycleDuration: 500 * time.Millisecond, bulkLimit: 10},
|
||||
locker: newTestLocker(t, "testID", "testView4").expectRenew(t, nil, 250*time.Millisecond),
|
||||
eventstore: &eventstoreStub{events: []*models.Event{{}}},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer tt.fields.locker.finish()
|
||||
s := &spooledHandler{
|
||||
Handler: tt.fields.currentHandler,
|
||||
locker: tt.fields.locker.mock,
|
||||
eventstore: tt.fields.eventstore,
|
||||
}
|
||||
s.load("test-worker")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSpooler_lock(t *testing.T) {
|
||||
type fields struct {
|
||||
currentHandler query.Handler
|
||||
locker *testLocker
|
||||
expectsErr bool
|
||||
}
|
||||
type args struct {
|
||||
deadline time.Time
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
}{
|
||||
{
|
||||
"renew correct",
|
||||
fields{
|
||||
currentHandler: &testHandler{cycleDuration: 1 * time.Second, viewModel: "testView"},
|
||||
locker: newTestLocker(t, "testID", "testView").expectRenew(t, nil, 500*time.Millisecond),
|
||||
expectsErr: false,
|
||||
},
|
||||
args{
|
||||
deadline: time.Now().Add(1 * time.Second),
|
||||
},
|
||||
},
|
||||
{
|
||||
"renew fails",
|
||||
fields{
|
||||
currentHandler: &testHandler{cycleDuration: 900 * time.Millisecond, viewModel: "testView"},
|
||||
locker: newTestLocker(t, "testID", "testView").expectRenew(t, fmt.Errorf("renew failed"), 450*time.Millisecond),
|
||||
expectsErr: true,
|
||||
},
|
||||
args{
|
||||
deadline: time.Now().Add(5 * time.Second),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
defer tt.fields.locker.finish()
|
||||
s := &spooledHandler{
|
||||
Handler: tt.fields.currentHandler,
|
||||
locker: tt.fields.locker.mock,
|
||||
}
|
||||
|
||||
errs := make(chan error, 1)
|
||||
defer close(errs)
|
||||
ctx, _ := context.WithDeadline(context.Background(), tt.args.deadline)
|
||||
|
||||
locked := s.lock(ctx, errs, "test-worker")
|
||||
|
||||
if tt.fields.expectsErr {
|
||||
lock := <-locked
|
||||
err := <-errs
|
||||
if err == nil {
|
||||
t.Error("No error in error queue")
|
||||
}
|
||||
if lock {
|
||||
t.Error("lock should have failed")
|
||||
}
|
||||
} else {
|
||||
lock := <-locked
|
||||
if !lock {
|
||||
t.Error("lock should be true")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testLocker struct {
|
||||
mock *mock.MockLocker
|
||||
lockerID string
|
||||
viewName string
|
||||
ctrl *gomock.Controller
|
||||
}
|
||||
|
||||
func newTestLocker(t *testing.T, lockerID, viewName string) *testLocker {
|
||||
ctrl := gomock.NewController(t)
|
||||
return &testLocker{mock.NewMockLocker(ctrl), lockerID, viewName, ctrl}
|
||||
}
|
||||
|
||||
func (l *testLocker) expectRenew(t *testing.T, err error, waitTime time.Duration) *testLocker {
|
||||
t.Helper()
|
||||
l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_, _, _ string, gotten time.Duration) error {
|
||||
t.Helper()
|
||||
if waitTime-gotten != 0 {
|
||||
t.Errorf("expected waittime %v got %v", waitTime, gotten)
|
||||
}
|
||||
return err
|
||||
}).MinTimes(1).MaxTimes(3)
|
||||
|
||||
return l
|
||||
}
|
||||
|
||||
func (l *testLocker) finish() {
|
||||
l.ctrl.Finish()
|
||||
}
|
||||
|
||||
func TestHandleError(t *testing.T) {
|
||||
type args struct {
|
||||
event *models.Event
|
||||
failedErr error
|
||||
latestFailedEvent func(sequence uint64, instanceID string) (*repository.FailedEvent, error)
|
||||
errorCountUntilSkip uint64
|
||||
}
|
||||
type res struct {
|
||||
wantErr bool
|
||||
shouldProcessSequence bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "should process sequence already too high",
|
||||
args: args{
|
||||
event: &models.Event{Sequence: 30000000},
|
||||
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
|
||||
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return &repository.FailedEvent{
|
||||
ErrMsg: "blub",
|
||||
FailedSequence: s - 1,
|
||||
FailureCount: 6,
|
||||
ViewName: "super.table",
|
||||
InstanceID: instanceID,
|
||||
LastFailed: testNow,
|
||||
}, nil
|
||||
},
|
||||
errorCountUntilSkip: 5,
|
||||
},
|
||||
res: res{
|
||||
shouldProcessSequence: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should process sequence after this event too high",
|
||||
args: args{
|
||||
event: &models.Event{Sequence: 30000000},
|
||||
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
|
||||
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return &repository.FailedEvent{
|
||||
ErrMsg: "blub",
|
||||
FailedSequence: s - 1,
|
||||
FailureCount: 5,
|
||||
ViewName: "super.table",
|
||||
InstanceID: instanceID,
|
||||
LastFailed: testNow,
|
||||
}, nil
|
||||
},
|
||||
errorCountUntilSkip: 6,
|
||||
},
|
||||
res: res{
|
||||
shouldProcessSequence: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "should not process sequence",
|
||||
args: args{
|
||||
event: &models.Event{Sequence: 30000000},
|
||||
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
|
||||
latestFailedEvent: func(s uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return &repository.FailedEvent{
|
||||
ErrMsg: "blub",
|
||||
FailedSequence: s - 1,
|
||||
FailureCount: 3,
|
||||
ViewName: "super.table",
|
||||
InstanceID: instanceID,
|
||||
LastFailed: testNow,
|
||||
}, nil
|
||||
},
|
||||
errorCountUntilSkip: 5,
|
||||
},
|
||||
res: res{
|
||||
shouldProcessSequence: false,
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
processedSequence := false
|
||||
err := HandleError(
|
||||
tt.args.event,
|
||||
tt.args.failedErr,
|
||||
tt.args.latestFailedEvent,
|
||||
func(*repository.FailedEvent) error {
|
||||
return nil
|
||||
},
|
||||
func(*models.Event) error {
|
||||
processedSequence = true
|
||||
return nil
|
||||
},
|
||||
tt.args.errorCountUntilSkip)
|
||||
|
||||
if (err != nil) != tt.res.wantErr {
|
||||
t.Errorf("HandleError() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
}
|
||||
if tt.res.shouldProcessSequence != processedSequence {
|
||||
t.Error("should not process sequence")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -1,74 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
|
||||
var (
|
||||
subscriptions map[models.AggregateType][]*Subscription = map[models.AggregateType][]*Subscription{}
|
||||
subsMutext sync.Mutex
|
||||
)
|
||||
|
||||
type Subscription struct {
|
||||
Events chan *models.Event
|
||||
aggregates []models.AggregateType
|
||||
}
|
||||
|
||||
func (es *eventstore) Subscribe(aggregates ...models.AggregateType) *Subscription {
|
||||
events := make(chan *models.Event, 100)
|
||||
sub := &Subscription{
|
||||
Events: events,
|
||||
aggregates: aggregates,
|
||||
}
|
||||
|
||||
subsMutext.Lock()
|
||||
defer subsMutext.Unlock()
|
||||
|
||||
for _, aggregate := range aggregates {
|
||||
_, ok := subscriptions[aggregate]
|
||||
if !ok {
|
||||
subscriptions[aggregate] = make([]*Subscription, 0, 1)
|
||||
}
|
||||
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
|
||||
}
|
||||
|
||||
return sub
|
||||
}
|
||||
|
||||
func Notify(events []*models.Event) {
|
||||
subsMutext.Lock()
|
||||
defer subsMutext.Unlock()
|
||||
for _, event := range events {
|
||||
subs, ok := subscriptions[event.AggregateType]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for _, sub := range subs {
|
||||
sub.Events <- event
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Subscription) Unsubscribe() {
|
||||
subsMutext.Lock()
|
||||
defer subsMutext.Unlock()
|
||||
for _, aggregate := range s.aggregates {
|
||||
subs, ok := subscriptions[aggregate]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
for i := len(subs) - 1; i >= 0; i-- {
|
||||
if subs[i] == s {
|
||||
subs[i] = subs[len(subs)-1]
|
||||
subs[len(subs)-1] = nil
|
||||
subs = subs[:len(subs)-1]
|
||||
}
|
||||
}
|
||||
}
|
||||
_, ok := <-s.Events
|
||||
if ok {
|
||||
close(s.Events)
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user