mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 03:57:32 +00:00
chore: move the go code into a subfolder
This commit is contained in:
24
apps/api/internal/eventstore/repository/asset.go
Normal file
24
apps/api/internal/eventstore/repository/asset.go
Normal file
@@ -0,0 +1,24 @@
|
||||
package repository
|
||||
|
||||
//Asset represents all information about a asset (img)
|
||||
type Asset struct {
|
||||
// ID is to refer to the asset
|
||||
ID string
|
||||
//Asset is the actual image
|
||||
Asset []byte
|
||||
//Action defines if asset should be added or removed
|
||||
Action AssetAction
|
||||
}
|
||||
|
||||
type AssetAction int32
|
||||
|
||||
const (
|
||||
AssetAdded AssetAction = iota
|
||||
AssetRemoved
|
||||
|
||||
assetCount
|
||||
)
|
||||
|
||||
func (f AssetAction) Valid() bool {
|
||||
return f >= 0 && f < assetCount
|
||||
}
|
133
apps/api/internal/eventstore/repository/event.go
Normal file
133
apps/api/internal/eventstore/repository/event.go
Normal file
@@ -0,0 +1,133 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var _ eventstore.Event = (*Event)(nil)
|
||||
|
||||
// Event represents all information about a manipulation of an aggregate
|
||||
type Event struct {
|
||||
//ID is a generated uuid for this event
|
||||
ID string
|
||||
|
||||
// Seq is the sequence of the event
|
||||
Seq uint64
|
||||
// Pos is the global sequence of the event multiple events can have the same sequence
|
||||
Pos decimal.Decimal
|
||||
|
||||
//CreationDate is the time the event is created
|
||||
// it's used for human readability.
|
||||
// Don't use it for event ordering,
|
||||
// time drifts in different services could cause integrity problems
|
||||
CreationDate time.Time
|
||||
|
||||
// Typ describes the cause of the event (e.g. user.added)
|
||||
// it should always be in past-form
|
||||
Typ eventstore.EventType
|
||||
|
||||
//Data describe the changed fields (e.g. userName = "hodor")
|
||||
// data must always a pointer to a struct, a struct or a byte array containing json bytes
|
||||
Data []byte
|
||||
|
||||
//EditorUser should be a unique identifier for the user which created the event
|
||||
// it's meant for maintainability.
|
||||
// It's recommend to use the aggregate id of the user
|
||||
EditorUser string
|
||||
|
||||
//Version describes the definition of the aggregate at a certain point in time
|
||||
// it's used in read models to reduce the events in the correct definition
|
||||
Version eventstore.Version
|
||||
//AggregateID id is the unique identifier of the aggregate
|
||||
// the client must generate it by it's own
|
||||
AggregateID string
|
||||
//AggregateType describes the meaning of the aggregate for this event
|
||||
// it could an object like user
|
||||
AggregateType eventstore.AggregateType
|
||||
//ResourceOwner is the organisation which owns this aggregate
|
||||
// an aggregate can only be managed by one organisation
|
||||
// use the ID of the org
|
||||
ResourceOwner sql.NullString
|
||||
//InstanceID is the instance where this event belongs to
|
||||
// use the ID of the instance
|
||||
InstanceID string
|
||||
|
||||
Constraints []*eventstore.UniqueConstraint
|
||||
}
|
||||
|
||||
// Aggregate implements [eventstore.Event]
|
||||
func (e *Event) Aggregate() *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
ID: e.AggregateID,
|
||||
Type: e.AggregateType,
|
||||
ResourceOwner: e.ResourceOwner.String,
|
||||
InstanceID: e.InstanceID,
|
||||
Version: e.Version,
|
||||
}
|
||||
}
|
||||
|
||||
// Creator implements [eventstore.Event]
|
||||
func (e *Event) Creator() string {
|
||||
return e.EditorUser
|
||||
}
|
||||
|
||||
// Type implements [eventstore.Event]
|
||||
func (e *Event) Type() eventstore.EventType {
|
||||
return e.Typ
|
||||
}
|
||||
|
||||
// Revision implements [eventstore.Event]
|
||||
func (e *Event) Revision() uint16 {
|
||||
revision, err := strconv.ParseUint(strings.TrimPrefix(string(e.Version), "v"), 10, 16)
|
||||
logging.OnError(err).Debug("failed to parse event revision")
|
||||
return uint16(revision)
|
||||
}
|
||||
|
||||
// Sequence implements [eventstore.Event]
|
||||
func (e *Event) Sequence() uint64 {
|
||||
return e.Seq
|
||||
}
|
||||
|
||||
// Position implements [eventstore.Event]
|
||||
func (e *Event) Position() decimal.Decimal {
|
||||
return e.Pos
|
||||
}
|
||||
|
||||
// CreatedAt implements [eventstore.Event]
|
||||
func (e *Event) CreatedAt() time.Time {
|
||||
return e.CreationDate
|
||||
}
|
||||
|
||||
// Unmarshal implements [eventstore.Event]
|
||||
func (e *Event) Unmarshal(ptr any) error {
|
||||
if len(e.Data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(e.Data, ptr)
|
||||
}
|
||||
|
||||
// DataAsBytes implements [eventstore.Event]
|
||||
func (e *Event) DataAsBytes() []byte {
|
||||
return e.Data
|
||||
}
|
||||
|
||||
func (e *Event) Payload() any {
|
||||
return e.Data
|
||||
}
|
||||
|
||||
func (e *Event) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
return e.Constraints
|
||||
}
|
||||
|
||||
func (e *Event) Fields() []*eventstore.FieldOperation {
|
||||
return nil
|
||||
}
|
3
apps/api/internal/eventstore/repository/mock/gen_mock.go
Normal file
3
apps/api/internal/eventstore/repository/mock/gen_mock.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package mock
|
||||
|
||||
//go:generate mockgen -package mock -destination ./repository.mock.go github.com/zitadel/zitadel/internal/eventstore Querier,Pusher
|
186
apps/api/internal/eventstore/repository/mock/repository.mock.go
Normal file
186
apps/api/internal/eventstore/repository/mock/repository.mock.go
Normal file
@@ -0,0 +1,186 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/zitadel/zitadel/internal/eventstore (interfaces: Querier,Pusher)
|
||||
//
|
||||
// Generated by this command:
|
||||
//
|
||||
// mockgen -package mock -destination ./repository.mock.go github.com/zitadel/zitadel/internal/eventstore Querier,Pusher
|
||||
//
|
||||
|
||||
// Package mock is a generated GoMock package.
|
||||
package mock
|
||||
|
||||
import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
decimal "github.com/shopspring/decimal"
|
||||
database "github.com/zitadel/zitadel/internal/database"
|
||||
eventstore "github.com/zitadel/zitadel/internal/eventstore"
|
||||
gomock "go.uber.org/mock/gomock"
|
||||
)
|
||||
|
||||
// MockQuerier is a mock of Querier interface.
|
||||
type MockQuerier struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockQuerierMockRecorder
|
||||
}
|
||||
|
||||
// MockQuerierMockRecorder is the mock recorder for MockQuerier.
|
||||
type MockQuerierMockRecorder struct {
|
||||
mock *MockQuerier
|
||||
}
|
||||
|
||||
// NewMockQuerier creates a new mock instance.
|
||||
func NewMockQuerier(ctrl *gomock.Controller) *MockQuerier {
|
||||
mock := &MockQuerier{ctrl: ctrl}
|
||||
mock.recorder = &MockQuerierMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockQuerier) EXPECT() *MockQuerierMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Client mocks base method.
|
||||
func (m *MockQuerier) Client() *database.DB {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Client")
|
||||
ret0, _ := ret[0].(*database.DB)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Client indicates an expected call of Client.
|
||||
func (mr *MockQuerierMockRecorder) Client() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Client", reflect.TypeOf((*MockQuerier)(nil).Client))
|
||||
}
|
||||
|
||||
// FilterToReducer mocks base method.
|
||||
func (m *MockQuerier) FilterToReducer(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder, arg2 eventstore.Reducer) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "FilterToReducer", arg0, arg1, arg2)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// FilterToReducer indicates an expected call of FilterToReducer.
|
||||
func (mr *MockQuerierMockRecorder) FilterToReducer(arg0, arg1, arg2 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterToReducer", reflect.TypeOf((*MockQuerier)(nil).FilterToReducer), arg0, arg1, arg2)
|
||||
}
|
||||
|
||||
// Health mocks base method.
|
||||
func (m *MockQuerier) Health(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Health", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Health indicates an expected call of Health.
|
||||
func (mr *MockQuerierMockRecorder) Health(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockQuerier)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// InstanceIDs mocks base method.
|
||||
func (m *MockQuerier) InstanceIDs(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder) ([]string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "InstanceIDs", arg0, arg1)
|
||||
ret0, _ := ret[0].([]string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// InstanceIDs indicates an expected call of InstanceIDs.
|
||||
func (mr *MockQuerierMockRecorder) InstanceIDs(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceIDs", reflect.TypeOf((*MockQuerier)(nil).InstanceIDs), arg0, arg1)
|
||||
}
|
||||
|
||||
// LatestPosition mocks base method.
|
||||
func (m *MockQuerier) LatestPosition(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder) (decimal.Decimal, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "LatestPosition", arg0, arg1)
|
||||
ret0, _ := ret[0].(decimal.Decimal)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// LatestPosition indicates an expected call of LatestPosition.
|
||||
func (mr *MockQuerierMockRecorder) LatestPosition(arg0, arg1 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestPosition", reflect.TypeOf((*MockQuerier)(nil).LatestPosition), arg0, arg1)
|
||||
}
|
||||
|
||||
// MockPusher is a mock of Pusher interface.
|
||||
type MockPusher struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockPusherMockRecorder
|
||||
}
|
||||
|
||||
// MockPusherMockRecorder is the mock recorder for MockPusher.
|
||||
type MockPusherMockRecorder struct {
|
||||
mock *MockPusher
|
||||
}
|
||||
|
||||
// NewMockPusher creates a new mock instance.
|
||||
func NewMockPusher(ctrl *gomock.Controller) *MockPusher {
|
||||
mock := &MockPusher{ctrl: ctrl}
|
||||
mock.recorder = &MockPusherMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockPusher) EXPECT() *MockPusherMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Client mocks base method.
|
||||
func (m *MockPusher) Client() *database.DB {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Client")
|
||||
ret0, _ := ret[0].(*database.DB)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Client indicates an expected call of Client.
|
||||
func (mr *MockPusherMockRecorder) Client() *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Client", reflect.TypeOf((*MockPusher)(nil).Client))
|
||||
}
|
||||
|
||||
// Health mocks base method.
|
||||
func (m *MockPusher) Health(arg0 context.Context) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Health", arg0)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// Health indicates an expected call of Health.
|
||||
func (mr *MockPusherMockRecorder) Health(arg0 any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockPusher)(nil).Health), arg0)
|
||||
}
|
||||
|
||||
// Push mocks base method.
|
||||
func (m *MockPusher) Push(arg0 context.Context, arg1 database.ContextQueryExecuter, arg2 ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
varargs = append(varargs, a)
|
||||
}
|
||||
ret := m.ctrl.Call(m, "Push", varargs...)
|
||||
ret0, _ := ret[0].([]eventstore.Event)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Push indicates an expected call of Push.
|
||||
func (mr *MockPusherMockRecorder) Push(arg0, arg1 any, arg2 ...any) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
varargs := append([]any{arg0, arg1}, arg2...)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockPusher)(nil).Push), varargs...)
|
||||
}
|
@@ -0,0 +1,235 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
type MockRepository struct {
|
||||
*MockPusher
|
||||
*MockQuerier
|
||||
}
|
||||
|
||||
func NewRepo(t *testing.T) *MockRepository {
|
||||
controller := gomock.NewController(t)
|
||||
return &MockRepository{
|
||||
MockPusher: NewMockPusher(controller),
|
||||
MockQuerier: NewMockQuerier(controller),
|
||||
}
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectFilterNoEventsNoError() *MockRepository {
|
||||
m.MockQuerier.ctrl.T.Helper()
|
||||
|
||||
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectFilterEvents(events ...eventstore.Event) *MockRepository {
|
||||
m.MockQuerier.ctrl.T.Helper()
|
||||
|
||||
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(_ context.Context, _ *eventstore.SearchQueryBuilder, reduce eventstore.Reducer) error {
|
||||
for _, event := range events {
|
||||
if err := reduce(event); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
},
|
||||
)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectFilterEventsError(err error) *MockRepository {
|
||||
m.MockQuerier.ctrl.T.Helper()
|
||||
|
||||
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).Return(err)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectInstanceIDs(hasFilters []*repository.Filter, instanceIDs ...string) *MockRepository {
|
||||
m.MockQuerier.ctrl.T.Helper()
|
||||
|
||||
matcher := gomock.Any()
|
||||
if len(hasFilters) > 0 {
|
||||
matcher = &filterQueryMatcher{SubQueries: [][]*repository.Filter{hasFilters}}
|
||||
}
|
||||
m.MockQuerier.EXPECT().InstanceIDs(gomock.Any(), matcher).Return(instanceIDs, nil)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectInstanceIDsError(err error) *MockRepository {
|
||||
m.MockQuerier.ctrl.T.Helper()
|
||||
|
||||
m.MockQuerier.EXPECT().InstanceIDs(gomock.Any(), gomock.Any()).Return(nil, err)
|
||||
return m
|
||||
}
|
||||
|
||||
// ExpectPush checks if the expectedCommands are send to the Push method.
|
||||
// The call will sleep at least the amount of passed duration.
|
||||
func (m *MockRepository) ExpectPush(expectedCommands []eventstore.Command, sleep time.Duration) *MockRepository {
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
m.MockPusher.ctrl.T.Helper()
|
||||
|
||||
time.Sleep(sleep)
|
||||
|
||||
if len(expectedCommands) != len(commands) {
|
||||
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
|
||||
}
|
||||
for i, expectedCommand := range expectedCommands {
|
||||
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Aggregate(), commands[i].Aggregate()) {
|
||||
m.MockPusher.ctrl.T.Errorf("invalid command.Aggregate [%d]: expected: %#v got: %#v", i, expectedCommand.Aggregate(), commands[i].Aggregate())
|
||||
}
|
||||
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Creator(), commands[i].Creator()) {
|
||||
m.MockPusher.ctrl.T.Errorf("invalid command.Creator [%d]: expected: %#v got: %#v", i, expectedCommand.Creator(), commands[i].Creator())
|
||||
}
|
||||
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Type(), commands[i].Type()) {
|
||||
m.MockPusher.ctrl.T.Errorf("invalid command.Type [%d]: expected: %#v got: %#v", i, expectedCommand.Type(), commands[i].Type())
|
||||
}
|
||||
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Revision(), commands[i].Revision()) {
|
||||
m.MockPusher.ctrl.T.Errorf("invalid command.Revision [%d]: expected: %#v got: %#v", i, expectedCommand.Revision(), commands[i].Revision())
|
||||
}
|
||||
|
||||
var expectedPayload []byte
|
||||
expectedPayload, ok := expectedCommand.Payload().([]byte)
|
||||
if !ok {
|
||||
expectedPayload, _ = json.Marshal(expectedCommand.Payload())
|
||||
}
|
||||
if string(expectedPayload) == "" {
|
||||
expectedPayload = []byte("null")
|
||||
}
|
||||
gotPayload, _ := json.Marshal(commands[i].Payload())
|
||||
|
||||
if !assert.Equal(m.MockPusher.ctrl.T, expectedPayload, gotPayload) {
|
||||
m.MockPusher.ctrl.T.Errorf("invalid command.Payload [%d]: expected: %#v got: %#v", i, expectedCommand.Payload(), commands[i].Payload())
|
||||
}
|
||||
if !assert.ElementsMatch(m.MockPusher.ctrl.T, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints()) {
|
||||
m.MockPusher.ctrl.T.Errorf("invalid command.UniqueConstraints [%d]: expected: %#v got: %#v", i, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints())
|
||||
}
|
||||
}
|
||||
events := make([]eventstore.Event, len(commands))
|
||||
for i, command := range commands {
|
||||
events[i] = &mockEvent{
|
||||
Command: command,
|
||||
}
|
||||
}
|
||||
return events, nil
|
||||
},
|
||||
)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectPushFailed(err error, expectedCommands []eventstore.Command) *MockRepository {
|
||||
m.MockPusher.ctrl.T.Helper()
|
||||
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
if len(expectedCommands) != len(commands) {
|
||||
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
|
||||
}
|
||||
for i, expectedCommand := range expectedCommands {
|
||||
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Aggregate(), commands[i].Aggregate())
|
||||
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Creator(), commands[i].Creator())
|
||||
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Type(), commands[i].Type())
|
||||
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Revision(), commands[i].Revision())
|
||||
var expectedPayload []byte
|
||||
expectedPayload, ok := expectedCommand.Payload().([]byte)
|
||||
if !ok {
|
||||
expectedPayload, _ = json.Marshal(expectedCommand.Payload())
|
||||
}
|
||||
if string(expectedPayload) == "" {
|
||||
expectedPayload = []byte("null")
|
||||
}
|
||||
gotPayload, _ := json.Marshal(commands[i].Payload())
|
||||
|
||||
assert.Equal(m.MockPusher.ctrl.T, expectedPayload, gotPayload)
|
||||
assert.ElementsMatch(m.MockPusher.ctrl.T, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints())
|
||||
}
|
||||
|
||||
return nil, err
|
||||
},
|
||||
)
|
||||
return m
|
||||
}
|
||||
|
||||
type mockEvent struct {
|
||||
eventstore.Command
|
||||
sequence uint64
|
||||
createdAt time.Time
|
||||
}
|
||||
|
||||
// DataAsBytes implements eventstore.Event
|
||||
func (e *mockEvent) DataAsBytes() []byte {
|
||||
if e.Payload() == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(e.Payload())
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
func (e *mockEvent) Unmarshal(ptr any) error {
|
||||
if e.Payload() == nil {
|
||||
return nil
|
||||
}
|
||||
payload, err := json.Marshal(e.Payload())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return json.Unmarshal(payload, ptr)
|
||||
}
|
||||
|
||||
func (e *mockEvent) Sequence() uint64 {
|
||||
return e.sequence
|
||||
}
|
||||
|
||||
func (e *mockEvent) Position() decimal.Decimal {
|
||||
return decimal.Decimal{}
|
||||
}
|
||||
|
||||
func (e *mockEvent) CreatedAt() time.Time {
|
||||
return e.createdAt
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command) *MockRepository {
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
assert.Len(m.MockPusher.ctrl.T, commands, len(expectedCommands))
|
||||
|
||||
events := make([]eventstore.Event, len(commands))
|
||||
for i, command := range commands {
|
||||
events[i] = &mockEvent{
|
||||
Command: command,
|
||||
}
|
||||
}
|
||||
|
||||
return events, nil
|
||||
},
|
||||
)
|
||||
return m
|
||||
}
|
||||
|
||||
func (m *MockRepository) ExpectRandomPushFailed(err error, expectedEvents []eventstore.Command) *MockRepository {
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, events ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
assert.Len(m.MockPusher.ctrl.T, events, len(expectedEvents))
|
||||
return nil, err
|
||||
},
|
||||
)
|
||||
return m
|
||||
}
|
@@ -0,0 +1,33 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"go.uber.org/mock/gomock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
var _ gomock.Matcher = (*filterMatcher)(nil)
|
||||
var _ gomock.GotFormatter = (*filterMatcher)(nil)
|
||||
|
||||
type filterMatcher repository.Filter
|
||||
|
||||
func (f *filterMatcher) String() string {
|
||||
jsonValue, err := json.Marshal(f.Value)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return fmt.Sprintf("%d %d (content=%+v,type=%T,json=%s)", f.Field, f.Operation, f.Value, f.Value, string(jsonValue))
|
||||
}
|
||||
|
||||
func (f *filterMatcher) Matches(x interface{}) bool {
|
||||
other := x.(*repository.Filter)
|
||||
return f.Field == other.Field && f.Operation == other.Operation && reflect.DeepEqual(f.Value, other.Value)
|
||||
}
|
||||
|
||||
func (f *filterMatcher) Got(got interface{}) string {
|
||||
return (*filterMatcher)(got.(*repository.Filter)).String()
|
||||
}
|
@@ -0,0 +1,45 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
type filterQueryMatcher repository.SearchQuery
|
||||
|
||||
func (f *filterQueryMatcher) String() string {
|
||||
var filterLists []string
|
||||
for _, filterSlice := range f.SubQueries {
|
||||
var str string
|
||||
for _, filter := range filterSlice {
|
||||
str += "," + (*filterMatcher)(filter).String()
|
||||
}
|
||||
filterLists = append(filterLists, fmt.Sprintf("[%s]", strings.TrimPrefix(str, ",")))
|
||||
|
||||
}
|
||||
return fmt.Sprintf("Filters: %s", strings.Join(filterLists, " "))
|
||||
}
|
||||
|
||||
func (f *filterQueryMatcher) Matches(x interface{}) bool {
|
||||
other := x.(*repository.SearchQuery)
|
||||
if len(f.SubQueries) != len(other.SubQueries) {
|
||||
return false
|
||||
}
|
||||
for filterSliceIdx, filterSlice := range f.SubQueries {
|
||||
if len(filterSlice) != len(other.SubQueries[filterSliceIdx]) {
|
||||
return false
|
||||
}
|
||||
for filterIdx, filter := range f.SubQueries[filterSliceIdx] {
|
||||
if !(*filterMatcher)(filter).Matches(other.SubQueries[filterSliceIdx][filterIdx]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func (f *filterQueryMatcher) Got(got interface{}) string {
|
||||
return (*filterQueryMatcher)(got.(*repository.SearchQuery)).String()
|
||||
}
|
326
apps/api/internal/eventstore/repository/search_query.go
Normal file
326
apps/api/internal/eventstore/repository/search_query.go
Normal file
@@ -0,0 +1,326 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
// SearchQuery defines the which and how data are queried
|
||||
type SearchQuery struct {
|
||||
Columns eventstore.Columns
|
||||
|
||||
SubQueries [][]*Filter
|
||||
Tx *sql.Tx
|
||||
LockRows bool
|
||||
LockOption eventstore.LockOption
|
||||
AwaitOpenTransactions bool
|
||||
Limit uint64
|
||||
Offset uint32
|
||||
Desc bool
|
||||
|
||||
InstanceID *Filter
|
||||
InstanceIDs *Filter
|
||||
ExcludedInstances *Filter
|
||||
Creator *Filter
|
||||
Owner *Filter
|
||||
Position *Filter
|
||||
Sequence *Filter
|
||||
CreatedAfter *Filter
|
||||
CreatedBefore *Filter
|
||||
ExcludeAggregateIDs []*Filter
|
||||
}
|
||||
|
||||
// Filter represents all fields needed to compare a field of an event with a value
|
||||
type Filter struct {
|
||||
Field Field
|
||||
Value interface{}
|
||||
Operation Operation
|
||||
}
|
||||
|
||||
// Operation defines how fields are compared
|
||||
type Operation int32
|
||||
|
||||
const (
|
||||
// OperationEquals compares two values for equality
|
||||
OperationEquals Operation = iota + 1
|
||||
// OperationGreater compares if the given values is greater than the stored one
|
||||
OperationGreater
|
||||
// OperationLess compares if the given values is less than the stored one
|
||||
OperationLess
|
||||
// OperationIn checks if a stored value matches one of the passed value list
|
||||
OperationIn
|
||||
// OperationJSONContains checks if a stored value matches the given json
|
||||
OperationJSONContains
|
||||
// OperationNotIn checks if a stored value does not match one of the passed value list
|
||||
OperationNotIn
|
||||
|
||||
OperationGreaterOrEquals
|
||||
|
||||
operationCount
|
||||
)
|
||||
|
||||
// Field is the representation of a field from the event
|
||||
type Field int32
|
||||
|
||||
const (
|
||||
// FieldAggregateType represents the aggregate type field
|
||||
FieldAggregateType Field = iota + 1
|
||||
// FieldAggregateID represents the aggregate id field
|
||||
FieldAggregateID
|
||||
// FieldSequence represents the sequence field
|
||||
FieldSequence
|
||||
// FieldResourceOwner represents the resource owner field
|
||||
FieldResourceOwner
|
||||
// FieldInstanceID represents the instance id field
|
||||
FieldInstanceID
|
||||
// FieldEditorService represents the editor service field
|
||||
FieldEditorService
|
||||
// FieldEditorUser represents the editor user field
|
||||
FieldEditorUser
|
||||
// FieldEventType represents the event type field
|
||||
FieldEventType
|
||||
// FieldEventData represents the event data field
|
||||
FieldEventData
|
||||
// FieldCreationDate represents the creation date field
|
||||
FieldCreationDate
|
||||
// FieldPosition represents the field of the global sequence
|
||||
FieldPosition
|
||||
|
||||
fieldCount
|
||||
)
|
||||
|
||||
// NewFilter is used in tests. Use searchQuery.*Filter() instead
|
||||
func NewFilter(field Field, value interface{}, operation Operation) *Filter {
|
||||
return &Filter{
|
||||
Field: field,
|
||||
Value: value,
|
||||
Operation: operation,
|
||||
}
|
||||
}
|
||||
|
||||
// Validate checks if the fields of the filter have valid values
|
||||
func (f *Filter) Validate() error {
|
||||
if f == nil {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "REPO-z6KcG", "filter is nil")
|
||||
}
|
||||
if f.Field <= 0 || f.Field >= fieldCount {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "REPO-zw62U", "field not definded")
|
||||
}
|
||||
if f.Value == nil {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "REPO-GJ9ct", "no value definded")
|
||||
}
|
||||
if f.Operation <= 0 || f.Operation >= operationCount {
|
||||
return zerrors.ThrowPreconditionFailed(nil, "REPO-RrQTy", "operation not definded")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func QueryFromBuilder(builder *eventstore.SearchQueryBuilder) (*SearchQuery, error) {
|
||||
if builder == nil ||
|
||||
builder.GetColumns().Validate() != nil {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "MODEL-4m9gs", "builder invalid")
|
||||
}
|
||||
|
||||
query := &SearchQuery{
|
||||
Columns: builder.GetColumns(),
|
||||
Limit: builder.GetLimit(),
|
||||
Offset: builder.GetOffset(),
|
||||
Desc: builder.GetDesc(),
|
||||
Tx: builder.GetTx(),
|
||||
AwaitOpenTransactions: builder.GetAwaitOpenTransactions(),
|
||||
SubQueries: make([][]*Filter, len(builder.GetQueries())),
|
||||
}
|
||||
query.LockRows, query.LockOption = builder.GetLockRows()
|
||||
|
||||
for _, f := range []func(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter{
|
||||
instanceIDFilter,
|
||||
instanceIDsFilter,
|
||||
editorUserFilter,
|
||||
resourceOwnerFilter,
|
||||
positionAfterFilter,
|
||||
eventSequenceGreaterFilter,
|
||||
creationDateAfterFilter,
|
||||
creationDateBeforeFilter,
|
||||
} {
|
||||
filter := f(builder, query)
|
||||
if filter == nil {
|
||||
continue
|
||||
}
|
||||
if err := filter.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
for i, q := range builder.GetQueries() {
|
||||
for _, f := range []func(query *eventstore.SearchQuery) *Filter{
|
||||
aggregateTypeFilter,
|
||||
aggregateIDFilter,
|
||||
eventTypeFilter,
|
||||
eventDataFilter,
|
||||
eventPositionAfterFilter,
|
||||
} {
|
||||
filter := f(q)
|
||||
if filter == nil {
|
||||
continue
|
||||
}
|
||||
if err := filter.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query.SubQueries[i] = append(query.SubQueries[i], filter)
|
||||
}
|
||||
}
|
||||
if excludeAggregateIDs := builder.GetExcludeAggregateIDs(); excludeAggregateIDs != nil {
|
||||
for _, f := range []func(query *eventstore.ExclusionQuery) *Filter{
|
||||
excludeAggregateTypeFilter,
|
||||
excludeEventTypeFilter,
|
||||
} {
|
||||
filter := f(excludeAggregateIDs)
|
||||
if filter == nil {
|
||||
continue
|
||||
}
|
||||
if err := filter.Validate(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
query.ExcludeAggregateIDs = append(query.ExcludeAggregateIDs, filter)
|
||||
}
|
||||
}
|
||||
|
||||
return query, nil
|
||||
}
|
||||
|
||||
func eventSequenceGreaterFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetEventSequenceGreater() == 0 {
|
||||
return nil
|
||||
}
|
||||
sortOrder := OperationGreater
|
||||
if builder.GetDesc() {
|
||||
sortOrder = OperationLess
|
||||
}
|
||||
query.Sequence = NewFilter(FieldSequence, builder.GetEventSequenceGreater(), sortOrder)
|
||||
return query.Sequence
|
||||
}
|
||||
|
||||
func creationDateAfterFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetCreationDateAfter().IsZero() {
|
||||
return nil
|
||||
}
|
||||
query.CreatedAfter = NewFilter(FieldCreationDate, builder.GetCreationDateAfter(), OperationGreater)
|
||||
return query.CreatedAfter
|
||||
}
|
||||
|
||||
func creationDateBeforeFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetCreationDateBefore().IsZero() {
|
||||
return nil
|
||||
}
|
||||
query.CreatedBefore = NewFilter(FieldCreationDate, builder.GetCreationDateBefore(), OperationLess)
|
||||
return query.CreatedBefore
|
||||
}
|
||||
|
||||
func resourceOwnerFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetResourceOwner() == "" {
|
||||
return nil
|
||||
}
|
||||
query.Owner = NewFilter(FieldResourceOwner, builder.GetResourceOwner(), OperationEquals)
|
||||
return query.Owner
|
||||
}
|
||||
|
||||
func editorUserFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetEditorUser() == "" {
|
||||
return nil
|
||||
}
|
||||
query.Creator = NewFilter(FieldEditorUser, builder.GetEditorUser(), OperationEquals)
|
||||
return query.Creator
|
||||
}
|
||||
|
||||
func instanceIDFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetInstanceID() == nil {
|
||||
return nil
|
||||
}
|
||||
query.InstanceID = NewFilter(FieldInstanceID, *builder.GetInstanceID(), OperationEquals)
|
||||
return query.InstanceID
|
||||
}
|
||||
|
||||
func instanceIDsFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetInstanceIDs() == nil {
|
||||
return nil
|
||||
}
|
||||
query.InstanceIDs = NewFilter(FieldInstanceID, database.TextArray[string](builder.GetInstanceIDs()), OperationIn)
|
||||
return query.InstanceIDs
|
||||
}
|
||||
|
||||
func positionAfterFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
|
||||
if builder.GetPositionAtLeast().IsZero() {
|
||||
return nil
|
||||
}
|
||||
query.Position = NewFilter(FieldPosition, builder.GetPositionAtLeast(), OperationGreaterOrEquals)
|
||||
return query.Position
|
||||
}
|
||||
|
||||
func aggregateIDFilter(query *eventstore.SearchQuery) *Filter {
|
||||
if len(query.GetAggregateIDs()) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(query.GetAggregateIDs()) == 1 {
|
||||
return NewFilter(FieldAggregateID, query.GetAggregateIDs()[0], OperationEquals)
|
||||
}
|
||||
return NewFilter(FieldAggregateID, database.TextArray[string](query.GetAggregateIDs()), OperationIn)
|
||||
}
|
||||
|
||||
func eventTypeFilter(query *eventstore.SearchQuery) *Filter {
|
||||
if len(query.GetEventTypes()) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(query.GetEventTypes()) == 1 {
|
||||
return NewFilter(FieldEventType, query.GetEventTypes()[0], OperationEquals)
|
||||
}
|
||||
return NewFilter(FieldEventType, database.TextArray[eventstore.EventType](query.GetEventTypes()), OperationIn)
|
||||
}
|
||||
|
||||
func aggregateTypeFilter(query *eventstore.SearchQuery) *Filter {
|
||||
if len(query.GetAggregateTypes()) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(query.GetAggregateTypes()) == 1 {
|
||||
return NewFilter(FieldAggregateType, query.GetAggregateTypes()[0], OperationEquals)
|
||||
}
|
||||
return NewFilter(FieldAggregateType, database.TextArray[eventstore.AggregateType](query.GetAggregateTypes()), OperationIn)
|
||||
}
|
||||
|
||||
func eventDataFilter(query *eventstore.SearchQuery) *Filter {
|
||||
if len(query.GetEventData()) == 0 {
|
||||
return nil
|
||||
}
|
||||
return NewFilter(FieldEventData, query.GetEventData(), OperationJSONContains)
|
||||
}
|
||||
|
||||
func eventPositionAfterFilter(query *eventstore.SearchQuery) *Filter {
|
||||
if pos := query.GetPositionAfter(); !pos.Equal(decimal.Decimal{}) {
|
||||
return NewFilter(FieldPosition, pos, OperationGreater)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func excludeEventTypeFilter(query *eventstore.ExclusionQuery) *Filter {
|
||||
if len(query.GetEventTypes()) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(query.GetEventTypes()) == 1 {
|
||||
return NewFilter(FieldEventType, query.GetEventTypes()[0], OperationEquals)
|
||||
}
|
||||
return NewFilter(FieldEventType, database.TextArray[eventstore.EventType](query.GetEventTypes()), OperationIn)
|
||||
}
|
||||
|
||||
func excludeAggregateTypeFilter(query *eventstore.ExclusionQuery) *Filter {
|
||||
if len(query.GetAggregateTypes()) < 1 {
|
||||
return nil
|
||||
}
|
||||
if len(query.GetAggregateTypes()) == 1 {
|
||||
return NewFilter(FieldAggregateType, query.GetAggregateTypes()[0], OperationEquals)
|
||||
}
|
||||
return NewFilter(FieldAggregateType, database.TextArray[eventstore.AggregateType](query.GetAggregateTypes()), OperationIn)
|
||||
}
|
146
apps/api/internal/eventstore/repository/search_query_test.go
Normal file
146
apps/api/internal/eventstore/repository/search_query_test.go
Normal file
@@ -0,0 +1,146 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func TestNewFilter(t *testing.T) {
|
||||
type args struct {
|
||||
field Field
|
||||
value interface{}
|
||||
operation Operation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *Filter
|
||||
}{
|
||||
{
|
||||
name: "aggregateID equals",
|
||||
args: args{
|
||||
field: FieldAggregateID,
|
||||
value: "hodor",
|
||||
operation: OperationEquals,
|
||||
},
|
||||
want: &Filter{Field: FieldAggregateID, Operation: OperationEquals, Value: "hodor"},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := NewFilter(tt.args.field, tt.args.value, tt.args.operation); !reflect.DeepEqual(got, tt.want) {
|
||||
t.Errorf("NewFilter() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFilter_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
field Field
|
||||
value interface{}
|
||||
operation Operation
|
||||
isNil bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "correct filter",
|
||||
fields: fields{
|
||||
field: FieldSequence,
|
||||
operation: OperationGreater,
|
||||
value: uint64(235),
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "filter is nil",
|
||||
fields: fields{isNil: true},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no field error",
|
||||
fields: fields{
|
||||
operation: OperationGreater,
|
||||
value: uint64(235),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no value error",
|
||||
fields: fields{
|
||||
field: FieldSequence,
|
||||
operation: OperationGreater,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "no operation error",
|
||||
fields: fields{
|
||||
field: FieldSequence,
|
||||
value: uint64(235),
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
var f *Filter
|
||||
if !tt.fields.isNil {
|
||||
f = &Filter{
|
||||
Field: tt.fields.field,
|
||||
Value: tt.fields.value,
|
||||
Operation: tt.fields.operation,
|
||||
}
|
||||
}
|
||||
if err := f.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Filter.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestColumns_Validate(t *testing.T) {
|
||||
type fields struct {
|
||||
columns eventstore.Columns
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "correct filter",
|
||||
fields: fields{
|
||||
columns: eventstore.ColumnsEvent,
|
||||
},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "columns too low",
|
||||
fields: fields{
|
||||
columns: 0,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "columns too high",
|
||||
fields: fields{
|
||||
columns: 100,
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.fields.columns.Validate(); (err != nil) != tt.wantErr {
|
||||
t.Errorf("Columns.Validate() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -0,0 +1,115 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
pgxdecimal "github.com/jackc/pgx-shopspring-decimal"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/cmd/initialise"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
"github.com/zitadel/zitadel/internal/database/postgres"
|
||||
new_es "github.com/zitadel/zitadel/internal/eventstore/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
testClient *sql.DB
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
config, cleanup := postgres.StartEmbedded()
|
||||
defer cleanup()
|
||||
|
||||
connConfig, err := pgxpool.ParseConfig(config.GetConnectionURL())
|
||||
logging.OnError(err).Fatal("unable to parse db url")
|
||||
|
||||
connConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
pgxdecimal.Register(conn.TypeMap())
|
||||
return new_es.RegisterEventstoreTypes(ctx, conn)
|
||||
}
|
||||
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
|
||||
logging.OnError(err).Fatal("unable to create db pool")
|
||||
|
||||
testClient = stdlib.OpenDBFromPool(pool)
|
||||
|
||||
err = testClient.Ping()
|
||||
logging.OnError(err).Fatal("unable to ping db")
|
||||
|
||||
defer func() {
|
||||
logging.OnError(testClient.Close()).Error("unable to close db")
|
||||
}()
|
||||
|
||||
err = initDB(context.Background(), &database.DB{DB: testClient, Database: &postgres.Config{Database: "zitadel"}})
|
||||
logging.OnError(err).Fatal("migrations failed")
|
||||
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
|
||||
func initDB(ctx context.Context, db *database.DB) error {
|
||||
config := new(database.Config)
|
||||
config.SetConnector(&postgres.Config{User: postgres.User{Username: "zitadel"}, Database: "zitadel"})
|
||||
|
||||
if err := initialise.ReadStmts(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := initialise.Init(ctx, db,
|
||||
initialise.VerifyUser(config.Username(), ""),
|
||||
initialise.VerifyDatabase(config.DatabaseName()),
|
||||
initialise.VerifyGrant(config.DatabaseName(), config.Username()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = initialise.VerifyZitadel(context.Background(), db, *config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// create old events
|
||||
_, err = db.Exec(oldEventsTable)
|
||||
return err
|
||||
}
|
||||
|
||||
type testDB struct{}
|
||||
|
||||
func (_ *testDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
|
||||
|
||||
func (*testDB) DatabaseName() string { return "db" }
|
||||
|
||||
func (*testDB) Username() string { return "user" }
|
||||
|
||||
func (*testDB) Type() dialect.DatabaseType { return dialect.DatabaseTypePostgres }
|
||||
|
||||
const oldEventsTable = `CREATE TABLE IF NOT EXISTS eventstore.events (
|
||||
id UUID DEFAULT gen_random_uuid()
|
||||
, event_type TEXT NOT NULL
|
||||
, aggregate_type TEXT NOT NULL
|
||||
, aggregate_id TEXT NOT NULL
|
||||
, aggregate_version TEXT NOT NULL
|
||||
, event_sequence BIGINT NOT NULL
|
||||
, previous_aggregate_sequence BIGINT
|
||||
, previous_aggregate_type_sequence INT8
|
||||
, creation_date TIMESTAMPTZ NOT NULL DEFAULT now()
|
||||
, created_at TIMESTAMPTZ NOT NULL DEFAULT clock_timestamp()
|
||||
, event_data JSONB
|
||||
, editor_user TEXT NOT NULL
|
||||
, editor_service TEXT
|
||||
, resource_owner TEXT NOT NULL
|
||||
, instance_id TEXT NOT NULL
|
||||
, "position" DECIMAL NOT NULL
|
||||
, in_tx_order INTEGER NOT NULL
|
||||
|
||||
, PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence)
|
||||
);`
|
242
apps/api/internal/eventstore/repository/sql/postgres.go
Normal file
242
apps/api/internal/eventstore/repository/sql/postgres.go
Normal file
@@ -0,0 +1,242 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"regexp"
|
||||
"strconv"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
// awaitOpenTransactions ensures event ordering, so we don't events younger that open transactions
|
||||
var (
|
||||
awaitOpenTransactionsV1 = ` AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')`
|
||||
awaitOpenTransactionsV2 = ` AND "position" < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')`
|
||||
)
|
||||
|
||||
func awaitOpenTransactions(useV1 bool) string {
|
||||
if useV1 {
|
||||
return awaitOpenTransactionsV1
|
||||
}
|
||||
return awaitOpenTransactionsV2
|
||||
}
|
||||
|
||||
type Postgres struct {
|
||||
*database.DB
|
||||
}
|
||||
|
||||
func NewPostgres(client *database.DB) *Postgres {
|
||||
return &Postgres{client}
|
||||
}
|
||||
|
||||
func (db *Postgres) Health(ctx context.Context) error { return db.Ping() }
|
||||
|
||||
// FilterToReducer finds all events matching the given search query and passes them to the reduce function.
|
||||
func (psql *Postgres) FilterToReducer(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder, reduce eventstore.Reducer) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
err = query(ctx, psql, searchQuery, reduce, false)
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
pgErr := new(pgconn.PgError)
|
||||
// check events2 not exists
|
||||
if errors.As(err, &pgErr) && pgErr.Code == "42P01" {
|
||||
return query(ctx, psql, searchQuery, reduce, true)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
// LatestPosition returns the latest position found by the search query
|
||||
func (db *Postgres) LatestPosition(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) (decimal.Decimal, error) {
|
||||
var position decimal.Decimal
|
||||
err := query(ctx, db, searchQuery, &position, false)
|
||||
return position, err
|
||||
}
|
||||
|
||||
// InstanceIDs returns the instance ids found by the search query
|
||||
func (db *Postgres) InstanceIDs(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) ([]string, error) {
|
||||
var ids []string
|
||||
err := query(ctx, db, searchQuery, &ids, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (db *Postgres) Client() *database.DB {
|
||||
return db.DB
|
||||
}
|
||||
|
||||
func (db *Postgres) orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string {
|
||||
if useV1 {
|
||||
if desc {
|
||||
return ` ORDER BY event_sequence DESC`
|
||||
}
|
||||
return ` ORDER BY event_sequence`
|
||||
}
|
||||
if shouldOrderBySequence {
|
||||
if desc {
|
||||
return ` ORDER BY "sequence" DESC`
|
||||
}
|
||||
return ` ORDER BY "sequence"`
|
||||
}
|
||||
|
||||
if desc {
|
||||
return ` ORDER BY "position" DESC, in_tx_order DESC`
|
||||
}
|
||||
return ` ORDER BY "position", in_tx_order`
|
||||
}
|
||||
|
||||
func (db *Postgres) eventQuery(useV1 bool) string {
|
||||
if useV1 {
|
||||
return "SELECT" +
|
||||
" creation_date" +
|
||||
", event_type" +
|
||||
", event_sequence" +
|
||||
", event_data" +
|
||||
", editor_user" +
|
||||
", resource_owner" +
|
||||
", instance_id" +
|
||||
", aggregate_type" +
|
||||
", aggregate_id" +
|
||||
", aggregate_version" +
|
||||
" FROM eventstore.events"
|
||||
}
|
||||
return "SELECT" +
|
||||
" created_at" +
|
||||
", event_type" +
|
||||
`, "sequence"` +
|
||||
`, "position"` +
|
||||
", payload" +
|
||||
", creator" +
|
||||
`, "owner"` +
|
||||
", instance_id" +
|
||||
", aggregate_type" +
|
||||
", aggregate_id" +
|
||||
", revision" +
|
||||
" FROM eventstore.events2"
|
||||
}
|
||||
|
||||
func (db *Postgres) maxPositionQuery(useV1 bool) string {
|
||||
if useV1 {
|
||||
return `SELECT event_sequence FROM eventstore.events`
|
||||
}
|
||||
return `SELECT "position" FROM eventstore.events2`
|
||||
}
|
||||
|
||||
func (db *Postgres) instanceIDsQuery(useV1 bool) string {
|
||||
table := "eventstore.events2"
|
||||
if useV1 {
|
||||
table = "eventstore.events"
|
||||
}
|
||||
return "SELECT DISTINCT instance_id FROM " + table
|
||||
}
|
||||
|
||||
func (db *Postgres) columnName(col repository.Field, useV1 bool) string {
|
||||
switch col {
|
||||
case repository.FieldAggregateID:
|
||||
return "aggregate_id"
|
||||
case repository.FieldAggregateType:
|
||||
return "aggregate_type"
|
||||
case repository.FieldSequence:
|
||||
if useV1 {
|
||||
return "event_sequence"
|
||||
}
|
||||
return `"sequence"`
|
||||
case repository.FieldResourceOwner:
|
||||
if useV1 {
|
||||
return "resource_owner"
|
||||
}
|
||||
return `"owner"`
|
||||
case repository.FieldInstanceID:
|
||||
return "instance_id"
|
||||
case repository.FieldEditorService:
|
||||
if useV1 {
|
||||
return "editor_service"
|
||||
}
|
||||
return ""
|
||||
case repository.FieldEditorUser:
|
||||
if useV1 {
|
||||
return "editor_user"
|
||||
}
|
||||
return "creator"
|
||||
case repository.FieldEventType:
|
||||
return "event_type"
|
||||
case repository.FieldEventData:
|
||||
if useV1 {
|
||||
return "event_data"
|
||||
}
|
||||
return "payload"
|
||||
case repository.FieldCreationDate:
|
||||
if useV1 {
|
||||
return "creation_date"
|
||||
}
|
||||
return "created_at"
|
||||
case repository.FieldPosition:
|
||||
return `"position"`
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Postgres) conditionFormat(operation repository.Operation) string {
|
||||
switch operation {
|
||||
case repository.OperationIn:
|
||||
return "%s %s ANY(?)"
|
||||
case repository.OperationNotIn:
|
||||
return "%s %s ALL(?)"
|
||||
case repository.OperationEquals, repository.OperationGreater, repository.OperationLess, repository.OperationJSONContains:
|
||||
fallthrough
|
||||
default:
|
||||
return "%s %s ?"
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Postgres) operation(operation repository.Operation) string {
|
||||
switch operation {
|
||||
case repository.OperationEquals, repository.OperationIn:
|
||||
return "="
|
||||
case repository.OperationGreater:
|
||||
return ">"
|
||||
case repository.OperationGreaterOrEquals:
|
||||
return ">="
|
||||
case repository.OperationLess:
|
||||
return "<"
|
||||
case repository.OperationJSONContains:
|
||||
return "@>"
|
||||
case repository.OperationNotIn:
|
||||
return "<>"
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
var (
|
||||
placeholder = regexp.MustCompile(`\?`)
|
||||
)
|
||||
|
||||
// placeholder replaces all "?" with postgres placeholders ($<NUMBER>)
|
||||
func (db *Postgres) placeholder(query string) string {
|
||||
occurrences := placeholder.FindAllStringIndex(query, -1)
|
||||
if len(occurrences) == 0 {
|
||||
return query
|
||||
}
|
||||
replaced := query[:occurrences[0][0]]
|
||||
|
||||
for i, l := range occurrences {
|
||||
nextIDX := len(query)
|
||||
if i < len(occurrences)-1 {
|
||||
nextIDX = occurrences[i+1][0]
|
||||
}
|
||||
replaced = replaced + "$" + strconv.Itoa(i+1) + query[l[1]:nextIDX]
|
||||
}
|
||||
return replaced
|
||||
}
|
325
apps/api/internal/eventstore/repository/sql/postgres_test.go
Normal file
325
apps/api/internal/eventstore/repository/sql/postgres_test.go
Normal file
@@ -0,0 +1,325 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"testing"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
func TestPostgres_placeholder(t *testing.T) {
|
||||
type args struct {
|
||||
query string
|
||||
}
|
||||
type res struct {
|
||||
query string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "no placeholders",
|
||||
args: args{
|
||||
query: "SELECT * FROM eventstore.events2",
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT * FROM eventstore.events2",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "one placeholder",
|
||||
args: args{
|
||||
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = ?",
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = $1",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple placeholders",
|
||||
args: args{
|
||||
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = ? AND aggregate_id = ? LIMIT ?",
|
||||
},
|
||||
res: res{
|
||||
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = $1 AND aggregate_id = $2 LIMIT $3",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := &Postgres{}
|
||||
if query := db.placeholder(tt.args.query); query != tt.res.query {
|
||||
t.Errorf("Postgres.placeholder() = %v, want %v", query, tt.res.query)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgres_operation(t *testing.T) {
|
||||
type res struct {
|
||||
op string
|
||||
}
|
||||
type args struct {
|
||||
operation repository.Operation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "no op",
|
||||
args: args{
|
||||
operation: repository.Operation(-1),
|
||||
},
|
||||
res: res{
|
||||
op: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "greater",
|
||||
args: args{
|
||||
operation: repository.OperationGreater,
|
||||
},
|
||||
res: res{
|
||||
op: ">",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "less",
|
||||
args: args{
|
||||
operation: repository.OperationLess,
|
||||
},
|
||||
res: res{
|
||||
op: "<",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "equals",
|
||||
args: args{
|
||||
operation: repository.OperationEquals,
|
||||
},
|
||||
res: res{
|
||||
op: "=",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "in",
|
||||
args: args{
|
||||
operation: repository.OperationIn,
|
||||
},
|
||||
res: res{
|
||||
op: "=",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := &Postgres{}
|
||||
if got := db.operation(tt.args.operation); got != tt.res.op {
|
||||
t.Errorf("Postgres.operation() = %v, want %v", got, tt.res.op)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgres_conditionFormat(t *testing.T) {
|
||||
type res struct {
|
||||
format string
|
||||
}
|
||||
type args struct {
|
||||
operation repository.Operation
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "default",
|
||||
args: args{
|
||||
operation: repository.OperationEquals,
|
||||
},
|
||||
res: res{
|
||||
format: "%s %s ?",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "in",
|
||||
args: args{
|
||||
operation: repository.OperationIn,
|
||||
},
|
||||
res: res{
|
||||
format: "%s %s ANY(?)",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := &Postgres{}
|
||||
if got := db.conditionFormat(tt.args.operation); got != tt.res.format {
|
||||
t.Errorf("Postgres.conditionFormat() = %v, want %v", got, tt.res.format)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPostgres_columnName(t *testing.T) {
|
||||
type res struct {
|
||||
name string
|
||||
}
|
||||
type args struct {
|
||||
field repository.Field
|
||||
useV1 bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
name: "invalid field",
|
||||
args: args{
|
||||
field: repository.Field(-1),
|
||||
},
|
||||
res: res{
|
||||
name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "aggregate id",
|
||||
args: args{
|
||||
field: repository.FieldAggregateID,
|
||||
},
|
||||
res: res{
|
||||
name: "aggregate_id",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "aggregate type",
|
||||
args: args{
|
||||
field: repository.FieldAggregateType,
|
||||
},
|
||||
res: res{
|
||||
name: "aggregate_type",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "editor service",
|
||||
args: args{
|
||||
field: repository.FieldEditorService,
|
||||
useV1: true,
|
||||
},
|
||||
res: res{
|
||||
name: "editor_service",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "editor service v2",
|
||||
args: args{
|
||||
field: repository.FieldEditorService,
|
||||
},
|
||||
res: res{
|
||||
name: "",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "editor user",
|
||||
args: args{
|
||||
field: repository.FieldEditorUser,
|
||||
useV1: true,
|
||||
},
|
||||
res: res{
|
||||
name: "editor_user",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "editor user v2",
|
||||
args: args{
|
||||
field: repository.FieldEditorUser,
|
||||
},
|
||||
res: res{
|
||||
name: "creator",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "event type",
|
||||
args: args{
|
||||
field: repository.FieldEventType,
|
||||
},
|
||||
res: res{
|
||||
name: "event_type",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "latest sequence",
|
||||
args: args{
|
||||
field: repository.FieldSequence,
|
||||
useV1: true,
|
||||
},
|
||||
res: res{
|
||||
name: "event_sequence",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "latest sequence v2",
|
||||
args: args{
|
||||
field: repository.FieldSequence,
|
||||
},
|
||||
res: res{
|
||||
name: `"sequence"`,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource owner",
|
||||
args: args{
|
||||
field: repository.FieldResourceOwner,
|
||||
useV1: true,
|
||||
},
|
||||
res: res{
|
||||
name: "resource_owner",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "resource owner v2",
|
||||
args: args{
|
||||
field: repository.FieldResourceOwner,
|
||||
},
|
||||
res: res{
|
||||
name: `"owner"`,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
db := &Postgres{}
|
||||
if got := db.columnName(tt.args.field, tt.args.useV1); got != tt.res.name {
|
||||
t.Errorf("Postgres.operation() = %v, want %v", got, tt.res.name)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func generateEvent(t *testing.T, aggregateID string, opts ...func(*repository.Event)) *repository.Event {
|
||||
t.Helper()
|
||||
e := &repository.Event{
|
||||
AggregateID: aggregateID,
|
||||
AggregateType: eventstore.AggregateType(t.Name()),
|
||||
EditorUser: "user",
|
||||
ResourceOwner: sql.NullString{String: "ro", Valid: true},
|
||||
Typ: "test.created",
|
||||
Version: "v1",
|
||||
Pos: decimal.NewFromInt(42),
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt(e)
|
||||
}
|
||||
|
||||
return e
|
||||
}
|
355
apps/api/internal/eventstore/repository/sql/query.go
Normal file
355
apps/api/internal/eventstore/repository/sql/query.go
Normal file
@@ -0,0 +1,355 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/shopspring/decimal"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type querier interface {
|
||||
columnName(field repository.Field, useV1 bool) string
|
||||
operation(repository.Operation) string
|
||||
conditionFormat(repository.Operation) string
|
||||
placeholder(query string) string
|
||||
eventQuery(useV1 bool) string
|
||||
maxPositionQuery(useV1 bool) string
|
||||
instanceIDsQuery(useV1 bool) string
|
||||
Client() *database.DB
|
||||
orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string
|
||||
dialect.Database
|
||||
}
|
||||
|
||||
type scan func(dest ...interface{}) error
|
||||
|
||||
type tx struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
func (t *tx) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
|
||||
rows, err := t.Tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Info("rows.Close failed")
|
||||
}()
|
||||
|
||||
if err = scan(rows); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
func query(ctx context.Context, criteria querier, searchQuery *eventstore.SearchQueryBuilder, dest interface{}, useV1 bool) error {
|
||||
q, err := repository.QueryFromBuilder(searchQuery)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
query, rowScanner := prepareColumns(criteria, q.Columns, useV1)
|
||||
where, values := prepareConditions(criteria, q, useV1)
|
||||
if where == "" || query == "" {
|
||||
return zerrors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
|
||||
}
|
||||
query += where
|
||||
|
||||
// instead of using the max function of the database (which doesn't work for postgres)
|
||||
// we select the most recent row
|
||||
if q.Columns == eventstore.ColumnsMaxPosition {
|
||||
q.Limit = 1
|
||||
q.Desc = true
|
||||
}
|
||||
|
||||
// if there is only one subquery we can optimize the query ordering by ordering by sequence
|
||||
var shouldOrderBySequence bool
|
||||
if len(q.SubQueries) == 1 {
|
||||
for _, filter := range q.SubQueries[0] {
|
||||
if filter.Field == repository.FieldAggregateID {
|
||||
shouldOrderBySequence = filter.Operation == repository.OperationEquals
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
switch q.Columns {
|
||||
case eventstore.ColumnsEvent,
|
||||
eventstore.ColumnsMaxPosition:
|
||||
query += criteria.orderByEventSequence(q.Desc, shouldOrderBySequence, useV1)
|
||||
}
|
||||
|
||||
if q.Limit > 0 {
|
||||
values = append(values, q.Limit)
|
||||
query += " LIMIT ?"
|
||||
}
|
||||
|
||||
if q.Offset > 0 {
|
||||
values = append(values, q.Offset)
|
||||
query += " OFFSET ?"
|
||||
}
|
||||
|
||||
if q.LockRows {
|
||||
query += " FOR UPDATE"
|
||||
switch q.LockOption {
|
||||
case eventstore.LockOptionWait: // default behavior
|
||||
case eventstore.LockOptionNoWait:
|
||||
query += " NOWAIT"
|
||||
case eventstore.LockOptionSkipLocked:
|
||||
query += " SKIP LOCKED"
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
query = criteria.placeholder(query)
|
||||
|
||||
var contextQuerier interface {
|
||||
QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error
|
||||
}
|
||||
contextQuerier = criteria.Client()
|
||||
if q.Tx != nil {
|
||||
contextQuerier = &tx{Tx: q.Tx}
|
||||
}
|
||||
|
||||
err = contextQuerier.QueryContext(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
err := rowScanner(rows.Scan, dest)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}, query, values...)
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Info("query failed")
|
||||
return zerrors.ThrowInternal(err, "SQL-KyeAx", "unable to filter events")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareColumns(criteria querier, columns eventstore.Columns, useV1 bool) (string, func(s scan, dest interface{}) error) {
|
||||
switch columns {
|
||||
case eventstore.ColumnsMaxPosition:
|
||||
return criteria.maxPositionQuery(useV1), maxPositionScanner
|
||||
case eventstore.ColumnsInstanceIDs:
|
||||
return criteria.instanceIDsQuery(useV1), instanceIDsScanner
|
||||
case eventstore.ColumnsEvent:
|
||||
return criteria.eventQuery(useV1), eventsScanner(useV1)
|
||||
default:
|
||||
return "", nil
|
||||
}
|
||||
}
|
||||
|
||||
func maxPositionScanner(row scan, dest interface{}) (err error) {
|
||||
position, ok := dest.(*decimal.Decimal)
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "SQL-NBjA9", "type must be pointer to decimal.Decimal got: %T", dest)
|
||||
}
|
||||
var res decimal.NullDecimal
|
||||
err = row(&res)
|
||||
if err == nil || errors.Is(err, sql.ErrNoRows) {
|
||||
*position = res.Decimal
|
||||
return nil
|
||||
}
|
||||
return zerrors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
|
||||
}
|
||||
|
||||
func instanceIDsScanner(scanner scan, dest interface{}) (err error) {
|
||||
ids, ok := dest.(*[]string)
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgument(nil, "SQL-Begh2", "type must be an array of string")
|
||||
}
|
||||
var id string
|
||||
err = scanner(&id)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("unable to scan row")
|
||||
return zerrors.ThrowInternal(err, "SQL-DEFGe", "unable to scan row")
|
||||
}
|
||||
*ids = append(*ids, id)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func eventsScanner(useV1 bool) func(scanner scan, dest interface{}) (err error) {
|
||||
return func(scanner scan, dest interface{}) (err error) {
|
||||
reduce, ok := dest.(eventstore.Reducer)
|
||||
if !ok {
|
||||
return zerrors.ThrowInvalidArgumentf(nil, "SQL-4GP6F", "events scanner: invalid type %T", dest)
|
||||
}
|
||||
event := new(repository.Event)
|
||||
position := new(decimal.NullDecimal)
|
||||
|
||||
if useV1 {
|
||||
err = scanner(
|
||||
&event.CreationDate,
|
||||
&event.Typ,
|
||||
&event.Seq,
|
||||
&event.Data,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.InstanceID,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&event.Version,
|
||||
)
|
||||
} else {
|
||||
var revision uint8
|
||||
err = scanner(
|
||||
&event.CreationDate,
|
||||
&event.Typ,
|
||||
&event.Seq,
|
||||
position,
|
||||
&event.Data,
|
||||
&event.EditorUser,
|
||||
&event.ResourceOwner,
|
||||
&event.InstanceID,
|
||||
&event.AggregateType,
|
||||
&event.AggregateID,
|
||||
&revision,
|
||||
)
|
||||
event.Version = eventstore.Version("v" + strconv.Itoa(int(revision)))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Warn("unable to scan row")
|
||||
return zerrors.ThrowInternal(err, "SQL-M0dsf", "unable to scan row")
|
||||
}
|
||||
event.Pos = position.Decimal
|
||||
return reduce(event)
|
||||
}
|
||||
}
|
||||
|
||||
func prepareConditions(criteria querier, query *repository.SearchQuery, useV1 bool) (_ string, args []any) {
|
||||
clauses, args := prepareQuery(criteria, useV1, query.InstanceID, query.InstanceIDs, query.ExcludedInstances)
|
||||
if clauses != "" && len(query.SubQueries) > 0 {
|
||||
clauses += " AND "
|
||||
}
|
||||
subClauses := make([]string, len(query.SubQueries))
|
||||
for i, filters := range query.SubQueries {
|
||||
var subArgs []any
|
||||
subClauses[i], subArgs = prepareQuery(criteria, useV1, filters...)
|
||||
// an error is thrown in [query]
|
||||
if subClauses[i] == "" {
|
||||
return "", nil
|
||||
}
|
||||
if len(query.SubQueries) > 1 && len(subArgs) > 1 {
|
||||
subClauses[i] = "(" + subClauses[i] + ")"
|
||||
}
|
||||
args = append(args, subArgs...)
|
||||
}
|
||||
if len(subClauses) == 1 {
|
||||
clauses += subClauses[0]
|
||||
} else if len(subClauses) > 1 {
|
||||
clauses += "(" + strings.Join(subClauses, " OR ") + ")"
|
||||
}
|
||||
|
||||
additionalClauses, additionalArgs := prepareQuery(criteria, useV1,
|
||||
query.Position,
|
||||
query.Owner,
|
||||
query.Sequence,
|
||||
query.CreatedAfter,
|
||||
query.CreatedBefore,
|
||||
query.Creator,
|
||||
)
|
||||
if additionalClauses != "" {
|
||||
if clauses != "" {
|
||||
clauses += " AND "
|
||||
}
|
||||
clauses += additionalClauses
|
||||
args = append(args, additionalArgs...)
|
||||
}
|
||||
|
||||
excludeAggregateIDs := query.ExcludeAggregateIDs
|
||||
if len(excludeAggregateIDs) > 0 {
|
||||
excludeAggregateIDs = append(excludeAggregateIDs, query.InstanceID, query.InstanceIDs, query.Position, query.CreatedAfter, query.CreatedBefore)
|
||||
}
|
||||
excludeAggregateIDsClauses, excludeAggregateIDsArgs := prepareQuery(criteria, useV1, excludeAggregateIDs...)
|
||||
if excludeAggregateIDsClauses != "" {
|
||||
if clauses != "" {
|
||||
clauses += " AND "
|
||||
}
|
||||
if useV1 {
|
||||
clauses += "aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events WHERE " + excludeAggregateIDsClauses + ")"
|
||||
} else {
|
||||
clauses += "aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events2 WHERE " + excludeAggregateIDsClauses + ")"
|
||||
}
|
||||
args = append(args, excludeAggregateIDsArgs...)
|
||||
}
|
||||
|
||||
if query.AwaitOpenTransactions {
|
||||
instanceIDs := make(database.TextArray[string], 0, 3)
|
||||
if query.InstanceID != nil {
|
||||
instanceIDs = append(instanceIDs, query.InstanceID.Value.(string))
|
||||
} else if query.InstanceIDs != nil {
|
||||
instanceIDs = append(instanceIDs, query.InstanceIDs.Value.(database.TextArray[string])...)
|
||||
}
|
||||
|
||||
for i := range instanceIDs {
|
||||
instanceIDs[i] = "zitadel_es_pusher_" + instanceIDs[i]
|
||||
}
|
||||
|
||||
clauses += awaitOpenTransactions(useV1)
|
||||
args = append(args, instanceIDs)
|
||||
}
|
||||
|
||||
if clauses == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return " WHERE " + clauses, args
|
||||
}
|
||||
|
||||
func prepareQuery(criteria querier, useV1 bool, filters ...*repository.Filter) (_ string, args []any) {
|
||||
clauses := make([]string, 0, len(filters))
|
||||
args = make([]any, 0, len(filters))
|
||||
for _, filter := range filters {
|
||||
if filter == nil {
|
||||
continue
|
||||
}
|
||||
arg := filter.Value
|
||||
|
||||
// marshal if payload filter
|
||||
if filter.Field == repository.FieldEventData {
|
||||
var err error
|
||||
arg, err = json.Marshal(arg)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("unable to marshal search value")
|
||||
continue
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
clauses = append(clauses, getCondition(criteria, filter, useV1))
|
||||
// if mapping failed an error is thrown in [query]
|
||||
if clauses[len(clauses)-1] == "" {
|
||||
return "", nil
|
||||
}
|
||||
args = append(args, arg)
|
||||
}
|
||||
|
||||
return strings.Join(clauses, " AND "), args
|
||||
}
|
||||
|
||||
func getCondition(cond querier, filter *repository.Filter, useV1 bool) (condition string) {
|
||||
field := cond.columnName(filter.Field, useV1)
|
||||
operation := cond.operation(filter.Operation)
|
||||
if field == "" || operation == "" {
|
||||
return ""
|
||||
}
|
||||
format := cond.conditionFormat(filter.Operation)
|
||||
|
||||
return fmt.Sprintf(format, field, operation)
|
||||
}
|
1056
apps/api/internal/eventstore/repository/sql/query_test.go
Normal file
1056
apps/api/internal/eventstore/repository/sql/query_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user