fix: move v2 pkgs (#1331)

* fix: move eventstore pkgs

* fix: move eventstore pkgs

* fix: remove v2 view

* fix: remove v2 view
This commit is contained in:
Fabi
2021-02-23 15:13:04 +01:00
committed by GitHub
parent 57b277bc7c
commit d8e42744b4
797 changed files with 2116 additions and 2224 deletions

View File

@@ -0,0 +1,28 @@
package v1
import (
"github.com/caos/zitadel/internal/cache/config"
eventstore2 "github.com/caos/zitadel/internal/eventstore"
sql_v2 "github.com/caos/zitadel/internal/eventstore/repository/sql"
"github.com/caos/zitadel/internal/eventstore/v1/internal/repository/sql"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
type Config struct {
Repository sql.Config
ServiceName string
Cache *config.CacheConfig
}
func Start(conf Config) (Eventstore, error) {
repo, sqlClient, err := sql.Start(conf.Repository)
if err != nil {
return nil, err
}
return &eventstore{
repo: repo,
aggregateCreator: models.NewAggregateCreator(conf.ServiceName),
esV2: eventstore2.NewEventstore(sql_v2.NewCRDB(sqlClient)),
}, nil
}

View File

@@ -0,0 +1,75 @@
package v1
import (
"context"
eventstore2 "github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/internal/repository"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
type Eventstore interface {
AggregateCreator() *models.AggregateCreator
Health(ctx context.Context) error
PushAggregates(ctx context.Context, aggregates ...*models.Aggregate) error
FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) (events []*models.Event, err error)
LatestSequence(ctx context.Context, searchQuery *models.SearchQueryFactory) (uint64, error)
V2() *eventstore2.Eventstore
Subscribe(aggregates ...models.AggregateType) *Subscription
}
var _ Eventstore = (*eventstore)(nil)
type eventstore struct {
repo repository.Repository
aggregateCreator *models.AggregateCreator
esV2 *eventstore2.Eventstore
}
func (es *eventstore) AggregateCreator() *models.AggregateCreator {
return es.aggregateCreator
}
func (es *eventstore) PushAggregates(ctx context.Context, aggregates ...*models.Aggregate) (err error) {
for _, aggregate := range aggregates {
if len(aggregate.Events) == 0 {
return errors.ThrowInvalidArgument(nil, "EVENT-cNhIj", "no events in aggregate")
}
for _, event := range aggregate.Events {
if err = event.Validate(); err != nil {
return errors.ThrowInvalidArgument(err, "EVENT-tzIhl", "validate event failed")
}
}
}
err = es.repo.PushAggregates(ctx, aggregates...)
if err != nil {
return err
}
go notify(aggregates)
return nil
}
func (es *eventstore) FilterEvents(ctx context.Context, searchQuery *models.SearchQuery) ([]*models.Event, error) {
if err := searchQuery.Validate(); err != nil {
return nil, err
}
return es.repo.Filter(ctx, models.FactoryFromSearchQuery(searchQuery))
}
func (es *eventstore) LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error) {
sequenceFactory := *queryFactory
sequenceFactory = *(&sequenceFactory).Columns(models.Columns_Max_Sequence)
sequenceFactory = *(&sequenceFactory).SequenceGreater(0)
return es.repo.LatestSequence(ctx, &sequenceFactory)
}
func (es *eventstore) Health(ctx context.Context) error {
return es.repo.Health(ctx)
}
func (es *eventstore) V2() *eventstore2.Eventstore {
return es.esV2
}

View File

@@ -0,0 +1,3 @@
package v1
//go:generate mockgen -package mock -destination ./mock/eventstore.mock.go github.com/caos/zitadel/internal/eventstore Eventstore

View File

@@ -0,0 +1,3 @@
package repository
//go:generate mockgen -package mock -destination ./mock/repository.mock.go github.com/caos/zitadel/internal/eventstore/internal/repository Repository

View File

@@ -0,0 +1,34 @@
package mock
import (
"context"
"testing"
"github.com/caos/zitadel/internal/eventstore/v1/models"
gomock "github.com/golang/mock/gomock"
)
func NewMock(t *testing.T) *MockRepository {
return NewMockRepository(gomock.NewController(t))
}
func (m *MockRepository) ExpectFilter(query *models.SearchQuery, eventAmount int) *MockRepository {
events := make([]*models.Event, eventAmount)
m.EXPECT().Filter(context.Background(), query).Return(events, nil).MaxTimes(1)
return m
}
func (m *MockRepository) ExpectFilterFail(query *models.SearchQuery, err error) *MockRepository {
m.EXPECT().Filter(context.Background(), query).Return(nil, err).MaxTimes(1)
return m
}
func (m *MockRepository) ExpectPush(aggregates ...*models.Aggregate) *MockRepository {
m.EXPECT().PushAggregates(context.Background(), aggregates).Return(nil).MaxTimes(1)
return m
}
func (m *MockRepository) ExpectPushError(err error, aggregates ...*models.Aggregate) *MockRepository {
m.EXPECT().PushAggregates(context.Background(), aggregates).Return(err).MaxTimes(1)
return m
}

View File

@@ -0,0 +1,98 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/zitadel/internal/eventstore/internal/repository (interfaces: Repository)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
models "github.com/caos/zitadel/internal/eventstore/v1/models"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockRepository is a mock of Repository interface
type MockRepository struct {
ctrl *gomock.Controller
recorder *MockRepositoryMockRecorder
}
// MockRepositoryMockRecorder is the mock recorder for MockRepository
type MockRepositoryMockRecorder struct {
mock *MockRepository
}
// NewMockRepository creates a new mock instance
func NewMockRepository(ctrl *gomock.Controller) *MockRepository {
mock := &MockRepository{ctrl: ctrl}
mock.recorder = &MockRepositoryMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockRepository) EXPECT() *MockRepositoryMockRecorder {
return m.recorder
}
// Filter mocks base method
func (m *MockRepository) Filter(arg0 context.Context, arg1 *models.SearchQueryFactory) ([]*models.Event, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Filter", arg0, arg1)
ret0, _ := ret[0].([]*models.Event)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Filter indicates an expected call of Filter
func (mr *MockRepositoryMockRecorder) Filter(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Filter", reflect.TypeOf((*MockRepository)(nil).Filter), arg0, arg1)
}
// Health mocks base method
func (m *MockRepository) Health(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Health indicates an expected call of Health
func (mr *MockRepositoryMockRecorder) Health(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockRepository)(nil).Health), arg0)
}
// LatestSequence mocks base method
func (m *MockRepository) LatestSequence(arg0 context.Context, arg1 *models.SearchQueryFactory) (uint64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LatestSequence", arg0, arg1)
ret0, _ := ret[0].(uint64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LatestSequence indicates an expected call of LatestSequence
func (mr *MockRepositoryMockRecorder) LatestSequence(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestSequence", reflect.TypeOf((*MockRepository)(nil).LatestSequence), arg0, arg1)
}
// PushAggregates mocks base method
func (m *MockRepository) PushAggregates(arg0 context.Context, arg1 ...*models.Aggregate) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PushAggregates", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// PushAggregates indicates an expected call of PushAggregates
func (mr *MockRepositoryMockRecorder) PushAggregates(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushAggregates", reflect.TypeOf((*MockRepository)(nil).PushAggregates), varargs...)
}

View File

@@ -0,0 +1,19 @@
package repository
import (
"context"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
type Repository interface {
Health(ctx context.Context) error
// PushEvents adds all events of the given aggregates to the eventstreams of the aggregates.
// This call is transaction save. The transaction will be rolled back if one event fails
PushAggregates(ctx context.Context, aggregates ...*models.Aggregate) error
// Filter returns all events matching the given search query
Filter(ctx context.Context, searchQuery *models.SearchQueryFactory) (events []*models.Event, err error)
//LatestSequence returns the latests sequence found by the the search query
LatestSequence(ctx context.Context, queryFactory *models.SearchQueryFactory) (uint64, error)
}

View File

@@ -0,0 +1,24 @@
package sql
import (
"database/sql"
_ "github.com/lib/pq"
"github.com/caos/zitadel/internal/config/types"
"github.com/caos/zitadel/internal/errors"
)
type Config struct {
SQL types.SQL
}
func Start(conf Config) (*SQL, *sql.DB, error) {
client, err := conf.SQL.Start()
if err != nil {
return nil, nil, errors.ThrowPreconditionFailed(err, "SQL-9qBtr", "unable to open database connection")
}
return &SQL{
client: client,
}, client, nil
}

View File

@@ -0,0 +1,193 @@
package sql
import (
"database/sql"
"regexp"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
const (
selectEscaped = `SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore\.events WHERE aggregate_type = \$1`
)
var (
eventColumns = []string{"creation_date", "event_type", "event_sequence", "previous_sequence", "event_data", "editor_service", "editor_user", "resource_owner", "aggregate_type", "aggregate_id", "aggregate_version"}
expectedFilterEventsLimitFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence LIMIT \$2`).String()
expectedFilterEventsDescFormat = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence DESC`).String()
expectedFilterEventsAggregateIDLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
expectedFilterEventsAggregateIDTypeLimit = regexp.MustCompile(selectEscaped + ` AND aggregate_id = \$2 ORDER BY event_sequence LIMIT \$3`).String()
expectedGetAllEvents = regexp.MustCompile(selectEscaped + ` ORDER BY event_sequence`).String()
expectedInsertStatement = regexp.MustCompile(`INSERT INTO eventstore\.events ` +
`\(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, previous_sequence\) ` +
`SELECT \$1, \$2, \$3, \$4, COALESCE\(\$5, now\(\)\), \$6, \$7, \$8, \$9, \$10 ` +
`WHERE EXISTS \(` +
`SELECT 1 FROM eventstore\.events WHERE aggregate_type = \$11 AND aggregate_id = \$12 HAVING MAX\(event_sequence\) = \$13 OR \(\$14::BIGINT IS NULL AND COUNT\(\*\) = 0\)\) ` +
`RETURNING event_sequence, creation_date`).String()
)
type dbMock struct {
sqlClient *sql.DB
mock sqlmock.Sqlmock
}
func (db *dbMock) close() {
db.sqlClient.Close()
}
func mockDB(t *testing.T) *dbMock {
mockDB := dbMock{}
var err error
mockDB.sqlClient, mockDB.mock, err = sqlmock.New()
if err != nil {
t.Fatalf("error occured while creating stub db %v", err)
}
mockDB.mock.MatchExpectationsInOrder(true)
return &mockDB
}
func (db *dbMock) expectBegin(err error) *dbMock {
if err != nil {
db.mock.ExpectBegin().WillReturnError(err)
} else {
db.mock.ExpectBegin()
}
return db
}
func (db *dbMock) expectSavepoint() *dbMock {
db.mock.ExpectExec("SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
return db
}
func (db *dbMock) expectReleaseSavepoint(err error) *dbMock {
expectation := db.mock.ExpectExec("RELEASE SAVEPOINT")
if err == nil {
expectation.WillReturnResult(sqlmock.NewResult(1, 1))
} else {
expectation.WillReturnError(err)
}
return db
}
func (db *dbMock) expectCommit(err error) *dbMock {
if err != nil {
db.mock.ExpectCommit().WillReturnError(err)
} else {
db.mock.ExpectCommit()
}
return db
}
func (db *dbMock) expectRollback(err error) *dbMock {
if err != nil {
db.mock.ExpectRollback().WillReturnError(err)
} else {
db.mock.ExpectRollback()
}
return db
}
func (db *dbMock) expectInsertEvent(e *models.Event, returnedSequence uint64) *dbMock {
db.mock.ExpectQuery(expectedInsertStatement).
WithArgs(
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, Sequence(e.PreviousSequence),
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
).
WillReturnRows(
sqlmock.NewRows([]string{"event_sequence", "creation_date"}).
AddRow(returnedSequence, time.Now().UTC()),
)
return db
}
func (db *dbMock) expectInsertEventError(e *models.Event) *dbMock {
db.mock.ExpectQuery(expectedInsertStatement).
WithArgs(
e.Type, e.AggregateType, e.AggregateID, e.AggregateVersion, sqlmock.AnyArg(), Data(e.Data), e.EditorUser, e.EditorService, e.ResourceOwner, Sequence(e.PreviousSequence),
e.AggregateType, e.AggregateID, Sequence(e.PreviousSequence), Sequence(e.PreviousSequence),
).
WillReturnError(sql.ErrTxDone)
return db
}
func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, eventCount int) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := 0; i < eventCount; i++ {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsLimitFormat).
WithArgs(aggregateType, limit).
WillReturnRows(rows)
return db
}
func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := eventCount; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsDescFormat).
WillReturnRows(rows)
return db
}
func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit).
WithArgs(aggregateType, aggregateID, limit).
WillReturnRows(rows)
return db
}
func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregateID string, limit uint64) *dbMock {
rows := sqlmock.NewRows(eventColumns)
for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "aggType", "aggID", "v1.0.0")
}
db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit).
WithArgs(aggregateType, aggregateID, limit).
WillReturnRows(rows)
return db
}
func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock {
db.mock.ExpectQuery(expectedGetAllEvents).
WillReturnError(returnedErr)
return db
}
func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock {
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`).
WithArgs(aggregateType).
WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence))
return db
}
func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock {
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events WHERE aggregate_type = \$1`).
WithArgs(aggregateType).WillReturnError(err)
return db
}
func (db *dbMock) expectPrepareInsert(err error) *dbMock {
prepare := db.mock.ExpectPrepare(expectedInsertStatement)
if err != nil {
prepare.WillReturnError(err)
}
return db
}

View File

@@ -0,0 +1,62 @@
package sql
import (
"context"
"database/sql"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/telemetry/tracing"
)
type Querier interface {
Query(query string, args ...interface{}) (*sql.Rows, error)
}
func (db *SQL) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
return filter(db.client, searchQuery)
}
func filter(querier Querier, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
query, limit, values, rowScanner := buildQuery(searchQuery)
if query == "" {
return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
rows, err := querier.Query(query, values...)
if err != nil {
logging.Log("SQL-HP3Uk").WithError(err).Info("query failed")
return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
}
defer rows.Close()
events = make([]*es_models.Event, 0, limit)
for rows.Next() {
event := new(es_models.Event)
err := rowScanner(rows.Scan, event)
if err != nil {
return nil, err
}
events = append(events, event)
}
return events, nil
}
func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
query, _, values, rowScanner := buildQuery(queryFactory)
if query == "" {
return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
row := db.client.QueryRow(query, values...)
sequence := new(Sequence)
err := rowScanner(row.Scan, sequence)
if err != nil {
logging.Log("SQL-WsxTg").WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Info("query failed")
return 0, errors.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
}
return uint64(*sequence), nil
}

View File

@@ -0,0 +1,233 @@
package sql
import (
"context"
"database/sql"
"math"
"testing"
"github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
)
func TestSQL_Filter(t *testing.T) {
type fields struct {
client *dbMock
}
type args struct {
events *mockEvents
searchQuery *es_models.SearchQueryFactory
}
type res struct {
wantErr bool
isErrFunc func(error) bool
eventsLen int
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "only limit filter",
fields: fields{
client: mockDB(t).expectFilterEventsLimit("user", 34, 3),
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").Limit(34),
},
res: res{
eventsLen: 3,
wantErr: false,
},
},
{
name: "only desc filter",
fields: fields{
client: mockDB(t).expectFilterEventsDesc("user", 34),
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").OrderDesc(),
},
res: res{
eventsLen: 34,
wantErr: false,
},
},
{
name: "no events found",
fields: fields{
client: mockDB(t).expectFilterEventsError(sql.ErrNoRows),
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("nonAggregate"),
},
res: res{
wantErr: true,
isErrFunc: errors.IsInternal,
},
},
{
name: "filter fails because sql internal error",
fields: fields{
client: mockDB(t).expectFilterEventsError(sql.ErrConnDone),
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user"),
},
res: res{
wantErr: true,
isErrFunc: errors.IsInternal,
},
},
{
name: "filter by aggregate id",
fields: fields{
client: mockDB(t).expectFilterEventsAggregateIDLimit("user", "hop", 5),
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"),
},
res: res{
wantErr: false,
isErrFunc: nil,
},
},
{
name: "filter by aggregate id and aggregate type",
fields: fields{
client: mockDB(t).expectFilterEventsAggregateIDTypeLimit("user", "hop", 5),
},
args: args{
events: &mockEvents{t: t},
searchQuery: es_models.NewSearchQueryFactory("user").Limit(5).AggregateIDs("hop"),
},
res: res{
wantErr: false,
isErrFunc: nil,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sql := &SQL{
client: tt.fields.client.sqlClient,
}
events, err := sql.Filter(context.Background(), tt.args.searchQuery)
if (err != nil) != tt.res.wantErr {
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
}
if tt.res.eventsLen != 0 && len(events) != tt.res.eventsLen {
t.Errorf("events has wrong length got: %d want %d", len(events), tt.res.eventsLen)
}
if tt.res.wantErr && !tt.res.isErrFunc(err) {
t.Errorf("got wrong error %v", err)
}
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
tt.fields.client.close()
})
}
}
func TestSQL_LatestSequence(t *testing.T) {
type fields struct {
client *dbMock
}
type args struct {
searchQuery *es_models.SearchQueryFactory
}
type res struct {
wantErr bool
isErrFunc func(error) bool
sequence uint64
}
tests := []struct {
name string
fields fields
args args
res res
}{
{
name: "invalid query factory",
args: args{
searchQuery: nil,
},
fields: fields{
client: mockDB(t),
},
res: res{
wantErr: true,
isErrFunc: errors.IsErrorInvalidArgument,
},
},
{
name: "no events for aggregate",
args: args{
searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence),
},
fields: fields{
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrNoRows),
},
res: res{
wantErr: false,
sequence: 0,
},
},
{
name: "sql query error",
args: args{
searchQuery: es_models.NewSearchQueryFactory("idiot").Columns(es_models.Columns_Max_Sequence),
},
fields: fields{
client: mockDB(t).expectLatestSequenceFilterError("idiot", sql.ErrConnDone),
},
res: res{
wantErr: true,
isErrFunc: errors.IsInternal,
sequence: 0,
},
},
{
name: "events for aggregate found",
args: args{
searchQuery: es_models.NewSearchQueryFactory("user").Columns(es_models.Columns_Max_Sequence),
},
fields: fields{
client: mockDB(t).expectLatestSequenceFilter("user", math.MaxUint64),
},
res: res{
wantErr: false,
sequence: math.MaxUint64,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sql := &SQL{
client: tt.fields.client.sqlClient,
}
sequence, err := sql.LatestSequence(context.Background(), tt.args.searchQuery)
if (err != nil) != tt.res.wantErr {
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
}
if tt.res.sequence != sequence {
t.Errorf("events has wrong length got: %d want %d", sequence, tt.res.sequence)
}
if tt.res.wantErr && !tt.res.isErrFunc(err) {
t.Errorf("got wrong error %v", err)
}
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
t.Errorf("there were unfulfilled expectations: %s", err)
}
tt.fields.client.close()
})
}
}

View File

@@ -0,0 +1,90 @@
package sql
import (
"context"
"database/sql"
"errors"
"github.com/caos/logging"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/telemetry/tracing"
"github.com/cockroachdb/cockroach-go/v2/crdb"
)
const (
insertStmt = "INSERT INTO eventstore.events " +
"(event_type, aggregate_type, aggregate_id, aggregate_version, creation_date, event_data, editor_user, editor_service, resource_owner, previous_sequence) " +
"SELECT $1, $2, $3, $4, COALESCE($5, now()), $6, $7, $8, $9, $10 " +
"WHERE EXISTS (" +
"SELECT 1 FROM eventstore.events WHERE aggregate_type = $11 AND aggregate_id = $12 HAVING MAX(event_sequence) = $13 OR ($14::BIGINT IS NULL AND COUNT(*) = 0)) " +
"RETURNING event_sequence, creation_date"
)
func (db *SQL) PushAggregates(ctx context.Context, aggregates ...*models.Aggregate) (err error) {
err = crdb.ExecuteTx(ctx, db.client, nil, func(tx *sql.Tx) error {
stmt, err := tx.Prepare(insertStmt)
if err != nil {
tx.Rollback()
logging.Log("SQL-9ctx5").WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Warn("prepare failed")
return caos_errs.ThrowInternal(err, "SQL-juCgA", "prepare failed")
}
for _, aggregate := range aggregates {
err = precondtion(tx, aggregate)
if err != nil {
tx.Rollback()
return err
}
err = insertEvents(stmt, Sequence(aggregate.PreviousSequence), aggregate.Events)
if err != nil {
tx.Rollback()
return err
}
}
return nil
})
if err != nil && !errors.Is(err, &caos_errs.CaosError{}) {
err = caos_errs.ThrowInternal(err, "SQL-DjgtG", "unable to store events")
}
return err
}
func precondtion(tx *sql.Tx, aggregate *models.Aggregate) error {
if aggregate.Precondition == nil {
return nil
}
events, err := filter(tx, models.FactoryFromSearchQuery(aggregate.Precondition.Query))
if err != nil {
return caos_errs.ThrowPreconditionFailed(err, "SQL-oBPxB", "filter failed")
}
err = aggregate.Precondition.Validation(events...)
if err != nil {
return err
}
return nil
}
func insertEvents(stmt *sql.Stmt, previousSequence Sequence, events []*models.Event) error {
for _, event := range events {
creationDate := sql.NullTime{Time: event.CreationDate, Valid: !event.CreationDate.IsZero()}
err := stmt.QueryRow(event.Type, event.AggregateType, event.AggregateID, event.AggregateVersion, creationDate, Data(event.Data), event.EditorUser, event.EditorService, event.ResourceOwner, previousSequence,
event.AggregateType, event.AggregateID, previousSequence, previousSequence).Scan(&previousSequence, &event.CreationDate)
if err != nil {
logging.LogWithFields("SQL-5M0sd",
"aggregate", event.AggregateType,
"previousSequence", previousSequence,
"aggregateId", event.AggregateID,
"aggregateType", event.AggregateType,
"eventType", event.Type).WithError(err).Info("query failed")
return caos_errs.ThrowInternal(err, "SQL-SBP37", "unable to create event")
}
event.Sequence = uint64(previousSequence)
}
return nil
}

View File

@@ -0,0 +1,443 @@
package sql
import (
"context"
"database/sql"
"errors"
"reflect"
"runtime"
"testing"
z_errors "github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
type mockEvents struct {
events []*models.Event
t *testing.T
}
func TestSQL_PushAggregates(t *testing.T) {
type fields struct {
client *dbMock
}
type args struct {
aggregates []*models.Aggregate
}
tests := []struct {
name string
fields fields
args args
isError func(error) bool
shouldCheckEvents bool
}{
{
name: "no aggregates",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectSavepoint().
expectPrepareInsert(nil).
expectReleaseSavepoint(nil).
expectCommit(nil),
},
args: args{aggregates: []*models.Aggregate{}},
shouldCheckEvents: false,
isError: noErr,
},
{
name: "prepare fails",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectSavepoint().
expectPrepareInsert(sql.ErrConnDone).
expectReleaseSavepoint(nil).
expectCommit(nil),
},
args: args{aggregates: []*models.Aggregate{}},
shouldCheckEvents: false,
isError: func(err error) bool { return errors.Is(err, sql.ErrConnDone) },
},
{
name: "no aggregates release fails",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectSavepoint().
expectPrepareInsert(nil).
expectReleaseSavepoint(sql.ErrConnDone).
expectCommit(nil),
},
args: args{aggregates: []*models.Aggregate{}},
isError: z_errors.IsInternal,
shouldCheckEvents: false,
},
{
name: "aggregate precondtion fails",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectSavepoint().
expectPrepareInsert(nil).
expectFilterEventsError(z_errors.CreateCaosError(nil, "SQL-IzJOf", "err")).
expectRollback(nil),
},
args: args{aggregates: []*models.Aggregate{aggregateWithPrecondition(&models.Aggregate{}, models.NewSearchQuery().SetLimit(1), nil)}},
isError: z_errors.IsPreconditionFailed,
shouldCheckEvents: false,
},
{
name: "one aggregate two events success",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectSavepoint().
expectPrepareInsert(nil).
expectInsertEvent(&models.Event{
AggregateID: "aggID",
AggregateType: "aggType",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
PreviousSequence: 34,
Type: "eventTyp",
AggregateVersion: "v0.0.1",
}, 45).
expectInsertEvent(&models.Event{
AggregateID: "aggID",
AggregateType: "aggType",
EditorService: "svc2",
EditorUser: "usr2",
ResourceOwner: "ro2",
PreviousSequence: 45,
Type: "eventTyp",
AggregateVersion: "v0.0.1",
}, 46).
expectReleaseSavepoint(nil).
expectCommit(nil),
},
args: args{
aggregates: []*models.Aggregate{
{
PreviousSequence: 34,
Events: []*models.Event{
{
AggregateID: "aggID",
AggregateType: "aggType",
AggregateVersion: "v0.0.1",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
Type: "eventTyp",
},
{
AggregateID: "aggID",
AggregateType: "aggType",
AggregateVersion: "v0.0.1",
EditorService: "svc2",
EditorUser: "usr2",
ResourceOwner: "ro2",
Type: "eventTyp",
},
},
},
},
},
shouldCheckEvents: true,
isError: noErr,
},
{
name: "two aggregates one event per aggregate success",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectSavepoint().
expectPrepareInsert(nil).
expectInsertEvent(&models.Event{
AggregateID: "aggID",
AggregateType: "aggType",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
PreviousSequence: 34,
Type: "eventTyp",
AggregateVersion: "v0.0.1",
}, 47).
expectInsertEvent(&models.Event{
AggregateID: "aggID2",
AggregateType: "aggType2",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
PreviousSequence: 40,
Type: "eventTyp",
AggregateVersion: "v0.0.1",
}, 48).
expectReleaseSavepoint(nil).
expectCommit(nil),
},
args: args{
aggregates: []*models.Aggregate{
{
PreviousSequence: 34,
Events: []*models.Event{
{
AggregateID: "aggID",
AggregateType: "aggType",
AggregateVersion: "v0.0.1",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
Type: "eventTyp",
},
},
},
{
PreviousSequence: 40,
Events: []*models.Event{
{
AggregateID: "aggID2",
AggregateType: "aggType2",
AggregateVersion: "v0.0.1",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
Type: "eventTyp",
},
},
},
},
},
shouldCheckEvents: true,
isError: noErr,
},
{
name: "first event fails no action with second event",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectSavepoint().
expectPrepareInsert(nil).
expectInsertEventError(&models.Event{
AggregateID: "aggID",
AggregateType: "aggType",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
PreviousSequence: 34,
Data: []byte("{}"),
Type: "eventTyp",
AggregateVersion: "v0.0.1",
}).
expectReleaseSavepoint(nil).
expectRollback(nil),
},
args: args{
aggregates: []*models.Aggregate{
{
Events: []*models.Event{
{
AggregateID: "aggID",
AggregateType: "aggType",
AggregateVersion: "v0.0.1",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
Type: "eventTyp",
PreviousSequence: 34,
},
{
AggregateID: "aggID",
AggregateType: "aggType",
AggregateVersion: "v0.0.1",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
Type: "eventTyp",
PreviousSequence: 0,
},
},
},
},
},
isError: z_errors.IsInternal,
shouldCheckEvents: false,
},
{
name: "one event, release savepoint fails",
fields: fields{
client: mockDB(t).
expectBegin(nil).
expectPrepareInsert(nil).
expectSavepoint().
expectInsertEvent(&models.Event{
AggregateID: "aggID",
AggregateType: "aggType",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
PreviousSequence: 34,
Type: "eventTyp",
Data: []byte("{}"),
AggregateVersion: "v0.0.1",
}, 47).
expectReleaseSavepoint(sql.ErrConnDone).
expectCommit(nil).
expectRollback(nil),
},
args: args{
aggregates: []*models.Aggregate{
{
Events: []*models.Event{
{
AggregateID: "aggID",
AggregateType: "aggType",
AggregateVersion: "v0.0.1",
EditorService: "svc",
EditorUser: "usr",
ResourceOwner: "ro",
Type: "eventTyp",
PreviousSequence: 34,
},
},
},
},
},
isError: z_errors.IsInternal,
shouldCheckEvents: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
sql := &SQL{
client: tt.fields.client.sqlClient,
}
err := sql.PushAggregates(context.Background(), tt.args.aggregates...)
if err != nil && !tt.isError(err) {
t.Errorf("wrong error type = %v, errFunc %s", err, functionName(tt.isError))
}
if !tt.shouldCheckEvents {
return
}
for _, aggregate := range tt.args.aggregates {
for _, event := range aggregate.Events {
if event.Sequence == 0 {
t.Error("sequence of returned event is not set")
}
if event.AggregateType == "" || event.AggregateID == "" {
t.Error("aggregate of event is not set")
}
}
}
if err := tt.fields.client.mock.ExpectationsWereMet(); err != nil {
t.Errorf("not all database expectaions met: %s", err)
}
})
}
}
func noErr(err error) bool {
return err == nil
}
func functionName(i interface{}) string {
return runtime.FuncForPC(reflect.ValueOf(i).Pointer()).Name()
}
func Test_precondtion(t *testing.T) {
type fields struct {
client *dbMock
}
type args struct {
aggregate *models.Aggregate
}
tests := []struct {
name string
fields fields
args args
isErr func(error) bool
}{
{
name: "no precondition",
fields: fields{
client: mockDB(t).
expectBegin(nil),
},
args: args{
aggregate: &models.Aggregate{},
},
},
{
name: "precondition fails",
fields: fields{
client: mockDB(t).
expectBegin(nil).expectFilterEventsLimit("test", 5, 0),
},
args: args{
aggregate: aggregateWithPrecondition(&models.Aggregate{}, models.NewSearchQuery().SetLimit(5).AggregateTypeFilter("test"), validationFunc(z_errors.ThrowPreconditionFailed(nil, "SQL-LBIKm", "err"))),
},
isErr: z_errors.IsPreconditionFailed,
},
{
name: "precondition with filter error",
fields: fields{
client: mockDB(t).
expectBegin(nil).expectFilterEventsError(z_errors.ThrowInternal(nil, "SQL-ac9EW", "err")),
},
args: args{
aggregate: aggregateWithPrecondition(&models.Aggregate{}, models.NewSearchQuery().SetLimit(5).AggregateTypeFilter("test"), validationFunc(z_errors.CreateCaosError(nil, "SQL-LBIKm", "err"))),
},
isErr: z_errors.IsPreconditionFailed,
},
{
name: "precondition no events",
fields: fields{
client: mockDB(t).
expectBegin(nil).expectFilterEventsLimit("test", 5, 0),
},
args: args{
aggregate: aggregateWithPrecondition(&models.Aggregate{}, models.NewSearchQuery().SetLimit(5).AggregateTypeFilter("test"), validationFunc(nil)),
},
},
{
name: "precondition with events",
fields: fields{
client: mockDB(t).
expectBegin(nil).expectFilterEventsLimit("test", 5, 3),
},
args: args{
aggregate: aggregateWithPrecondition(&models.Aggregate{}, models.NewSearchQuery().SetLimit(5).AggregateTypeFilter("test"), validationFunc(nil)),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tx, err := tt.fields.client.sqlClient.Begin()
if err != nil {
t.Errorf("unable to start tx %v", err)
t.FailNow()
}
err = precondtion(tx, tt.args.aggregate)
if (err != nil) && (tt.isErr == nil) {
t.Errorf("no error expected got: %v", err)
}
if tt.isErr != nil && !tt.isErr(err) {
t.Errorf("precondtion() wrong error %T, %v", err, err)
}
})
}
}
func aggregateWithPrecondition(aggregate *models.Aggregate, query *models.SearchQuery, precondition func(...*models.Event) error) *models.Aggregate {
aggregate.SetPrecondition(query, precondition)
return aggregate
}
func validationFunc(err error) func(events ...*models.Event) error {
return func(events ...*models.Event) error {
return err
}
}

View File

@@ -0,0 +1,198 @@
package sql
import (
"database/sql"
"errors"
"fmt"
"strconv"
"strings"
"github.com/caos/logging"
z_errors "github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/lib/pq"
)
const (
selectStmt = "SELECT" +
" creation_date" +
", event_type" +
", event_sequence" +
", previous_sequence" +
", event_data" +
", editor_service" +
", editor_user" +
", resource_owner" +
", aggregate_type" +
", aggregate_id" +
", aggregate_version" +
" FROM eventstore.events"
)
func buildQuery(queryFactory *es_models.SearchQueryFactory) (query string, limit uint64, values []interface{}, rowScanner func(s scan, dest interface{}) error) {
searchQuery, err := queryFactory.Build()
if err != nil {
logging.Log("SQL-cshKu").WithError(err).Warn("search query factory invalid")
return "", 0, nil, nil
}
query, rowScanner = prepareColumns(searchQuery.Columns)
where, values := prepareCondition(searchQuery.Filters)
if where == "" || query == "" {
return "", 0, nil, nil
}
query += where
if searchQuery.Columns != es_models.Columns_Max_Sequence {
query += " ORDER BY event_sequence"
if searchQuery.Desc {
query += " DESC"
}
}
if searchQuery.Limit > 0 {
values = append(values, searchQuery.Limit)
query += " LIMIT ?"
}
query = numberPlaceholder(query, "?", "$")
return query, searchQuery.Limit, values, rowScanner
}
func prepareCondition(filters []*es_models.Filter) (clause string, values []interface{}) {
values = make([]interface{}, len(filters))
clauses := make([]string, len(filters))
if len(filters) == 0 {
return clause, values
}
for i, filter := range filters {
value := filter.GetValue()
switch value.(type) {
case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType:
value = pq.Array(value)
}
clauses[i] = getCondition(filter)
if clauses[i] == "" {
return "", nil
}
values[i] = value
}
return " WHERE " + strings.Join(clauses, " AND "), values
}
type scan func(dest ...interface{}) error
func prepareColumns(columns es_models.Columns) (string, func(s scan, dest interface{}) error) {
switch columns {
case es_models.Columns_Max_Sequence:
return "SELECT MAX(event_sequence) FROM eventstore.events", func(row scan, dest interface{}) (err error) {
sequence, ok := dest.(*Sequence)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-NBjA9", "type must be sequence")
}
err = row(sequence)
if err == nil || errors.Is(err, sql.ErrNoRows) {
return nil
}
return z_errors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
}
case es_models.Columns_Event:
return selectStmt, func(row scan, dest interface{}) (err error) {
event, ok := dest.(*es_models.Event)
if !ok {
return z_errors.ThrowInvalidArgument(nil, "SQL-4GP6F", "type must be event")
}
var previousSequence Sequence
data := make(Data, 0)
err = row(
&event.CreationDate,
&event.Type,
&event.Sequence,
&previousSequence,
&data,
&event.EditorService,
&event.EditorUser,
&event.ResourceOwner,
&event.AggregateType,
&event.AggregateID,
&event.AggregateVersion,
)
if err != nil {
logging.Log("SQL-kn1Sw").WithError(err).Warn("unable to scan row")
return z_errors.ThrowInternal(err, "SQL-J0hFS", "unable to scan row")
}
event.PreviousSequence = uint64(previousSequence)
event.Data = make([]byte, len(data))
copy(event.Data, data)
return nil
}
default:
return "", nil
}
}
func numberPlaceholder(query, old, new string) string {
for i, hasChanged := 1, true; hasChanged; i++ {
newQuery := strings.Replace(query, old, new+strconv.Itoa(i), 1)
hasChanged = query != newQuery
query = newQuery
}
return query
}
func getCondition(filter *es_models.Filter) (condition string) {
field := getField(filter.GetField())
operation := getOperation(filter.GetOperation())
if field == "" || operation == "" {
return ""
}
format := getConditionFormat(filter.GetOperation())
return fmt.Sprintf(format, field, operation)
}
func getConditionFormat(operation es_models.Operation) string {
if operation == es_models.Operation_In {
return "%s %s ANY(?)"
}
return "%s %s ?"
}
func getField(field es_models.Field) string {
switch field {
case es_models.Field_AggregateID:
return "aggregate_id"
case es_models.Field_AggregateType:
return "aggregate_type"
case es_models.Field_LatestSequence:
return "event_sequence"
case es_models.Field_ResourceOwner:
return "resource_owner"
case es_models.Field_EditorService:
return "editor_service"
case es_models.Field_EditorUser:
return "editor_user"
case es_models.Field_EventType:
return "event_type"
}
return ""
}
func getOperation(operation es_models.Operation) string {
switch operation {
case es_models.Operation_Equals, es_models.Operation_In:
return "="
case es_models.Operation_Greater:
return ">"
case es_models.Operation_Less:
return "<"
}
return ""
}

View File

@@ -0,0 +1,485 @@
package sql
import (
"database/sql"
"reflect"
"testing"
"time"
"github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/lib/pq"
)
func Test_numberPlaceholder(t *testing.T) {
type args struct {
query string
old string
new string
}
type res struct {
query string
}
tests := []struct {
name string
args args
res res
}{
{
name: "no replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events",
},
res: res{
query: "SELECT * FROM eventstore.events",
},
},
{
name: "two replaces",
args: args{
new: "$",
old: "?",
query: "SELECT * FROM eventstore.events WHERE aggregate_type = ? AND LIMIT = ?",
},
res: res{
query: "SELECT * FROM eventstore.events WHERE aggregate_type = $1 AND LIMIT = $2",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := numberPlaceholder(tt.args.query, tt.args.old, tt.args.new); got != tt.res.query {
t.Errorf("numberPlaceholder() = %v, want %v", got, tt.res.query)
}
})
}
}
func Test_getOperation(t *testing.T) {
t.Run("all ops", func(t *testing.T) {
for op, expected := range map[es_models.Operation]string{
es_models.Operation_Equals: "=",
es_models.Operation_In: "=",
es_models.Operation_Greater: ">",
es_models.Operation_Less: "<",
es_models.Operation(-1): "",
} {
if got := getOperation(op); got != expected {
t.Errorf("getOperation() = %v, want %v", got, expected)
}
}
})
}
func Test_getField(t *testing.T) {
t.Run("all fields", func(t *testing.T) {
for field, expected := range map[es_models.Field]string{
es_models.Field_AggregateType: "aggregate_type",
es_models.Field_AggregateID: "aggregate_id",
es_models.Field_LatestSequence: "event_sequence",
es_models.Field_ResourceOwner: "resource_owner",
es_models.Field_EditorService: "editor_service",
es_models.Field_EditorUser: "editor_user",
es_models.Field_EventType: "event_type",
es_models.Field(-1): "",
} {
if got := getField(field); got != expected {
t.Errorf("getField() = %v, want %v", got, expected)
}
}
})
}
func Test_getConditionFormat(t *testing.T) {
type args struct {
operation es_models.Operation
}
tests := []struct {
name string
args args
want string
}{
{
name: "no in operation",
args: args{
operation: es_models.Operation_Equals,
},
want: "%s %s ?",
},
{
name: "in operation",
args: args{
operation: es_models.Operation_In,
},
want: "%s %s ANY(?)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getConditionFormat(tt.args.operation); got != tt.want {
t.Errorf("prepareConditionFormat() = %v, want %v", got, tt.want)
}
})
}
}
func Test_getCondition(t *testing.T) {
type args struct {
filter *es_models.Filter
}
tests := []struct {
name string
args args
want string
}{
{
name: "equals",
args: args{filter: es_models.NewFilter(es_models.Field_AggregateID, "", es_models.Operation_Equals)},
want: "aggregate_id = ?",
},
{
name: "greater",
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 0, es_models.Operation_Greater)},
want: "event_sequence > ?",
},
{
name: "less",
args: args{filter: es_models.NewFilter(es_models.Field_LatestSequence, 5000, es_models.Operation_Less)},
want: "event_sequence < ?",
},
{
name: "in list",
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation_In)},
want: "aggregate_type = ANY(?)",
},
{
name: "invalid operation",
args: args{filter: es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
want: "",
},
{
name: "invalid field",
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation_Equals)},
want: "",
},
{
name: "invalid field and operation",
args: args{filter: es_models.NewFilter(es_models.Field(-1), []es_models.AggregateType{"movies", "actors"}, es_models.Operation(-1))},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := getCondition(tt.args.filter); got != tt.want {
t.Errorf("getCondition() = %v, want %v", got, tt.want)
}
})
}
}
func Test_prepareColumns(t *testing.T) {
type args struct {
columns es_models.Columns
dest interface{}
dbErr error
}
type res struct {
query string
dbRow []interface{}
expected interface{}
dbErr func(error) bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "invalid columns",
args: args{columns: es_models.Columns(-1)},
res: res{
query: "",
dbErr: func(err error) bool { return err == nil },
},
},
{
name: "max column",
args: args{
columns: es_models.Columns_Max_Sequence,
dest: new(Sequence),
},
res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events",
dbRow: []interface{}{Sequence(5)},
expected: Sequence(5),
},
},
{
name: "max sequence wrong dest type",
args: args{
columns: es_models.Columns_Max_Sequence,
dest: new(uint64),
},
res: res{
query: "SELECT MAX(event_sequence) FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument,
},
},
{
name: "event",
args: args{
columns: es_models.Columns_Event,
dest: new(es_models.Event),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbRow: []interface{}{time.Time{}, es_models.EventType(""), uint64(5), Sequence(0), Data(nil), "", "", "", es_models.AggregateType("user"), "hodor", es_models.Version("")},
expected: es_models.Event{AggregateID: "hodor", AggregateType: "user", Sequence: 5, Data: make(Data, 0)},
},
},
{
name: "event wrong dest type",
args: args{
columns: es_models.Columns_Event,
dest: new(uint64),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsErrorInvalidArgument,
},
},
{
name: "event query error",
args: args{
columns: es_models.Columns_Event,
dest: new(es_models.Event),
dbErr: sql.ErrConnDone,
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events",
dbErr: errors.IsInternal,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query, rowScanner := prepareColumns(tt.args.columns)
if query != tt.res.query {
t.Errorf("prepareColumns() got = %v, want %v", query, tt.res.query)
}
if tt.res.query == "" && rowScanner != nil {
t.Errorf("row scanner should be nil")
}
if rowScanner == nil {
return
}
err := rowScanner(prepareTestScan(tt.args.dbErr, tt.res.dbRow), tt.args.dest)
if tt.res.dbErr != nil {
if !tt.res.dbErr(err) {
t.Errorf("wrong error type in rowScanner got: %v", err)
}
} else {
if !reflect.DeepEqual(reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface(), tt.res.expected) {
t.Errorf("unexpected result from rowScanner want: %v got: %v", tt.res.dbRow, tt.args.dest)
}
}
})
}
}
func prepareTestScan(err error, res []interface{}) scan {
return func(dests ...interface{}) error {
if err != nil {
return err
}
if len(dests) != len(res) {
return errors.ThrowInvalidArgumentf(nil, "SQL-NML1q", "expected len %d got %d", len(res), len(dests))
}
for i, r := range res {
reflect.ValueOf(dests[i]).Elem().Set(reflect.ValueOf(r))
}
return nil
}
}
func Test_prepareCondition(t *testing.T) {
type args struct {
filters []*es_models.Filter
}
type res struct {
clause string
values []interface{}
}
tests := []struct {
name string
args args
res res
}{
{
name: "nil filters",
args: args{
filters: nil,
},
res: res{
clause: "",
values: nil,
},
},
{
name: "empty filters",
args: args{
filters: []*es_models.Filter{},
},
res: res{
clause: "",
values: nil,
},
},
{
name: "invalid condition",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateID, "wrong", es_models.Operation(-1)),
},
},
res: res{
clause: "",
values: nil,
},
},
{
name: "array as condition value",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
},
},
res: res{
clause: " WHERE aggregate_type = ANY(?)",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
},
},
{
name: "multiple filters",
args: args{
filters: []*es_models.Filter{
es_models.NewFilter(es_models.Field_AggregateType, []es_models.AggregateType{"user", "org"}, es_models.Operation_In),
es_models.NewFilter(es_models.Field_AggregateID, "1234", es_models.Operation_Equals),
es_models.NewFilter(es_models.Field_EventType, []es_models.EventType{"user.created", "org.created"}, es_models.Operation_In),
},
},
res: res{
clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?)",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClause, gotValues := prepareCondition(tt.args.filters)
if gotClause != tt.res.clause {
t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause)
}
if len(gotValues) != len(tt.res.values) {
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
return
}
for i, value := range gotValues {
if !reflect.DeepEqual(value, tt.res.values[i]) {
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
}
}
})
}
}
func Test_buildQuery(t *testing.T) {
type args struct {
queryFactory *es_models.SearchQueryFactory
}
type res struct {
query string
limit uint64
values []interface{}
rowScanner bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "invalid query factory",
args: args{
queryFactory: nil,
},
res: res{
query: "",
limit: 0,
rowScanner: false,
values: nil,
},
},
{
name: "with order by desc",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").OrderDesc(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user")},
},
},
{
name: "with limit",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
},
},
{
name: "with limit and order by desc",
args: args{
queryFactory: es_models.NewSearchQueryFactory("user").Limit(5).OrderDesc(),
},
res: res{
query: "SELECT creation_date, event_type, event_sequence, previous_sequence, event_data, editor_service, editor_user, resource_owner, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2",
rowScanner: true,
values: []interface{}{es_models.AggregateType("user"), uint64(5)},
limit: 5,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotQuery, gotLimit, gotValues, gotRowScanner := buildQuery(tt.args.queryFactory)
if gotQuery != tt.res.query {
t.Errorf("buildQuery() gotQuery = %v, want %v", gotQuery, tt.res.query)
}
if gotLimit != tt.res.limit {
t.Errorf("buildQuery() gotLimit = %v, want %v", gotLimit, tt.res.limit)
}
if len(gotValues) != len(tt.res.values) {
t.Errorf("wrong length of gotten values got = %d, want %d", len(gotValues), len(tt.res.values))
return
}
for i, value := range gotValues {
if !reflect.DeepEqual(value, tt.res.values[i]) {
t.Errorf("prepareCondition() gotValues = %v, want %v", gotValues, tt.res.values)
}
}
if (tt.res.rowScanner && gotRowScanner == nil) || (!tt.res.rowScanner && gotRowScanner != nil) {
t.Errorf("rowScanner should be nil==%v got nil==%v", tt.res.rowScanner, gotRowScanner == nil)
}
})
}
}

View File

@@ -0,0 +1,17 @@
package sql
import (
"context"
"database/sql"
//sql import
_ "github.com/lib/pq"
)
type SQL struct {
client *sql.DB
}
func (db *SQL) Health(ctx context.Context) error {
return db.client.Ping()
}

View File

@@ -0,0 +1,47 @@
package sql
import "database/sql/driver"
// Data represents a byte array that may be null.
// Data implements the sql.Scanner interface
type Data []byte
// Scan implements the Scanner interface.
func (data *Data) Scan(value interface{}) error {
if value == nil {
*data = nil
return nil
}
*data = Data(value.([]byte))
return nil
}
// Value implements the driver Valuer interface.
func (data Data) Value() (driver.Value, error) {
if len(data) == 0 {
return nil, nil
}
return []byte(data), nil
}
// Sequence represents a number that may be null.
// Sequence implements the sql.Scanner interface
type Sequence uint64
// Scan implements the Scanner interface.
func (seq *Sequence) Scan(value interface{}) error {
if value == nil {
*seq = 0
return nil
}
*seq = Sequence(value.(int64))
return nil
}
// Value implements the driver Valuer interface.
func (seq Sequence) Value() (driver.Value, error) {
if seq == 0 {
return nil, nil
}
return int64(seq), nil
}

View File

@@ -0,0 +1,48 @@
package locker
import (
"context"
"database/sql"
"fmt"
"time"
"github.com/caos/logging"
caos_errs "github.com/caos/zitadel/internal/errors"
"github.com/cockroachdb/cockroach-go/v2/crdb"
)
const (
insertStmtFormat = "INSERT INTO %s" +
" (locker_id, locked_until, view_name) VALUES ($1, now()+$2::INTERVAL, $3)" +
" ON CONFLICT (view_name)" +
" DO UPDATE SET locker_id = $4, locked_until = now()+$5::INTERVAL" +
" WHERE locks.view_name = $6 AND (locks.locker_id = $7 OR locks.locked_until < now())"
millisecondsAsSeconds = int64(time.Second / time.Millisecond)
)
type lock struct {
LockerID string `gorm:"column:locker_id;primary_key"`
LockedUntil time.Time `gorm:"column:locked_until"`
ViewName string `gorm:"column:view_name;primary_key"`
}
func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel string, waitTime time.Duration) error {
return crdb.ExecuteTx(context.Background(), dbClient, nil, func(tx *sql.Tx) error {
insert := fmt.Sprintf(insertStmtFormat, lockTable)
result, err := tx.Exec(insert,
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel,
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds,
viewModel, lockerID)
if err != nil {
tx.Rollback()
return err
}
if rows, _ := result.RowsAffected(); rows == 0 {
return caos_errs.ThrowAlreadyExists(nil, "SPOOL-lso0e", "view already locked")
}
logging.LogWithFields("LOCKE-lOgbg", "view", viewModel, "locker", lockerID).Debug("locker changed")
return nil
})
}

View File

@@ -0,0 +1,125 @@
package locker
import (
"database/sql"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
)
type dbMock struct {
db *sql.DB
mock sqlmock.Sqlmock
}
func mockDB(t *testing.T) *dbMock {
mockDB := dbMock{}
var err error
mockDB.db, mockDB.mock, err = sqlmock.New()
if err != nil {
t.Fatalf("error occured while creating stub db %v", err)
}
mockDB.mock.MatchExpectationsInOrder(true)
return &mockDB
}
func (db *dbMock) expectCommit() *dbMock {
db.mock.ExpectCommit()
return db
}
func (db *dbMock) expectRollback() *dbMock {
db.mock.ExpectRollback()
return db
}
func (db *dbMock) expectBegin() *dbMock {
db.mock.ExpectBegin()
return db
}
func (db *dbMock) expectSavepoint() *dbMock {
db.mock.ExpectExec("SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
return db
}
func (db *dbMock) expectReleaseSavepoint() *dbMock {
db.mock.ExpectExec("RELEASE SAVEPOINT").WillReturnResult(sqlmock.NewResult(1, 1))
return db
}
func (db *dbMock) expectRenew(lockerID, view string, affectedRows int64) *dbMock {
query := db.mock.
ExpectExec(`INSERT INTO table\.locks \(locker_id, locked_until, view_name\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\) ON CONFLICT \(view_name\) DO UPDATE SET locker_id = \$4, locked_until = now\(\)\+\$5::INTERVAL WHERE locks\.view_name = \$6 AND \(locks\.locker_id = \$7 OR locks\.locked_until < now\(\)\)`).
WithArgs(lockerID, sqlmock.AnyArg(), view, lockerID, sqlmock.AnyArg(), view, lockerID).
WillReturnResult(sqlmock.NewResult(1, 1))
if affectedRows == 0 {
query.WillReturnResult(sqlmock.NewResult(0, 0))
} else {
query.WillReturnResult(sqlmock.NewResult(1, affectedRows))
}
return db
}
func Test_locker_Renew(t *testing.T) {
type fields struct {
db *dbMock
}
type args struct {
tableName string
lockerID string
viewModel string
waitTime time.Duration
}
tests := []struct {
name string
fields fields
args args
wantErr bool
}{
{
name: "renew succeeded",
fields: fields{
db: mockDB(t).
expectBegin().
expectSavepoint().
expectRenew("locker", "view", 1).
expectReleaseSavepoint().
expectCommit(),
},
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
wantErr: false,
},
{
name: "renew now rows updated",
fields: fields{
db: mockDB(t).
expectBegin().
expectSavepoint().
expectRenew("locker", "view", 0).
expectRollback(),
},
args: args{tableName: "table.locks", lockerID: "locker", viewModel: "view", waitTime: 1 * time.Second},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := Renew(tt.fields.db.db, tt.args.tableName, tt.args.lockerID, tt.args.viewModel, tt.args.waitTime); (err != nil) != tt.wantErr {
t.Errorf("locker.Renew() error = %v, wantErr %v", err, tt.wantErr)
}
if err := tt.fields.db.mock.ExpectationsWereMet(); err != nil {
t.Errorf("not all database expectations met: %v", err)
}
})
}
}

View File

@@ -0,0 +1,146 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/caos/zitadel/internal/eventstore (interfaces: Eventstore)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/v1"
models "github.com/caos/zitadel/internal/eventstore/v1/models"
gomock "github.com/golang/mock/gomock"
reflect "reflect"
)
// MockEventstore is a mock of Eventstore interface
type MockEventstore struct {
ctrl *gomock.Controller
recorder *MockEventstoreMockRecorder
}
// MockEventstoreMockRecorder is the mock recorder for MockEventstore
type MockEventstoreMockRecorder struct {
mock *MockEventstore
}
// NewMockEventstore creates a new mock instance
func NewMockEventstore(ctrl *gomock.Controller) *MockEventstore {
mock := &MockEventstore{ctrl: ctrl}
mock.recorder = &MockEventstoreMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockEventstore) EXPECT() *MockEventstoreMockRecorder {
return m.recorder
}
// AggregateCreator mocks base method
func (m *MockEventstore) AggregateCreator() *models.AggregateCreator {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "AggregateCreator")
ret0, _ := ret[0].(*models.AggregateCreator)
return ret0
}
// AggregateCreator indicates an expected call of AggregateCreator
func (mr *MockEventstoreMockRecorder) AggregateCreator() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AggregateCreator", reflect.TypeOf((*MockEventstore)(nil).AggregateCreator))
}
// FilterEvents mocks base method
func (m *MockEventstore) FilterEvents(arg0 context.Context, arg1 *models.SearchQuery) ([]*models.Event, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterEvents", arg0, arg1)
ret0, _ := ret[0].([]*models.Event)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// FilterEvents indicates an expected call of FilterEvents
func (mr *MockEventstoreMockRecorder) FilterEvents(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterEvents", reflect.TypeOf((*MockEventstore)(nil).FilterEvents), arg0, arg1)
}
// Health mocks base method
func (m *MockEventstore) Health(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Health indicates an expected call of Health
func (mr *MockEventstoreMockRecorder) Health(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockEventstore)(nil).Health), arg0)
}
// LatestSequence mocks base method
func (m *MockEventstore) LatestSequence(arg0 context.Context, arg1 *models.SearchQueryFactory) (uint64, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LatestSequence", arg0, arg1)
ret0, _ := ret[0].(uint64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LatestSequence indicates an expected call of LatestSequence
func (mr *MockEventstoreMockRecorder) LatestSequence(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestSequence", reflect.TypeOf((*MockEventstore)(nil).LatestSequence), arg0, arg1)
}
// PushAggregates mocks base method
func (m *MockEventstore) PushAggregates(arg0 context.Context, arg1 ...*models.Aggregate) error {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "PushAggregates", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// PushAggregates indicates an expected call of PushAggregates
func (mr *MockEventstoreMockRecorder) PushAggregates(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PushAggregates", reflect.TypeOf((*MockEventstore)(nil).PushAggregates), varargs...)
}
// Subscribe mocks base method
func (m *MockEventstore) Subscribe(arg0 ...models.AggregateType) *v1.Subscription {
m.ctrl.T.Helper()
varargs := []interface{}{}
for _, a := range arg0 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Subscribe", varargs...)
ret0, _ := ret[0].(*v1.Subscription)
return ret0
}
// Subscribe indicates an expected call of Subscribe
func (mr *MockEventstoreMockRecorder) Subscribe(arg0 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Subscribe", reflect.TypeOf((*MockEventstore)(nil).Subscribe), arg0...)
}
// V2 mocks base method
func (m *MockEventstore) V2() *eventstore.Eventstore {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "V2")
ret0, _ := ret[0].(*eventstore.Eventstore)
return ret0
}
// V2 indicates an expected call of V2
func (mr *MockEventstoreMockRecorder) V2() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "V2", reflect.TypeOf((*MockEventstore)(nil).V2))
}

View File

@@ -0,0 +1,97 @@
package models
import (
"time"
"github.com/caos/zitadel/internal/errors"
)
type AggregateType string
func (at AggregateType) String() string {
return string(at)
}
type Aggregates []*Aggregate
type Aggregate struct {
ID string
typ AggregateType
PreviousSequence uint64
version Version
editorService string
editorUser string
resourceOwner string
Events []*Event
Precondition *precondition
}
func (a *Aggregate) Type() AggregateType {
return a.typ
}
type precondition struct {
Query *SearchQuery
Validation func(...*Event) error
}
func (a *Aggregate) AppendEvent(typ EventType, payload interface{}) (*Aggregate, error) {
if string(typ) == "" {
return a, errors.ThrowInvalidArgument(nil, "MODEL-TGoCb", "no event type")
}
data, err := eventData(payload)
if err != nil {
return a, err
}
e := &Event{
CreationDate: time.Now(),
Data: data,
Type: typ,
AggregateID: a.ID,
AggregateType: a.typ,
AggregateVersion: a.version,
EditorService: a.editorService,
EditorUser: a.editorUser,
ResourceOwner: a.resourceOwner,
}
a.Events = append(a.Events, e)
return a, nil
}
func (a *Aggregate) SetPrecondition(query *SearchQuery, validateFunc func(...*Event) error) *Aggregate {
a.Precondition = &precondition{Query: query, Validation: validateFunc}
return a
}
func (a *Aggregate) Validate() error {
if a == nil {
return errors.ThrowPreconditionFailed(nil, "MODEL-yi5AC", "aggregate is nil")
}
if a.ID == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-FSjKV", "id not set")
}
if string(a.typ) == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-aj4t2", "type not set")
}
if err := a.version.Validate(); err != nil {
return errors.ThrowPreconditionFailed(err, "MODEL-PupjX", "invalid version")
}
if a.editorService == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-clYbY", "editor service not set")
}
if a.editorUser == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-Xcssi", "editor user not set")
}
if a.resourceOwner == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-eBYUW", "resource owner not set")
}
if a.Precondition != nil && (a.Precondition.Query == nil || a.Precondition.Validation == nil) {
return errors.ThrowPreconditionFailed(nil, "MODEL-EEUvA", "invalid precondition")
}
return nil
}

View File

@@ -0,0 +1,56 @@
package models
import (
"context"
"github.com/caos/zitadel/internal/api/authz"
)
type AggregateCreator struct {
serviceName string
}
func NewAggregateCreator(serviceName string) *AggregateCreator {
return &AggregateCreator{serviceName: serviceName}
}
type option func(*Aggregate)
func (c *AggregateCreator) NewAggregate(ctx context.Context, id string, typ AggregateType, version Version, previousSequence uint64, opts ...option) (*Aggregate, error) {
ctxData := authz.GetCtxData(ctx)
editorUser := ctxData.UserID
resourceOwner := ctxData.OrgID
aggregate := &Aggregate{
ID: id,
typ: typ,
PreviousSequence: previousSequence,
version: version,
Events: make([]*Event, 0, 2),
editorService: c.serviceName,
editorUser: editorUser,
resourceOwner: resourceOwner,
}
for _, opt := range opts {
opt(aggregate)
}
if err := aggregate.Validate(); err != nil {
return nil, err
}
return aggregate, nil
}
func OverwriteEditorUser(userID string) func(*Aggregate) {
return func(a *Aggregate) {
a.editorUser = userID
}
}
func OverwriteResourceOwner(resourceOwner string) func(*Aggregate) {
return func(a *Aggregate) {
a.resourceOwner = resourceOwner
}
}

View File

@@ -0,0 +1,118 @@
package models
import (
"context"
"reflect"
"testing"
)
func TestAggregateCreator_NewAggregate(t *testing.T) {
type args struct {
ctx context.Context
id string
typ AggregateType
version Version
opts []option
}
tests := []struct {
name string
creator *AggregateCreator
args args
want *Aggregate
wantErr bool
}{
{
name: "no ctxdata and no options",
creator: &AggregateCreator{serviceName: "admin"},
wantErr: true,
want: nil,
args: args{
ctx: context.Background(),
id: "hodor",
typ: "user",
version: "v1.0.0",
},
},
{
name: "no id error",
creator: &AggregateCreator{serviceName: "admin"},
wantErr: true,
want: nil,
args: args{
ctx: context.Background(),
typ: "user",
version: "v1.0.0",
opts: []option{
OverwriteEditorUser("hodor"),
OverwriteResourceOwner("org"),
},
},
},
{
name: "no type error",
creator: &AggregateCreator{serviceName: "admin"},
wantErr: true,
want: nil,
args: args{
ctx: context.Background(),
id: "hodor",
version: "v1.0.0",
opts: []option{
OverwriteEditorUser("hodor"),
OverwriteResourceOwner("org"),
},
},
},
{
name: "invalid version error",
creator: &AggregateCreator{serviceName: "admin"},
wantErr: true,
want: nil,
args: args{
ctx: context.Background(),
id: "hodor",
typ: "user",
opts: []option{
OverwriteEditorUser("hodor"),
OverwriteResourceOwner("org"),
},
},
},
{
name: "create ok",
creator: &AggregateCreator{serviceName: "admin"},
wantErr: false,
want: &Aggregate{
ID: "hodor",
Events: make([]*Event, 0, 2),
typ: "user",
version: "v1.0.0",
editorService: "admin",
editorUser: "hodor",
resourceOwner: "org",
},
args: args{
ctx: context.Background(),
id: "hodor",
typ: "user",
version: "v1.0.0",
opts: []option{
OverwriteEditorUser("hodor"),
OverwriteResourceOwner("org"),
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.creator.NewAggregate(tt.args.ctx, tt.args.id, tt.args.typ, tt.args.version, 0, tt.args.opts...)
if (err != nil) != tt.wantErr {
t.Errorf("AggregateCreator.NewAggregate() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("AggregateCreator.NewAggregate() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -0,0 +1,310 @@
package models
import (
"testing"
"github.com/caos/zitadel/internal/errors"
)
func TestAggregate_AppendEvent(t *testing.T) {
type fields struct {
aggregate *Aggregate
}
type args struct {
typ EventType
payload interface{}
}
tests := []struct {
name string
fields fields
args args
want *Aggregate
wantErr bool
}{
{
name: "no event type error",
fields: fields{aggregate: &Aggregate{}},
args: args{},
want: &Aggregate{},
wantErr: true,
},
{
name: "invalid payload error",
fields: fields{aggregate: &Aggregate{}},
args: args{typ: "user", payload: 134},
want: &Aggregate{},
wantErr: true,
},
{
name: "event added",
fields: fields{aggregate: &Aggregate{Events: []*Event{}}},
args: args{typ: "user.deactivated"},
want: &Aggregate{Events: []*Event{
{Type: "user.deactivated"},
}},
wantErr: false,
},
{
name: "event added",
fields: fields{aggregate: &Aggregate{Events: []*Event{
{},
}}},
args: args{typ: "user.deactivated"},
want: &Aggregate{Events: []*Event{
{},
{Type: "user.deactivated"},
}},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.fields.aggregate.AppendEvent(tt.args.typ, tt.args.payload)
if (err != nil) != tt.wantErr {
t.Errorf("Aggregate.AppendEvent() error = %v, wantErr %v", err, tt.wantErr)
return
}
if len(tt.fields.aggregate.Events) != len(got.Events) {
t.Errorf("events len should be %d but was %d", len(tt.fields.aggregate.Events), len(got.Events))
}
})
}
}
func TestAggregate_Validate(t *testing.T) {
type fields struct {
aggregate *Aggregate
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "aggregate nil error",
wantErr: true,
},
{
name: "aggregate empty error",
wantErr: true,
fields: fields{aggregate: &Aggregate{}},
},
{
name: "no id error",
wantErr: true,
fields: fields{aggregate: &Aggregate{
typ: "user",
version: "v1.0.0",
editorService: "svc",
editorUser: "hodor",
resourceOwner: "org",
PreviousSequence: 5,
Events: []*Event{
{
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
}},
},
{
name: "no type error",
wantErr: true,
fields: fields{aggregate: &Aggregate{
ID: "aggID",
version: "v1.0.0",
editorService: "svc",
editorUser: "hodor",
resourceOwner: "org",
PreviousSequence: 5,
Events: []*Event{
{
AggregateID: "hodor",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
}},
},
{
name: "invalid version error",
wantErr: true,
fields: fields{aggregate: &Aggregate{
ID: "aggID",
typ: "user",
editorService: "svc",
editorUser: "hodor",
resourceOwner: "org",
PreviousSequence: 5,
Events: []*Event{
{
AggregateID: "hodor",
AggregateType: "user",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
}},
},
{
name: "no query in precondition error",
wantErr: true,
fields: fields{aggregate: &Aggregate{
ID: "aggID",
typ: "user",
version: "v1.0.0",
editorService: "svc",
editorUser: "hodor",
resourceOwner: "org",
PreviousSequence: 5,
Precondition: &precondition{
Validation: func(...*Event) error { return nil },
},
Events: []*Event{
{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
}},
},
{
name: "no func in precondition error",
wantErr: true,
fields: fields{aggregate: &Aggregate{
ID: "aggID",
typ: "user",
version: "v1.0.0",
editorService: "svc",
editorUser: "hodor",
resourceOwner: "org",
PreviousSequence: 5,
Precondition: &precondition{
Query: NewSearchQuery().AggregateIDFilter("hodor"),
},
Events: []*Event{
{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
}},
},
{
name: "validation without precondition ok",
wantErr: false,
fields: fields{aggregate: &Aggregate{
ID: "aggID",
typ: "user",
version: "v1.0.0",
editorService: "svc",
editorUser: "hodor",
resourceOwner: "org",
PreviousSequence: 5,
Events: []*Event{
{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
}},
},
{
name: "validation with precondition ok",
wantErr: false,
fields: fields{aggregate: &Aggregate{
ID: "aggID",
typ: "user",
version: "v1.0.0",
editorService: "svc",
editorUser: "hodor",
resourceOwner: "org",
PreviousSequence: 5,
Precondition: &precondition{
Validation: func(...*Event) error { return nil },
Query: NewSearchQuery().AggregateIDFilter("hodor"),
},
Events: []*Event{
{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := tt.fields.aggregate.Validate()
if (err != nil) != tt.wantErr {
t.Errorf("Aggregate.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.wantErr && !errors.IsPreconditionFailed(err) {
t.Errorf("error must extend precondition failed: %v", err)
}
})
}
}
func TestAggregate_SetPrecondition(t *testing.T) {
type fields struct {
aggregate *Aggregate
}
type args struct {
query *SearchQuery
validateFunc func(...*Event) error
}
tests := []struct {
name string
fields fields
args args
want *Aggregate
}{
{
name: "set precondition",
fields: fields{aggregate: &Aggregate{}},
args: args{
query: &SearchQuery{},
validateFunc: func(...*Event) error { return nil },
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.fields.aggregate.SetPrecondition(tt.args.query, tt.args.validateFunc)
if got.Precondition == nil {
t.Error("precondition must not be nil")
t.FailNow()
}
if got.Precondition.Query == nil {
t.Error("query of precondition must not be nil")
}
if got.Precondition.Validation == nil {
t.Error("precondition func must not be nil")
}
})
}
}

View File

@@ -0,0 +1,89 @@
package models
import (
"encoding/json"
"reflect"
"time"
"github.com/caos/zitadel/internal/errors"
)
type EventType string
func (et EventType) String() string {
return string(et)
}
type Event struct {
ID string
Sequence uint64
CreationDate time.Time
Type EventType
PreviousSequence uint64
Data []byte
AggregateID string
AggregateType AggregateType
AggregateVersion Version
EditorService string
EditorUser string
ResourceOwner string
}
func eventData(i interface{}) ([]byte, error) {
switch v := i.(type) {
case []byte:
return v, nil
case map[string]interface{}:
bytes, err := json.Marshal(v)
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "MODEL-s2fgE", "unable to marshal data")
}
return bytes, nil
case nil:
return nil, nil
default:
t := reflect.TypeOf(i)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, errors.ThrowInvalidArgument(nil, "MODEL-rjWdN", "data is not valid")
}
bytes, err := json.Marshal(v)
if err != nil {
return nil, errors.ThrowInvalidArgument(err, "MODEL-Y2OpM", "unable to marshal data")
}
return bytes, nil
}
}
func (e *Event) Validate() error {
if e == nil {
return errors.ThrowPreconditionFailed(nil, "MODEL-oEAG4", "event is nil")
}
if string(e.Type) == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-R2sB0", "type not defined")
}
if e.AggregateID == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-A6WwL", "aggregate id not set")
}
if e.AggregateType == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-EzdyK", "aggregate type not set")
}
if err := e.AggregateVersion.Validate(); err != nil {
return err
}
if e.EditorService == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-4Yqik", "editor service not set")
}
if e.EditorUser == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-L3NHO", "editor user not set")
}
if e.ResourceOwner == "" {
return errors.ThrowPreconditionFailed(nil, "MODEL-omFVT", "resource ow")
}
return nil
}

View File

@@ -0,0 +1,198 @@
package models
import (
"reflect"
"testing"
)
func Test_eventData(t *testing.T) {
type args struct {
i interface{}
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
{
name: "from bytes",
args: args{[]byte(`{"hodor":"asdf"}`)},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from pointer",
args: args{&struct {
Hodor string `json:"hodor"`
}{Hodor: "asdf"}},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from struct",
args: args{struct {
Hodor string `json:"hodor"`
}{Hodor: "asdf"}},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from map",
args: args{
map[string]interface{}{"hodor": "asdf"},
},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from nil",
args: args{},
want: nil,
wantErr: false,
},
{
name: "invalid data",
args: args{876},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := eventData(tt.args.i)
if (err != nil) != tt.wantErr {
t.Errorf("eventData() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("eventData() = %s, want %s", got, tt.want)
}
})
}
}
func TestEvent_Validate(t *testing.T) {
type fields struct {
event *Event
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "event nil",
wantErr: true,
},
{
name: "event empty",
fields: fields{event: &Event{}},
wantErr: true,
},
{
name: "no aggregate id",
fields: fields{event: &Event{
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
wantErr: true,
},
{
name: "no aggregate type",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
wantErr: true,
},
{
name: "no aggregate version",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
wantErr: true,
},
{
name: "no editor service",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
wantErr: true,
},
{
name: "no editor user",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
ResourceOwner: "org",
Type: "born",
}},
wantErr: true,
},
{
name: "no resource owner",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
Type: "born",
}},
wantErr: true,
},
{
name: "no type",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
}},
wantErr: true,
},
{
name: "all fields set",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
EditorService: "management",
EditorUser: "hodor",
ResourceOwner: "org",
Type: "born",
}},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.fields.event.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Event.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,13 @@
package models
type Field int32
const (
Field_AggregateType Field = 1 + iota
Field_AggregateID
Field_LatestSequence
Field_ResourceOwner
Field_EditorService
Field_EditorUser
Field_EventType
)

View File

@@ -0,0 +1,46 @@
package models
import (
"github.com/caos/zitadel/internal/errors"
)
type Filter struct {
field Field
value interface{}
operation Operation
}
//NewFilter is used in tests. Use searchQuery.*Filter() instead
func NewFilter(field Field, value interface{}, operation Operation) *Filter {
return &Filter{
field: field,
value: value,
operation: operation,
}
}
func (f *Filter) GetField() Field {
return f.field
}
func (f *Filter) GetOperation() Operation {
return f.operation
}
func (f *Filter) GetValue() interface{} {
return f.value
}
func (f *Filter) Validate() error {
if f == nil {
return errors.ThrowPreconditionFailed(nil, "MODEL-z6KcG", "filter is nil")
}
if f.field <= 0 {
return errors.ThrowPreconditionFailed(nil, "MODEL-zw62U", "field not definded")
}
if f.value == nil {
return errors.ThrowPreconditionFailed(nil, "MODEL-GJ9ct", "no value definded")
}
if f.operation <= 0 {
return errors.ThrowPreconditionFailed(nil, "MODEL-RrQTy", "operation not definded")
}
return nil
}

View File

@@ -0,0 +1,104 @@
package models
import (
"reflect"
"testing"
)
func TestNewFilter(t *testing.T) {
type args struct {
field Field
value interface{}
operation Operation
}
tests := []struct {
name string
args args
want *Filter
}{
{
name: "aggregateID equals",
args: args{
field: Field_AggregateID,
value: "hodor",
operation: Operation_Equals,
},
want: &Filter{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewFilter(tt.args.field, tt.args.value, tt.args.operation); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewFilter() = %v, want %v", got, tt.want)
}
})
}
}
func TestFilter_Validate(t *testing.T) {
type fields struct {
field Field
value interface{}
operation Operation
isNil bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "correct filter",
fields: fields{
field: Field_LatestSequence,
operation: Operation_Greater,
value: uint64(235),
},
wantErr: false,
},
{
name: "filter is nil",
fields: fields{isNil: true},
wantErr: true,
},
{
name: "no field error",
fields: fields{
operation: Operation_Greater,
value: uint64(235),
},
wantErr: true,
},
{
name: "no value error",
fields: fields{
field: Field_LatestSequence,
operation: Operation_Greater,
},
wantErr: true,
},
{
name: "no operation error",
fields: fields{
field: Field_LatestSequence,
value: uint64(235),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var f *Filter
if !tt.fields.isNil {
f = &Filter{
field: tt.fields.field,
value: tt.fields.value,
operation: tt.fields.operation,
}
}
if err := f.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Filter.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,34 @@
package models
import (
"time"
)
type ObjectRoot struct {
AggregateID string `json:"-"`
Sequence uint64 `json:"-"`
ResourceOwner string `json:"-"`
CreationDate time.Time `json:"-"`
ChangeDate time.Time `json:"-"`
}
func (o *ObjectRoot) AppendEvent(event *Event) {
if o.AggregateID == "" {
o.AggregateID = event.AggregateID
} else if o.AggregateID != event.AggregateID {
return
}
if o.ResourceOwner == "" {
o.ResourceOwner = event.ResourceOwner
}
o.ChangeDate = event.CreationDate
if o.CreationDate.IsZero() {
o.CreationDate = o.ChangeDate
}
o.Sequence = event.Sequence
}
func (o *ObjectRoot) IsZero() bool {
return o.AggregateID == ""
}

View File

@@ -0,0 +1,81 @@
package models
import (
"testing"
"time"
)
func TestObjectRoot_AppendEvent(t *testing.T) {
type fields struct {
ID string
Sequence uint64
CreationDate time.Time
ChangeDate time.Time
}
type args struct {
event *Event
isNewRoot bool
}
tests := []struct {
name string
fields fields
args args
}{
{
"new root",
fields{},
args{
&Event{
AggregateID: "aggID",
Sequence: 34555,
CreationDate: time.Now(),
},
true,
},
},
{
"existing root",
fields{
"agg",
234,
time.Now().Add(-24 * time.Hour),
time.Now().Add(-12 * time.Hour),
},
args{
&Event{
AggregateID: "agg",
Sequence: 34555425,
CreationDate: time.Now(),
PreviousSequence: 22,
},
false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := &ObjectRoot{
AggregateID: tt.fields.ID,
Sequence: tt.fields.Sequence,
CreationDate: tt.fields.CreationDate,
ChangeDate: tt.fields.ChangeDate,
}
o.AppendEvent(tt.args.event)
if tt.args.isNewRoot {
if !o.CreationDate.Equal(tt.args.event.CreationDate) {
t.Error("creationDate should be equal to event on new root")
}
} else {
if o.CreationDate.Equal(o.ChangeDate) {
t.Error("creationDate and changedate should differ")
}
}
if o.Sequence != tt.args.event.Sequence {
t.Errorf("sequence not equal to event: event: %d root: %d", tt.args.event.Sequence, o.Sequence)
}
if !o.ChangeDate.Equal(tt.args.event.CreationDate) {
t.Errorf("changedate should be equal to event creation date: event: %v root: %v", tt.args.event.CreationDate, o.ChangeDate)
}
})
}
}

View File

@@ -0,0 +1,10 @@
package models
type Operation int32
const (
Operation_Equals Operation = 1 + iota
Operation_Greater
Operation_Less
Operation_In
)

View File

@@ -0,0 +1,213 @@
package models
import (
"github.com/caos/logging"
"github.com/caos/zitadel/internal/errors"
)
type SearchQueryFactory struct {
columns Columns
limit uint64
desc bool
aggregateTypes []AggregateType
aggregateIDs []string
sequenceFrom uint64
sequenceTo uint64
eventTypes []EventType
resourceOwner string
}
type searchQuery struct {
Columns Columns
Limit uint64
Desc bool
Filters []*Filter
}
type Columns int32
const (
Columns_Event = iota
Columns_Max_Sequence
//insert new columns-types before this columnsCount because count is needed for validation
columnsCount
)
//FactoryFromSearchQuery is deprecated because it's for migration purposes. use NewSearchQueryFactory
func FactoryFromSearchQuery(query *SearchQuery) *SearchQueryFactory {
factory := &SearchQueryFactory{
columns: Columns_Event,
desc: query.Desc,
limit: query.Limit,
}
for _, filter := range query.Filters {
switch filter.field {
case Field_AggregateType:
factory = factory.aggregateTypesMig(filter.value.([]AggregateType)...)
case Field_AggregateID:
if aggregateID, ok := filter.value.(string); ok {
factory = factory.AggregateIDs(aggregateID)
} else if aggregateIDs, ok := filter.value.([]string); ok {
factory = factory.AggregateIDs(aggregateIDs...)
}
case Field_LatestSequence:
if filter.operation == Operation_Greater {
factory = factory.SequenceGreater(filter.value.(uint64))
} else {
factory = factory.SequenceLess(filter.value.(uint64))
}
case Field_ResourceOwner:
factory = factory.ResourceOwner(filter.value.(string))
case Field_EventType:
factory = factory.EventTypes(filter.value.([]EventType)...)
case Field_EditorService, Field_EditorUser:
logging.Log("MODEL-Mr0VN").WithField("value", filter.value).Panic("field not converted to factory")
}
}
return factory
}
func NewSearchQueryFactory(aggregateTypes ...AggregateType) *SearchQueryFactory {
return &SearchQueryFactory{
aggregateTypes: aggregateTypes,
}
}
func (factory *SearchQueryFactory) Columns(columns Columns) *SearchQueryFactory {
factory.columns = columns
return factory
}
func (factory *SearchQueryFactory) Limit(limit uint64) *SearchQueryFactory {
factory.limit = limit
return factory
}
func (factory *SearchQueryFactory) SequenceGreater(sequence uint64) *SearchQueryFactory {
factory.sequenceFrom = sequence
return factory
}
func (factory *SearchQueryFactory) SequenceLess(sequence uint64) *SearchQueryFactory {
factory.sequenceTo = sequence
return factory
}
func (factory *SearchQueryFactory) AggregateIDs(ids ...string) *SearchQueryFactory {
factory.aggregateIDs = ids
return factory
}
func (factory *SearchQueryFactory) aggregateTypesMig(types ...AggregateType) *SearchQueryFactory {
factory.aggregateTypes = types
return factory
}
func (factory *SearchQueryFactory) EventTypes(types ...EventType) *SearchQueryFactory {
factory.eventTypes = types
return factory
}
func (factory *SearchQueryFactory) ResourceOwner(resourceOwner string) *SearchQueryFactory {
factory.resourceOwner = resourceOwner
return factory
}
func (factory *SearchQueryFactory) OrderDesc() *SearchQueryFactory {
factory.desc = true
return factory
}
func (factory *SearchQueryFactory) OrderAsc() *SearchQueryFactory {
factory.desc = false
return factory
}
func (factory *SearchQueryFactory) Build() (*searchQuery, error) {
if factory == nil ||
len(factory.aggregateTypes) < 1 ||
(factory.columns < 0 || factory.columns >= columnsCount) {
return nil, errors.ThrowPreconditionFailed(nil, "MODEL-tGAD3", "factory invalid")
}
filters := []*Filter{
factory.aggregateTypeFilter(),
}
for _, f := range []func() *Filter{
factory.aggregateIDFilter,
factory.sequenceFromFilter,
factory.sequenceToFilter,
factory.eventTypeFilter,
factory.resourceOwnerFilter,
} {
if filter := f(); filter != nil {
filters = append(filters, filter)
}
}
return &searchQuery{
Columns: factory.columns,
Limit: factory.limit,
Desc: factory.desc,
Filters: filters,
}, nil
}
func (factory *SearchQueryFactory) aggregateIDFilter() *Filter {
if len(factory.aggregateIDs) < 1 {
return nil
}
if len(factory.aggregateIDs) == 1 {
return NewFilter(Field_AggregateID, factory.aggregateIDs[0], Operation_Equals)
}
return NewFilter(Field_AggregateID, factory.aggregateIDs, Operation_In)
}
func (factory *SearchQueryFactory) eventTypeFilter() *Filter {
if len(factory.eventTypes) < 1 {
return nil
}
if len(factory.eventTypes) == 1 {
return NewFilter(Field_EventType, factory.eventTypes[0], Operation_Equals)
}
return NewFilter(Field_EventType, factory.eventTypes, Operation_In)
}
func (factory *SearchQueryFactory) aggregateTypeFilter() *Filter {
if len(factory.aggregateTypes) == 1 {
return NewFilter(Field_AggregateType, factory.aggregateTypes[0], Operation_Equals)
}
return NewFilter(Field_AggregateType, factory.aggregateTypes, Operation_In)
}
func (factory *SearchQueryFactory) sequenceFromFilter() *Filter {
if factory.sequenceFrom == 0 {
return nil
}
sortOrder := Operation_Greater
if factory.desc {
sortOrder = Operation_Less
}
return NewFilter(Field_LatestSequence, factory.sequenceFrom, sortOrder)
}
func (factory *SearchQueryFactory) sequenceToFilter() *Filter {
if factory.sequenceTo == 0 {
return nil
}
sortOrder := Operation_Less
if factory.desc {
sortOrder = Operation_Greater
}
return NewFilter(Field_LatestSequence, factory.sequenceTo, sortOrder)
}
func (factory *SearchQueryFactory) resourceOwnerFilter() *Filter {
if factory.resourceOwner == "" {
return nil
}
return NewFilter(Field_ResourceOwner, factory.resourceOwner, Operation_Equals)
}

View File

@@ -0,0 +1,96 @@
package models
import "github.com/caos/zitadel/internal/errors"
//SearchQuery is deprecated. Use SearchQueryFactory
type SearchQuery struct {
Limit uint64
Desc bool
Filters []*Filter
}
//NewSearchQuery is deprecated. Use SearchQueryFactory
func NewSearchQuery() *SearchQuery {
return &SearchQuery{
Filters: make([]*Filter, 0, 4),
}
}
func (q *SearchQuery) SetLimit(limit uint64) *SearchQuery {
q.Limit = limit
return q
}
func (q *SearchQuery) OrderDesc() *SearchQuery {
q.Desc = true
return q
}
func (q *SearchQuery) OrderAsc() *SearchQuery {
q.Desc = false
return q
}
func (q *SearchQuery) AggregateIDFilter(id string) *SearchQuery {
return q.setFilter(NewFilter(Field_AggregateID, id, Operation_Equals))
}
func (q *SearchQuery) AggregateIDsFilter(ids ...string) *SearchQuery {
return q.setFilter(NewFilter(Field_AggregateID, ids, Operation_In))
}
func (q *SearchQuery) AggregateTypeFilter(types ...AggregateType) *SearchQuery {
return q.setFilter(NewFilter(Field_AggregateType, types, Operation_In))
}
func (q *SearchQuery) EventTypesFilter(types ...EventType) *SearchQuery {
return q.setFilter(NewFilter(Field_EventType, types, Operation_In))
}
func (q *SearchQuery) LatestSequenceFilter(sequence uint64) *SearchQuery {
if sequence == 0 {
return q
}
sortOrder := Operation_Greater
if q.Desc {
sortOrder = Operation_Less
}
return q.setFilter(NewFilter(Field_LatestSequence, sequence, sortOrder))
}
func (q *SearchQuery) SequenceBetween(from, to uint64) *SearchQuery {
q.setFilter(NewFilter(Field_LatestSequence, from, Operation_Greater))
q.setFilter(NewFilter(Field_LatestSequence, to, Operation_Less))
return q
}
func (q *SearchQuery) ResourceOwnerFilter(resourceOwner string) *SearchQuery {
return q.setFilter(NewFilter(Field_ResourceOwner, resourceOwner, Operation_Equals))
}
func (q *SearchQuery) setFilter(filter *Filter) *SearchQuery {
for i, f := range q.Filters {
if f.field == filter.field && f.field != Field_LatestSequence {
q.Filters[i] = filter
return q
}
}
q.Filters = append(q.Filters, filter)
return q
}
func (q *SearchQuery) Validate() error {
if q == nil {
return errors.ThrowPreconditionFailed(nil, "MODEL-J5xQi", "search query is nil")
}
if len(q.Filters) == 0 {
return errors.ThrowPreconditionFailed(nil, "MODEL-pF3DR", "no filters set")
}
for _, filter := range q.Filters {
if err := filter.Validate(); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,65 @@
package models
import (
"reflect"
"testing"
)
func TestSearchQuery_setFilter(t *testing.T) {
type fields struct {
query *SearchQuery
}
type args struct {
filters []*Filter
}
tests := []struct {
name string
fields fields
args args
want *SearchQuery
}{
{
name: "set idFilter",
fields: fields{query: NewSearchQuery()},
args: args{filters: []*Filter{
{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
}},
want: &SearchQuery{Filters: []*Filter{
{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
}},
},
{
name: "overwrite idFilter",
fields: fields{query: NewSearchQuery()},
args: args{filters: []*Filter{
{field: Field_AggregateID, operation: Operation_Equals, value: "hodor"},
{field: Field_AggregateID, operation: Operation_Equals, value: "ursli"},
}},
want: &SearchQuery{Filters: []*Filter{
{field: Field_AggregateID, operation: Operation_Equals, value: "ursli"},
}},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := tt.fields.query
for _, filter := range tt.args.filters {
got = got.setFilter(filter)
}
for _, wantFilter := range tt.want.Filters {
found := false
for _, gotFilter := range got.Filters {
if gotFilter.field == wantFilter.field {
found = true
if !reflect.DeepEqual(wantFilter, gotFilter) {
t.Errorf("filter not as expected: want: %v got %v", wantFilter, gotFilter)
}
}
}
if !found {
t.Errorf("filter field %v not found", wantFilter.field)
}
}
})
}
}

View File

@@ -0,0 +1,465 @@
package models
import (
"reflect"
"testing"
"github.com/caos/zitadel/internal/errors"
)
func testSetColumns(columns Columns) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.Columns(columns)
return factory
}
}
func testSetLimit(limit uint64) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.Limit(limit)
return factory
}
}
func testSetSequence(sequence uint64) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.SequenceGreater(sequence)
return factory
}
}
func testSetAggregateIDs(aggregateIDs ...string) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.AggregateIDs(aggregateIDs...)
return factory
}
}
func testSetEventTypes(eventTypes ...EventType) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.EventTypes(eventTypes...)
return factory
}
}
func testSetResourceOwner(resourceOwner string) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
factory = factory.ResourceOwner(resourceOwner)
return factory
}
}
func testSetSortOrder(asc bool) func(factory *SearchQueryFactory) *SearchQueryFactory {
return func(factory *SearchQueryFactory) *SearchQueryFactory {
if asc {
factory = factory.OrderAsc()
} else {
factory = factory.OrderDesc()
}
return factory
}
}
func TestSearchQueryFactorySetters(t *testing.T) {
type args struct {
aggregateTypes []AggregateType
setters []func(*SearchQueryFactory) *SearchQueryFactory
}
tests := []struct {
name string
args args
res *SearchQueryFactory
}{
{
name: "New factory",
args: args{
aggregateTypes: []AggregateType{"user", "org"},
},
res: &SearchQueryFactory{
aggregateTypes: []AggregateType{"user", "org"},
},
},
{
name: "set columns",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetColumns(Columns_Max_Sequence)},
},
res: &SearchQueryFactory{
columns: Columns_Max_Sequence,
},
},
{
name: "set limit",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetLimit(100)},
},
res: &SearchQueryFactory{
limit: 100,
},
},
{
name: "set sequence",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetSequence(90)},
},
res: &SearchQueryFactory{
sequenceFrom: 90,
},
},
{
name: "set aggregateIDs",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "09824")},
},
res: &SearchQueryFactory{
aggregateIDs: []string{"1235", "09824"},
},
},
{
name: "set eventTypes",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetEventTypes("user.created", "user.updated")},
},
res: &SearchQueryFactory{
eventTypes: []EventType{"user.created", "user.updated"},
},
},
{
name: "set resource owner",
args: args{
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetResourceOwner("hodor")},
},
res: &SearchQueryFactory{
resourceOwner: "hodor",
},
},
{
name: "default search query",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{testSetAggregateIDs("1235", "024"), testSetSortOrder(false)},
},
res: &SearchQueryFactory{
aggregateTypes: []AggregateType{"user"},
aggregateIDs: []string{"1235", "024"},
desc: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
factory := NewSearchQueryFactory(tt.args.aggregateTypes...)
for _, setter := range tt.args.setters {
factory = setter(factory)
}
if !reflect.DeepEqual(factory, tt.res) {
t.Errorf("NewSearchQueryFactory() = %v, want %v", factory, tt.res)
}
})
}
}
func TestSearchQueryFactoryBuild(t *testing.T) {
type args struct {
aggregateTypes []AggregateType
setters []func(*SearchQueryFactory) *SearchQueryFactory
}
type res struct {
isErr func(err error) bool
query *searchQuery
}
tests := []struct {
name string
args args
res res
}{
{
name: "no aggregate types",
args: args{
aggregateTypes: []AggregateType{},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
},
res: res{
isErr: errors.IsPreconditionFailed,
query: nil,
},
},
{
name: "invalid column (too low)",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetColumns(Columns(-1)),
},
},
res: res{
isErr: errors.IsPreconditionFailed,
},
},
{
name: "invalid column (too high)",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetColumns(columnsCount),
},
},
res: res{
isErr: errors.IsPreconditionFailed,
},
},
{
name: "filter aggregate type",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
},
},
},
},
{
name: "filter aggregate types",
args: args{
aggregateTypes: []AggregateType{"user", "org"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, []AggregateType{"user", "org"}, Operation_In),
},
},
},
},
{
name: "filter aggregate type, limit, desc",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetLimit(5),
testSetSortOrder(false),
testSetSequence(100),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: true,
Limit: 5,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
},
},
},
},
{
name: "filter aggregate type, limit, asc",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetLimit(5),
testSetSortOrder(true),
testSetSequence(100),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 5,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Greater),
},
},
},
},
{
name: "filter aggregate type, limit, desc, max event sequence cols",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetLimit(5),
testSetSortOrder(false),
testSetSequence(100),
testSetColumns(Columns_Max_Sequence),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: Columns_Max_Sequence,
Desc: true,
Limit: 5,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(100), Operation_Less),
},
},
},
},
{
name: "filter aggregate type and aggregate id",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetAggregateIDs("1234"),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_AggregateID, "1234", Operation_Equals),
},
},
},
},
{
name: "filter aggregate type and aggregate ids",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetAggregateIDs("1234", "0815"),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_AggregateID, []string{"1234", "0815"}, Operation_In),
},
},
},
},
{
name: "filter aggregate type and sequence greater",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetSequence(8),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_LatestSequence, uint64(8), Operation_Greater),
},
},
},
},
{
name: "filter aggregate type and event type",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetEventTypes("user.created"),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_EventType, EventType("user.created"), Operation_Equals),
},
},
},
},
{
name: "filter aggregate type and event types",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetEventTypes("user.created", "user.changed"),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_EventType, []EventType{"user.created", "user.changed"}, Operation_In),
},
},
},
},
{
name: "filter aggregate type resource owner",
args: args{
aggregateTypes: []AggregateType{"user"},
setters: []func(*SearchQueryFactory) *SearchQueryFactory{
testSetResourceOwner("hodor"),
},
},
res: res{
isErr: nil,
query: &searchQuery{
Columns: 0,
Desc: false,
Limit: 0,
Filters: []*Filter{
NewFilter(Field_AggregateType, AggregateType("user"), Operation_Equals),
NewFilter(Field_ResourceOwner, "hodor", Operation_Equals),
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
factory := NewSearchQueryFactory(tt.args.aggregateTypes...)
for _, f := range tt.args.setters {
factory = f(factory)
}
query, err := factory.Build()
if tt.res.isErr != nil && !tt.res.isErr(err) {
t.Errorf("wrong error: %v", err)
return
}
if err != nil && tt.res.isErr == nil {
t.Errorf("no error expected: %v", err)
return
}
if !reflect.DeepEqual(query, tt.res.query) {
t.Errorf("NewSearchQueryFactory() = %v, want %v", factory, tt.res)
}
})
}
}

View File

@@ -0,0 +1,22 @@
package models
import (
"regexp"
"github.com/caos/zitadel/internal/errors"
)
var versionRegexp = regexp.MustCompile(`^v[0-9]+(\.[0-9]+){0,2}$`)
type Version string
func (v Version) Validate() error {
if !versionRegexp.MatchString(string(v)) {
return errors.ThrowPreconditionFailed(nil, "MODEL-luDuS", "version is not semver")
}
return nil
}
func (v Version) String() string {
return string(v)
}

View File

@@ -0,0 +1,39 @@
package models
import "testing"
func TestVersion_Validate(t *testing.T) {
tests := []struct {
name string
v Version
wantErr bool
}{
{
"correct version",
"v1.23.23",
false,
},
{
"no v prefix",
"1.2.2",
true,
},
{
"letters in version",
"v1.as.3",
true,
},
{
"no version",
"",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.v.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Version.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,74 @@
package query
import (
"context"
"github.com/caos/zitadel/internal/eventstore/v1"
"time"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
const (
eventLimit = 10000
)
type Handler interface {
ViewModel() string
EventQuery() (*models.SearchQuery, error)
Reduce(*models.Event) error
OnError(event *models.Event, err error) error
OnSuccess() error
MinimumCycleDuration() time.Duration
LockDuration() time.Duration
QueryLimit() uint64
AggregateTypes() []models.AggregateType
CurrentSequence() (uint64, error)
Eventstore() v1.Eventstore
}
func ReduceEvent(handler Handler, event *models.Event) {
currentSequence, err := handler.CurrentSequence()
if err != nil {
logging.Log("HANDL-BmpkC").WithError(err).Warn("unable to get current sequence")
return
}
searchQuery := models.NewSearchQuery().
AggregateTypeFilter(handler.AggregateTypes()...).
SequenceBetween(currentSequence, event.Sequence).
SetLimit(eventLimit)
unprocessedEvents, err := handler.Eventstore().FilterEvents(context.Background(), searchQuery)
if err != nil {
logging.LogWithFields("HANDL-L6YH1", "seq", event.Sequence).Warn("filter failed")
return
}
for _, unprocessedEvent := range unprocessedEvents {
currentSequence, err := handler.CurrentSequence()
if err != nil {
logging.Log("HANDL-BmpkC").WithError(err).Warn("unable to get current sequence")
return
}
if unprocessedEvent.Sequence < currentSequence {
logging.LogWithFields("QUERY-DOYVN",
"unprocessed", unprocessedEvent.Sequence,
"current", currentSequence,
"view", handler.ViewModel()).
Warn("sequence not matching")
return
}
err = handler.Reduce(unprocessedEvent)
logging.LogWithFields("HANDL-V42TI", "seq", unprocessedEvent.Sequence).OnError(err).Warn("reduce failed")
}
if len(unprocessedEvents) == eventLimit {
logging.LogWithFields("QUERY-BSqe9", "seq", event.Sequence).Warn("didnt process event")
return
}
err = handler.Reduce(event)
logging.LogWithFields("HANDL-wQDL2", "seq", event.Sequence).OnError(err).Warn("reduce failed")
}

View File

@@ -0,0 +1,36 @@
package sdk
import (
"fmt"
"github.com/caos/zitadel/internal/errors"
)
var (
_ AppendEventError = (*appendEventError)(nil)
_ errors.Error = (*appendEventError)(nil)
)
type AppendEventError interface {
error
IsAppendEventError()
}
type appendEventError struct {
*errors.CaosError
}
func ThrowAppendEventError(parent error, id, message string) error {
return &appendEventError{errors.CreateCaosError(parent, id, message)}
}
func ThrowAggregaterf(parent error, id, format string, a ...interface{}) error {
return ThrowAppendEventError(parent, id, fmt.Sprintf(format, a...))
}
func (err *appendEventError) IsAppendEventError() {}
func IsAppendEventError(err error) bool {
_, ok := err.(AppendEventError)
return ok
}

View File

@@ -0,0 +1,31 @@
package sdk
import (
"errors"
"testing"
"github.com/stretchr/testify/assert"
)
func TestAppendEventError(t *testing.T) {
var err interface{}
err = new(appendEventError)
_, ok := err.(*appendEventError)
assert.True(t, ok)
}
func TestThrowAppendEventErrorf(t *testing.T) {
err := ThrowAggregaterf(nil, "id", "msg")
_, ok := err.(*appendEventError)
assert.True(t, ok)
}
func TestIsAppendEventError(t *testing.T) {
err := ThrowAppendEventError(nil, "id", "msg")
ok := IsAppendEventError(err)
assert.True(t, ok)
err = errors.New("i am found")
ok = IsAppendEventError(err)
assert.False(t, ok)
}

View File

@@ -0,0 +1,87 @@
package sdk
import (
"context"
"github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
)
type filterFunc func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error)
type appendFunc func(...*es_models.Event) error
type AggregateFunc func(context.Context) (*es_models.Aggregate, error)
type pushFunc func(context.Context, ...*es_models.Aggregate) error
func Filter(ctx context.Context, filter filterFunc, appender appendFunc, query *es_models.SearchQuery) error {
events, err := filter(ctx, query)
if err != nil {
return err
}
if len(events) == 0 {
return errors.ThrowNotFound(nil, "EVENT-8due3", "no events found")
}
err = appender(events...)
if err != nil {
return ThrowAppendEventError(err, "SDK-awiWK", "Errors.Internal")
}
return nil
}
// Push is Deprecated use PushAggregates
// Push creates the aggregates from aggregater
// and pushes the aggregates to the given pushFunc
// the given events are appended by the appender
func Push(ctx context.Context, push pushFunc, appender appendFunc, aggregaters ...AggregateFunc) (err error) {
if len(aggregaters) < 1 {
return errors.ThrowPreconditionFailed(nil, "SDK-q9wjp", "Errors.Internal")
}
aggregates, err := makeAggregates(ctx, aggregaters)
if err != nil {
return err
}
err = push(ctx, aggregates...)
if err != nil {
return err
}
return appendAggregates(appender, aggregates)
}
func PushAggregates(ctx context.Context, push pushFunc, appender appendFunc, aggregates ...*es_models.Aggregate) (err error) {
if len(aggregates) < 1 {
return errors.ThrowPreconditionFailed(nil, "SDK-q9wjp", "Errors.Internal")
}
err = push(ctx, aggregates...)
if err != nil {
return err
}
if appender == nil {
return nil
}
return appendAggregates(appender, aggregates)
}
func appendAggregates(appender appendFunc, aggregates []*es_models.Aggregate) error {
for _, aggregate := range aggregates {
err := appender(aggregate.Events...)
if err != nil {
return ThrowAppendEventError(err, "SDK-o6kzK", "Errors.Internal")
}
}
return nil
}
func makeAggregates(ctx context.Context, aggregaters []AggregateFunc) (aggregates []*es_models.Aggregate, err error) {
aggregates = make([]*es_models.Aggregate, len(aggregaters))
for i, aggregater := range aggregaters {
aggregates[i], err = aggregater(ctx)
if err != nil {
return nil, err
}
}
return aggregates, nil
}

View File

@@ -0,0 +1,193 @@
package sdk
import (
"context"
"testing"
"github.com/caos/zitadel/internal/errors"
es_models "github.com/caos/zitadel/internal/eventstore/v1/models"
)
func TestFilter(t *testing.T) {
type args struct {
filter filterFunc
appender appendFunc
}
tests := []struct {
name string
args args
wantErr func(error) bool
}{
{
name: "filter error",
args: args{
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
return nil, errors.ThrowInternal(nil, "test-46VX2", "test error")
},
appender: nil,
},
wantErr: errors.IsInternal,
},
{
name: "no events found",
args: args{
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
return []*es_models.Event{}, nil
},
appender: nil,
},
wantErr: errors.IsNotFound,
},
{
name: "append fails",
args: args{
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
return []*es_models.Event{&es_models.Event{}}, nil
},
appender: func(...*es_models.Event) error {
return errors.ThrowInvalidArgument(nil, "SDK-DhBzl", "test error")
},
},
wantErr: IsAppendEventError,
},
{
name: "filter correct",
args: args{
filter: func(context.Context, *es_models.SearchQuery) ([]*es_models.Event, error) {
return []*es_models.Event{&es_models.Event{}}, nil
},
appender: func(...*es_models.Event) error {
return nil
},
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Filter(context.Background(), tt.args.filter, tt.args.appender, nil)
if tt.wantErr == nil && err != nil {
t.Errorf("no error expected %v", err)
}
if tt.wantErr != nil && !tt.wantErr(err) {
t.Errorf("no error has wrong type %v", err)
}
})
}
}
func TestPush(t *testing.T) {
type args struct {
push pushFunc
appender appendFunc
aggregaters []AggregateFunc
}
tests := []struct {
name string
args args
wantErr func(error) bool
}{
{
name: "no aggregates",
args: args{
push: nil,
appender: nil,
aggregaters: nil,
},
wantErr: errors.IsPreconditionFailed,
},
{
name: "aggregater fails",
args: args{
push: nil,
appender: nil,
aggregaters: []AggregateFunc{
func(context.Context) (*es_models.Aggregate, error) {
return nil, errors.ThrowInternal(nil, "SDK-Ec5x2", "test err")
},
},
},
wantErr: errors.IsInternal,
},
{
name: "push fails",
args: args{
push: func(context.Context, ...*es_models.Aggregate) error {
return errors.ThrowInternal(nil, "SDK-0g4gW", "test error")
},
appender: nil,
aggregaters: []AggregateFunc{
func(context.Context) (*es_models.Aggregate, error) {
return &es_models.Aggregate{}, nil
},
},
},
wantErr: errors.IsInternal,
},
{
name: "append aggregates fails",
args: args{
push: func(context.Context, ...*es_models.Aggregate) error {
return nil
},
appender: func(...*es_models.Event) error {
return errors.ThrowInvalidArgument(nil, "SDK-BDhcT", "test err")
},
aggregaters: []AggregateFunc{
func(context.Context) (*es_models.Aggregate, error) {
return &es_models.Aggregate{Events: []*es_models.Event{&es_models.Event{}}}, nil
},
},
},
wantErr: IsAppendEventError,
},
{
name: "correct one aggregate",
args: args{
push: func(context.Context, ...*es_models.Aggregate) error {
return nil
},
appender: func(...*es_models.Event) error {
return nil
},
aggregaters: []AggregateFunc{
func(context.Context) (*es_models.Aggregate, error) {
return &es_models.Aggregate{Events: []*es_models.Event{&es_models.Event{}}}, nil
},
},
},
wantErr: nil,
},
{
name: "correct multiple aggregate",
args: args{
push: func(context.Context, ...*es_models.Aggregate) error {
return nil
},
appender: func(...*es_models.Event) error {
return nil
},
aggregaters: []AggregateFunc{
func(context.Context) (*es_models.Aggregate, error) {
return &es_models.Aggregate{Events: []*es_models.Event{&es_models.Event{}}}, nil
},
func(context.Context) (*es_models.Aggregate, error) {
return &es_models.Aggregate{Events: []*es_models.Event{&es_models.Event{}}}, nil
},
},
},
wantErr: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := Push(context.Background(), tt.args.push, tt.args.appender, tt.args.aggregaters...)
if tt.wantErr == nil && err != nil {
t.Errorf("no error expected %v", err)
}
if tt.wantErr != nil && !tt.wantErr(err) {
t.Errorf("no error has wrong type %v", err)
}
})
}
}

View File

@@ -0,0 +1,40 @@
package spooler
import (
"github.com/caos/zitadel/internal/eventstore/v1"
"math/rand"
"os"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/id"
)
type Config struct {
Eventstore v1.Eventstore
Locker Locker
ViewHandlers []query.Handler
ConcurrentWorkers int
}
func (c *Config) New() *Spooler {
lockID, err := os.Hostname()
if err != nil || lockID == "" {
lockID, err = id.SonyFlakeGenerator.Next()
logging.Log("SPOOL-bdO56").OnError(err).Panic("unable to generate lockID")
}
//shuffle the handlers for better balance when running multiple pods
rand.Shuffle(len(c.ViewHandlers), func(i, j int) {
c.ViewHandlers[i], c.ViewHandlers[j] = c.ViewHandlers[j], c.ViewHandlers[i]
})
return &Spooler{
handlers: c.ViewHandlers,
lockID: lockID,
eventstore: c.Eventstore,
locker: c.Locker,
queue: make(chan *spooledHandler, len(c.ViewHandlers)),
workers: c.ConcurrentWorkers,
}
}

View File

@@ -0,0 +1,3 @@
package spooler
//go:generate mockgen -source spooler.go -destination ./mock/spooler.go -package mock

View File

@@ -0,0 +1,48 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: spooler.go
// Package mock is a generated GoMock package.
package mock
import (
gomock "github.com/golang/mock/gomock"
reflect "reflect"
time "time"
)
// MockLocker is a mock of Locker interface
type MockLocker struct {
ctrl *gomock.Controller
recorder *MockLockerMockRecorder
}
// MockLockerMockRecorder is the mock recorder for MockLocker
type MockLockerMockRecorder struct {
mock *MockLocker
}
// NewMockLocker creates a new mock instance
func NewMockLocker(ctrl *gomock.Controller) *MockLocker {
mock := &MockLocker{ctrl: ctrl}
mock.recorder = &MockLockerMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockLocker) EXPECT() *MockLockerMockRecorder {
return m.recorder
}
// Renew mocks base method
func (m *MockLocker) Renew(lockerID, viewModel string, waitTime time.Duration) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Renew", lockerID, viewModel, waitTime)
ret0, _ := ret[0].(error)
return ret0
}
// Renew indicates an expected call of Renew
func (mr *MockLockerMockRecorder) Renew(lockerID, viewModel, waitTime interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Renew", reflect.TypeOf((*MockLocker)(nil).Renew), lockerID, viewModel, waitTime)
}

View File

@@ -0,0 +1,206 @@
package spooler
import (
"context"
"github.com/caos/zitadel/internal/eventstore/v1"
"strconv"
"sync"
"time"
"github.com/caos/logging"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/telemetry/tracing"
"github.com/caos/zitadel/internal/view/repository"
)
type Spooler struct {
handlers []query.Handler
locker Locker
lockID string
eventstore v1.Eventstore
workers int
queue chan *spooledHandler
}
type Locker interface {
Renew(lockerID, viewModel string, waitTime time.Duration) error
}
type spooledHandler struct {
query.Handler
locker Locker
queuedAt time.Time
eventstore v1.Eventstore
}
func (s *Spooler) Start() {
defer logging.LogWithFields("SPOOL-N0V1g", "lockerID", s.lockID, "workers", s.workers).Info("spooler started")
if s.workers < 1 {
return
}
for i := 0; i < s.workers; i++ {
go func(workerIdx int) {
workerID := s.lockID + "--" + strconv.Itoa(workerIdx)
for task := range s.queue {
go requeueTask(task, s.queue)
task.load(workerID)
}
}(i)
}
go func() {
for _, handler := range s.handlers {
s.queue <- &spooledHandler{Handler: handler, locker: s.locker, queuedAt: time.Now(), eventstore: s.eventstore}
}
}()
}
func requeueTask(task *spooledHandler, queue chan<- *spooledHandler) {
time.Sleep(task.MinimumCycleDuration() - time.Since(task.queuedAt))
task.queuedAt = time.Now()
queue <- task
}
func (s *spooledHandler) load(workerID string) {
errs := make(chan error)
defer close(errs)
ctx, cancel := context.WithCancel(context.Background())
go s.awaitError(cancel, errs, workerID)
hasLocked := s.lock(ctx, errs, workerID)
if <-hasLocked {
for {
events, err := s.query(ctx)
if err != nil {
errs <- err
break
}
err = s.process(ctx, events, workerID)
if err != nil {
errs <- err
break
}
if uint64(len(events)) < s.QueryLimit() {
// no more events to process
// stop chan
if ctx.Err() == nil {
errs <- nil
}
break
}
}
}
<-ctx.Done()
}
func (s *spooledHandler) awaitError(cancel func(), errs chan error, workerID string) {
select {
case err := <-errs:
cancel()
logging.Log("SPOOL-OT8di").OnError(err).WithField("view", s.ViewModel()).WithField("worker", workerID).Debug("load canceled")
}
}
func (s *spooledHandler) process(ctx context.Context, events []*models.Event, workerID string) error {
for i, event := range events {
select {
case <-ctx.Done():
logging.LogWithFields("SPOOL-FTKwH", "view", s.ViewModel(), "worker", workerID, "traceID", tracing.TraceIDFromCtx(ctx)).Debug("context canceled")
return nil
default:
if err := s.Reduce(event); err != nil {
err = s.OnError(event, err)
if err == nil {
continue
}
time.Sleep(100 * time.Millisecond)
return s.process(ctx, events[i:], workerID)
}
}
}
err := s.OnSuccess()
logging.LogWithFields("SPOOL-49ods", "view", s.ViewModel(), "worker", workerID, "traceID", tracing.TraceIDFromCtx(ctx)).OnError(err).Warn("could not process on success func")
return err
}
func (s *spooledHandler) query(ctx context.Context) ([]*models.Event, error) {
query, err := s.EventQuery()
if err != nil {
return nil, err
}
factory := models.FactoryFromSearchQuery(query)
sequence, err := s.eventstore.LatestSequence(ctx, factory)
logging.Log("SPOOL-7SciK").OnError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Debug("unable to query latest sequence")
var processedSequence uint64
for _, filter := range query.Filters {
if filter.GetField() == models.Field_LatestSequence {
processedSequence = filter.GetValue().(uint64)
}
}
if sequence != 0 && processedSequence == sequence {
return nil, nil
}
query.Limit = s.QueryLimit()
return s.eventstore.FilterEvents(ctx, query)
}
//lock ensures the lock on the database.
// the returned channel will be closed if ctx is done or an error occured durring lock
func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID string) chan bool {
renewTimer := time.After(0)
locked := make(chan bool)
go func(locked chan bool) {
var firstLock sync.Once
defer close(locked)
for {
select {
case <-ctx.Done():
return
case <-renewTimer:
err := s.locker.Renew(workerID, s.ViewModel(), s.LockDuration())
firstLock.Do(func() {
locked <- err == nil
})
if err == nil {
renewTimer = time.After(s.LockDuration())
continue
}
if ctx.Err() == nil {
errs <- err
}
return
}
}
}(locked)
return locked
}
func HandleError(event *models.Event, failedErr error,
latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error),
processFailedEvent func(*repository.FailedEvent) error,
processSequence func(*models.Event) error,
errorCountUntilSkip uint64) error {
failedEvent, err := latestFailedEvent(event.Sequence)
if err != nil {
return err
}
failedEvent.FailureCount++
failedEvent.ErrMsg = failedErr.Error()
err = processFailedEvent(failedEvent)
if err != nil {
return err
}
if errorCountUntilSkip <= failedEvent.FailureCount {
return processSequence(event)
}
return failedErr
}
func HandleSuccess(updateSpoolerRunTimestamp func() error) error {
return updateSpoolerRunTimestamp()
}

View File

@@ -0,0 +1,490 @@
package spooler
import (
"context"
"fmt"
"github.com/caos/zitadel/internal/eventstore"
"github.com/caos/zitadel/internal/eventstore/v1"
"testing"
"time"
"github.com/caos/zitadel/internal/errors"
"github.com/caos/zitadel/internal/eventstore/v1/models"
"github.com/caos/zitadel/internal/eventstore/v1/query"
"github.com/caos/zitadel/internal/eventstore/v1/spooler/mock"
"github.com/caos/zitadel/internal/view/repository"
"github.com/golang/mock/gomock"
)
type testHandler struct {
cycleDuration time.Duration
processSleep time.Duration
processError error
queryError error
viewModel string
bulkLimit uint64
maxErrCount int
}
func (h *testHandler) AggregateTypes() []models.AggregateType {
return nil
}
func (h *testHandler) CurrentSequence() (uint64, error) {
return 0, nil
}
func (h *testHandler) Eventstore() v1.Eventstore {
return nil
}
func (h *testHandler) ViewModel() string {
return h.viewModel
}
func (h *testHandler) EventQuery() (*models.SearchQuery, error) {
if h.queryError != nil {
return nil, h.queryError
}
return &models.SearchQuery{}, nil
}
func (h *testHandler) Reduce(*models.Event) error {
<-time.After(h.processSleep)
return h.processError
}
func (h *testHandler) OnError(event *models.Event, err error) error {
if h.maxErrCount == 2 {
return nil
}
h.maxErrCount++
return err
}
func (h *testHandler) OnSuccess() error {
return nil
}
func (h *testHandler) MinimumCycleDuration() time.Duration {
return h.cycleDuration
}
func (h *testHandler) LockDuration() time.Duration {
return h.cycleDuration / 2
}
func (h *testHandler) QueryLimit() uint64 {
return h.bulkLimit
}
type eventstoreStub struct {
events []*models.Event
err error
}
func (es *eventstoreStub) Subscribe(...models.AggregateType) *v1.Subscription { return nil }
func (es *eventstoreStub) Health(ctx context.Context) error {
return nil
}
func (es *eventstoreStub) AggregateCreator() *models.AggregateCreator {
return nil
}
func (es *eventstoreStub) FilterEvents(ctx context.Context, in *models.SearchQuery) ([]*models.Event, error) {
if es.err != nil {
return nil, es.err
}
return es.events, nil
}
func (es *eventstoreStub) PushAggregates(ctx context.Context, in ...*models.Aggregate) error {
return nil
}
func (es *eventstoreStub) LatestSequence(ctx context.Context, in *models.SearchQueryFactory) (uint64, error) {
return 0, nil
}
func (es *eventstoreStub) V2() *eventstore.Eventstore {
return nil
}
func TestSpooler_process(t *testing.T) {
type fields struct {
currentHandler *testHandler
}
type args struct {
timeout time.Duration
events []*models.Event
}
tests := []struct {
name string
fields fields
args args
wantErr bool
wantRetries int
}{
{
name: "process all events",
fields: fields{
currentHandler: &testHandler{},
},
args: args{
timeout: 0,
events: []*models.Event{{}, {}},
},
wantErr: false,
},
{
name: "deadline exeeded",
fields: fields{
currentHandler: &testHandler{processSleep: 501 * time.Millisecond},
},
args: args{
timeout: 1 * time.Second,
events: []*models.Event{{}, {}, {}, {}},
},
wantErr: false,
},
{
name: "process error",
fields: fields{
currentHandler: &testHandler{processSleep: 1 * time.Second, processError: fmt.Errorf("i am an error")},
},
args: args{
events: []*models.Event{{}, {}},
},
wantErr: false,
wantRetries: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &spooledHandler{
Handler: tt.fields.currentHandler,
}
ctx := context.Background()
var start time.Time
if tt.args.timeout > 0 {
ctx, _ = context.WithTimeout(ctx, tt.args.timeout)
start = time.Now()
}
if err := s.process(ctx, tt.args.events, "test"); (err != nil) != tt.wantErr {
t.Errorf("Spooler.process() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.fields.currentHandler.maxErrCount != tt.wantRetries {
t.Errorf("Spooler.process() wrong retry count got: %d want %d", tt.fields.currentHandler.maxErrCount, tt.wantRetries)
}
elapsed := time.Since(start).Round(1 * time.Second)
if tt.args.timeout != 0 && elapsed != tt.args.timeout {
t.Errorf("wrong timeout wanted %v elapsed %v since %v", tt.args.timeout, elapsed, time.Since(start))
}
})
}
}
func TestSpooler_awaitError(t *testing.T) {
type fields struct {
currentHandler query.Handler
err error
canceled bool
}
tests := []struct {
name string
fields fields
}{
{
"no error",
fields{
err: nil,
currentHandler: &testHandler{processSleep: 500 * time.Millisecond},
canceled: false,
},
},
{
"with error",
fields{
err: fmt.Errorf("hodor"),
currentHandler: &testHandler{processSleep: 500 * time.Millisecond},
canceled: false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
s := &spooledHandler{
Handler: tt.fields.currentHandler,
}
errs := make(chan error)
ctx, cancel := context.WithCancel(context.Background())
go s.awaitError(cancel, errs, "test")
errs <- tt.fields.err
if ctx.Err() == nil {
t.Error("cancel function was not called")
}
})
}
}
// TestSpooler_load checks if load terminates
func TestSpooler_load(t *testing.T) {
type fields struct {
currentHandler query.Handler
locker *testLocker
eventstore v1.Eventstore
}
tests := []struct {
name string
fields fields
}{
{
"lock exists",
fields{
currentHandler: &testHandler{processSleep: 500 * time.Millisecond, viewModel: "testView1", cycleDuration: 1 * time.Second, bulkLimit: 10},
locker: newTestLocker(t, "testID", "testView1").expectRenew(t, fmt.Errorf("lock already exists"), 500*time.Millisecond),
},
},
{
"lock fails",
fields{
currentHandler: &testHandler{processSleep: 100 * time.Millisecond, viewModel: "testView2", cycleDuration: 1 * time.Second, bulkLimit: 10},
locker: newTestLocker(t, "testID", "testView2").expectRenew(t, fmt.Errorf("fail"), 500*time.Millisecond),
eventstore: &eventstoreStub{events: []*models.Event{{}}},
},
},
{
"query fails",
fields{
currentHandler: &testHandler{processSleep: 100 * time.Millisecond, viewModel: "testView3", queryError: fmt.Errorf("query fail"), cycleDuration: 1 * time.Second, bulkLimit: 10},
locker: newTestLocker(t, "testID", "testView3").expectRenew(t, nil, 500*time.Millisecond),
eventstore: &eventstoreStub{err: fmt.Errorf("fail")},
},
},
{
"process event fails",
fields{
currentHandler: &testHandler{processError: fmt.Errorf("oups"), processSleep: 100 * time.Millisecond, viewModel: "testView4", cycleDuration: 500 * time.Millisecond, bulkLimit: 10},
locker: newTestLocker(t, "testID", "testView4").expectRenew(t, nil, 250*time.Millisecond),
eventstore: &eventstoreStub{events: []*models.Event{{}}},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer tt.fields.locker.finish()
s := &spooledHandler{
Handler: tt.fields.currentHandler,
locker: tt.fields.locker.mock,
eventstore: tt.fields.eventstore,
}
s.load("test-worker")
})
}
}
func TestSpooler_lock(t *testing.T) {
type fields struct {
currentHandler query.Handler
locker *testLocker
expectsErr bool
}
type args struct {
deadline time.Time
}
tests := []struct {
name string
fields fields
args args
}{
{
"renew correct",
fields{
currentHandler: &testHandler{cycleDuration: 1 * time.Second, viewModel: "testView"},
locker: newTestLocker(t, "testID", "testView").expectRenew(t, nil, 500*time.Millisecond),
expectsErr: false,
},
args{
deadline: time.Now().Add(1 * time.Second),
},
},
{
"renew fails",
fields{
currentHandler: &testHandler{cycleDuration: 900 * time.Millisecond, viewModel: "testView"},
locker: newTestLocker(t, "testID", "testView").expectRenew(t, fmt.Errorf("renew failed"), 450*time.Millisecond),
expectsErr: true,
},
args{
deadline: time.Now().Add(5 * time.Second),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
defer tt.fields.locker.finish()
s := &spooledHandler{
Handler: tt.fields.currentHandler,
locker: tt.fields.locker.mock,
}
errs := make(chan error, 1)
defer close(errs)
ctx, _ := context.WithDeadline(context.Background(), tt.args.deadline)
locked := s.lock(ctx, errs, "test-worker")
if tt.fields.expectsErr {
lock := <-locked
err := <-errs
if err == nil {
t.Error("No error in error queue")
}
if lock {
t.Error("lock should have failed")
}
} else {
lock := <-locked
if !lock {
t.Error("lock should be true")
}
}
})
}
}
type testLocker struct {
mock *mock.MockLocker
lockerID string
viewName string
ctrl *gomock.Controller
}
func newTestLocker(t *testing.T, lockerID, viewName string) *testLocker {
ctrl := gomock.NewController(t)
return &testLocker{mock.NewMockLocker(ctrl), lockerID, viewName, ctrl}
}
func (l *testLocker) expectRenew(t *testing.T, err error, waitTime time.Duration) *testLocker {
t.Helper()
l.mock.EXPECT().Renew(gomock.Any(), l.viewName, gomock.Any()).DoAndReturn(
func(_, _ string, gotten time.Duration) error {
t.Helper()
if waitTime-gotten != 0 {
t.Errorf("expected waittime %v got %v", waitTime, gotten)
}
return err
}).MinTimes(1).MaxTimes(3)
return l
}
func (l *testLocker) finish() {
l.ctrl.Finish()
}
func TestHandleError(t *testing.T) {
type args struct {
event *models.Event
failedErr error
latestFailedEvent func(sequence uint64) (*repository.FailedEvent, error)
errorCountUntilSkip uint64
}
type res struct {
wantErr bool
shouldProcessSequence bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "should process sequence already too high",
args: args{
event: &models.Event{Sequence: 30000000},
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
return &repository.FailedEvent{
ErrMsg: "blub",
FailedSequence: s - 1,
FailureCount: 6,
ViewName: "super.table",
}, nil
},
errorCountUntilSkip: 5,
},
res: res{
shouldProcessSequence: true,
},
},
{
name: "should process sequence after this event too high",
args: args{
event: &models.Event{Sequence: 30000000},
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
return &repository.FailedEvent{
ErrMsg: "blub",
FailedSequence: s - 1,
FailureCount: 5,
ViewName: "super.table",
}, nil
},
errorCountUntilSkip: 6,
},
res: res{
shouldProcessSequence: true,
},
},
{
name: "should not process sequence",
args: args{
event: &models.Event{Sequence: 30000000},
failedErr: errors.ThrowInternal(nil, "SPOOL-Wk53B", "this was wrong"),
latestFailedEvent: func(s uint64) (*repository.FailedEvent, error) {
return &repository.FailedEvent{
ErrMsg: "blub",
FailedSequence: s - 1,
FailureCount: 3,
ViewName: "super.table",
}, nil
},
errorCountUntilSkip: 5,
},
res: res{
shouldProcessSequence: false,
wantErr: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
processedSequence := false
err := HandleError(
tt.args.event,
tt.args.failedErr,
tt.args.latestFailedEvent,
func(*repository.FailedEvent) error {
return nil
},
func(*models.Event) error {
processedSequence = true
return nil
},
tt.args.errorCountUntilSkip)
if (err != nil) != tt.res.wantErr {
t.Errorf("HandleError() error = %v, wantErr %v", err, tt.res.wantErr)
}
if tt.res.shouldProcessSequence != processedSequence {
t.Error("should not process sequence")
}
})
}
}

View File

@@ -0,0 +1,73 @@
package v1
import (
"sync"
"github.com/caos/zitadel/internal/eventstore/v1/models"
)
var (
subscriptions map[models.AggregateType][]*Subscription = map[models.AggregateType][]*Subscription{}
subsMutext sync.Mutex
)
type Subscription struct {
Events chan *models.Event
aggregates []models.AggregateType
}
func (es *eventstore) Subscribe(aggregates ...models.AggregateType) *Subscription {
events := make(chan *models.Event, 100)
sub := &Subscription{
Events: events,
aggregates: aggregates,
}
subsMutext.Lock()
defer subsMutext.Unlock()
for _, aggregate := range aggregates {
_, ok := subscriptions[aggregate]
if !ok {
subscriptions[aggregate] = make([]*Subscription, 0, 1)
}
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
}
return sub
}
func notify(aggregates []*models.Aggregate) {
subsMutext.Lock()
defer subsMutext.Unlock()
for _, aggregate := range aggregates {
subs, ok := subscriptions[aggregate.Type()]
if !ok {
continue
}
for _, sub := range subs {
for _, event := range aggregate.Events {
sub.Events <- event
}
}
}
}
func (s *Subscription) Unsubscribe() {
subsMutext.Lock()
defer subsMutext.Unlock()
for _, aggregate := range s.aggregates {
subs, ok := subscriptions[aggregate]
if !ok {
continue
}
for i := len(subs) - 1; i >= 0; i-- {
if subs[i] == s {
subs[i] = subs[len(subs)-1]
subs[len(subs)-1] = nil
subs = subs[:len(subs)-1]
}
}
}
close(s.Events)
}