chore: move the go code into a subfolder

This commit is contained in:
Florian Forster
2025-08-05 15:20:32 -07:00
parent 4ad22ba456
commit cd2921de26
2978 changed files with 373 additions and 300 deletions

View File

@@ -0,0 +1,111 @@
package eventstore
import (
"context"
"github.com/zitadel/zitadel/internal/api/authz"
)
type aggregateOpt func(*Aggregate)
// NewAggregate is the default constructor of an aggregate
// opts overwrite values calculated by given parameters
func NewAggregate(
ctx context.Context,
id string,
typ AggregateType,
version Version,
opts ...aggregateOpt,
) *Aggregate {
a := &Aggregate{
ID: id,
Type: typ,
ResourceOwner: authz.GetCtxData(ctx).OrgID,
InstanceID: authz.GetInstance(ctx).InstanceID(),
Version: version,
}
for _, opt := range opts {
opt(a)
}
return a
}
// WithResourceOwner overwrites the resource owner of the aggregate
// by default the resource owner is set by the context
func WithResourceOwner(resourceOwner string) aggregateOpt {
return func(aggregate *Aggregate) {
aggregate.ResourceOwner = resourceOwner
}
}
// WithInstanceID overwrites the instance id of the aggregate
// by default the instance is set by the context
func WithInstanceID(id string) aggregateOpt {
return func(aggregate *Aggregate) {
aggregate.InstanceID = id
}
}
// AggregateFromWriteModel maps the given WriteModel to an Aggregate.
// Deprecated: Creates linter errors on missing context. Use [AggregateFromWriteModelCtx] instead.
func AggregateFromWriteModel(
wm *WriteModel,
typ AggregateType,
version Version,
) *Aggregate {
return AggregateFromWriteModelCtx(context.Background(), wm, typ, version)
}
// AggregateFromWriteModelCtx maps the given WriteModel to an Aggregate.
func AggregateFromWriteModelCtx(
ctx context.Context,
wm *WriteModel,
typ AggregateType,
version Version,
) *Aggregate {
return NewAggregate(
ctx,
wm.AggregateID,
typ,
version,
WithResourceOwner(wm.ResourceOwner),
WithInstanceID(wm.InstanceID),
)
}
// Aggregate is the basic implementation of Aggregater
type Aggregate struct {
// ID is the unique identitfier of this aggregate
ID string `json:"id"`
// Type is the name of the aggregate.
Type AggregateType `json:"type"`
// ResourceOwner is the org this aggregates belongs to
ResourceOwner string `json:"resourceOwner"`
// InstanceID is the instance this aggregate belongs to
InstanceID string `json:"instanceId"`
// Version is the semver this aggregate represents
Version Version `json:"version"`
}
// AggregateType is the object name
type AggregateType string
func isAggregateTypes(a *Aggregate, types ...AggregateType) bool {
for _, typ := range types {
if a.Type == typ {
return true
}
}
return false
}
func isAggregateIDs(a *Aggregate, ids ...string) bool {
for _, id := range ids {
if a.ID == id {
return true
}
}
return false
}

View File

@@ -0,0 +1,35 @@
package eventstore
type Asset struct {
// ID is to refer to the asset
ID string
//Asset is the actual image
Asset []byte
//Action defines if asset should be added or removed
Action AssetAction
}
type AssetAction int32
const (
AssetAdd AssetAction = iota
AssetRemove
)
func NewAddAsset(
id string,
asset []byte) *Asset {
return &Asset{
ID: id,
Asset: asset,
Action: AssetAdd,
}
}
func NewRemoveAsset(
id string) *Asset {
return &Asset{
ID: id,
Action: AssetRemove,
}
}

View File

@@ -0,0 +1,33 @@
Lorem ipsum dolor sit amet, consectetur adipiscing elit. Nulla nec pharetra neque. Nam viverra elit lorem, sit amet euismod velit pulvinar vulputate. Donec bibendum convallis sodales. Pellentesque mattis massa id hendrerit suscipit. Vivamus a placerat mauris. Cras laoreet sapien eget commodo aliquet. Integer bibendum gravida augue, ultricies varius felis porttitor sed. Nullam scelerisque neque nec magna mattis rhoncus. Curabitur dictum luctus erat, vel dapibus est porta hendrerit. Donec in mauris eleifend, imperdiet nunc in, condimentum mi. Suspendisse viverra rhoncus pharetra. Sed ornare ipsum vitae eros consequat rutrum. In ullamcorper non ipsum vel aliquam. Praesent at tortor ut metus elementum mollis. Morbi orci nisi, feugiat placerat tempor id, imperdiet sed diam.
Quisque finibus sit amet erat quis lobortis. Praesent sit amet eros lectus. Pellentesque viverra purus in augue pretium, vitae ultricies orci fermentum. Sed ut eleifend metus. Suspendisse eget facilisis velit. Vestibulum sagittis turpis felis, in ultricies mauris sagittis sed. Donec facilisis suscipit placerat. Ut consequat varius elit ac semper. Nulla convallis nisi eu lorem porttitor posuere. Aliquam molestie egestas odio a scelerisque. Nunc aliquet dui eget ipsum hendrerit, et aliquet ligula vulputate. Duis ac ullamcorper tellus. Sed sit amet semper magna, ut blandit neque. Mauris vitae sem tempor, ullamcorper velit sit amet, rutrum sapien. Quisque scelerisque sollicitudin lectus quis dictum.
Duis convallis, dui aliquet imperdiet aliquam, nisl sem vehicula nisi, eleifend dignissim velit nunc nec sem. In euismod laoreet lacinia. Etiam euismod risus neque, sed dapibus nisl iaculis non. Vivamus porta nec risus ut tincidunt. Sed elementum metus at mauris pellentesque, ultrices pharetra arcu accumsan. Lorem ipsum dolor sit amet, consectetur adipiscing elit. Phasellus risus mi, mollis eget condimentum sed, pulvinar viverra nunc. Nam sit amet scelerisque mauris.
Pellentesque vitae tellus blandit ligula gravida iaculis. Pellentesque nec ex tellus. Vivamus egestas metus lacus, non iaculis est rutrum sodales. Phasellus eu mi tempus ligula mollis molestie. Suspendisse sit amet urna faucibus, tincidunt sem nec, faucibus massa. Integer vitae dui odio. Praesent finibus, quam posuere sodales maximus, libero quam vehicula massa, non dictum est risus non dui. Etiam elit risus, molestie et viverra in, dignissim et tortor. Curabitur posuere nunc id ante hendrerit, non semper neque rhoncus.
Pellentesque in libero in dolor euismod elementum. Vestibulum malesuada sodales lacinia. Praesent tellus nisl, ultrices vel massa eu, bibendum tincidunt eros. Pellentesque in rhoncus mi. Nulla sit amet lectus eleifend, varius quam sed, hendrerit risus. Sed molestie ipsum vel sapien condimentum blandit. Etiam ex nibh, dignissim et vestibulum ac, tempus et lectus. Mauris elit libero, tempus non scelerisque at, lacinia id elit. Integer erat velit, mollis vel laoreet eu, interdum feugiat mauris. Praesent et tincidunt lorem. Praesent finibus lobortis orci, non sagittis turpis rutrum et.
Donec et nisi nec neque condimentum sodales nec eget nisi. Nulla facilisi. Maecenas convallis, enim eu tristique rhoncus, neque urna aliquet lectus, sit amet porttitor eros neque nec ante. Suspendisse malesuada leo ut leo molestie iaculis. Nullam pretium ac eros vel finibus. Curabitur interdum iaculis enim vel imperdiet. Nam rutrum facilisis nisl, ut faucibus urna congue nec. Aenean ac diam magna.
Curabitur tempor, magna in eleifend tincidunt, nulla quam pellentesque augue, et hendrerit dui quam quis erat. Praesent id odio sed arcu eleifend volutpat. Integer ut massa sit amet ipsum egestas facilisis. Nulla facilisis mi et velit cursus, non sagittis turpis blandit. Duis iaculis neque a sapien imperdiet pellentesque. Nam urna nisi, euismod eu velit ac, mattis efficitur lorem. Duis in dignissim ante, a elementum felis. Quisque vestibulum malesuada malesuada. Donec ultricies, purus sed tincidunt efficitur, nisl nunc aliquet orci, at euismod est magna id dolor. Aliquam ornare placerat ex, at vulputate dui sagittis nec. Proin lorem lorem, pharetra ut lectus et, gravida euismod augue. Nullam aliquam id neque a tristique. Quisque non ante a orci porta vulputate eu non eros. Vestibulum velit velit, placerat sit amet malesuada quis, auctor at nunc.
Phasellus viverra lacus sed metus cursus fermentum. Fusce at congue nisl, at pharetra nunc. Nam maximus viverra dolor, sit amet vestibulum erat dignissim sit amet. Class aptent taciti sociosqu ad litora torquent per conubia nostra, per inceptos himenaeos. Nam at pellentesque mauris. Morbi vestibulum diam et vehicula suscipit. Suspendisse ex purus, varius et efficitur id, blandit a libero. Sed dapibus, neque eget finibus vulputate, ipsum sapien aliquam risus, convallis varius nulla velit a massa. Quisque eleifend ligula in sapien bibendum maximus. In ut bibendum nisi, fringilla mollis ex.
Suspendisse a est vel neque venenatis mollis ac egestas quam. Aenean at pretium massa, vitae rutrum est. Quisque et turpis velit. Mauris egestas nisl non mattis tempus. Pellentesque et magna nulla. Sed dolor odio, vehicula a lacus a, placerat pellentesque ante. Nam maximus dui fermentum, ornare ipsum ut, scelerisque odio. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Aliquam ac dictum mi. Integer porta dapibus orci, sollicitudin pharetra arcu mollis sed. Donec iaculis quam urna, ac tempus erat auctor eu. Maecenas pulvinar bibendum justo, sit amet blandit tortor. Maecenas id facilisis diam, a viverra odio.
Proin at ullamcorper orci. Pellentesque dapibus, mauris ut auctor rutrum, elit arcu fermentum dolor, sit amet porta leo lacus sed sapien. Fusce pharetra nulla a venenatis tempor. Curabitur ut nibh nunc. Nunc vel libero ante. Fusce quis erat sit amet dui facilisis laoreet. Sed egestas neque quam, sed mollis sem scelerisque imperdiet. Integer vulputate orci et eros scelerisque, sed feugiat lorem tincidunt. Fusce porta semper nulla, eget sollicitudin mi tincidunt quis. Nulla facilisi. Duis suscipit turpis id sem faucibus bibendum.
Cras elementum auctor mauris, ac finibus nisi fermentum non. Fusce congue enim mi, ac viverra lectus rutrum in. Curabitur sodales nibh est, in scelerisque dolor lobortis non. Duis non massa non ante finibus vehicula sed at est. Nulla facilisis porta magna, non convallis urna efficitur eget. Nunc nec libero nec felis facilisis sagittis. Cras aliquet neque et sapien dignissim ultrices. Donec blandit est metus. Vivamus pellentesque nec risus et tincidunt. Etiam nulla nibh, vehicula et tincidunt sit amet, porttitor a purus. Sed vel leo ante. Fusce purus felis, maximus sed erat ac, tincidunt volutpat felis. Vestibulum a ullamcorper magna. In luctus fringilla lectus, viverra vehicula velit. Suspendisse at nisl eget felis pretium faucibus tempor at nibh.
Pellentesque vitae ligula suscipit, fermentum mauris eget, mollis magna. Duis posuere iaculis enim sit amet mattis. Donec tincidunt libero sapien, vitae convallis risus fringilla quis. Suspendisse elementum nisl urna. Aliquam erat volutpat. Mauris lacinia elit quis turpis lacinia rutrum. Donec varius a sapien et maximus. Nullam nec finibus libero. Duis pulvinar est urna, eleifend faucibus urna fringilla sed. Pellentesque pretium fermentum est, a dignissim ante cursus ut.
Ut et massa finibus, pulvinar metus quis, pulvinar nulla. Vestibulum euismod congue euismod. Praesent tincidunt scelerisque suscipit. Sed condimentum massa quis cursus sodales. Vestibulum feugiat elit ac lobortis porta. Phasellus luctus aliquet enim ac efficitur. Nam accumsan est nunc, non malesuada nisl pretium ac. Integer id turpis hendrerit, rutrum nibh malesuada, mollis ipsum. Etiam porta aliquam leo, eget luctus nunc pretium eu. Quisque molestie risus sem, vel commodo enim ultricies tempus. Nulla cursus facilisis lacus nec aliquet. Nullam consequat arcu in mi vehicula, sit amet viverra lectus lacinia. Aliquam congue, sapien sed mattis mollis, lacus felis ultricies lectus, sed ultricies augue tellus vel purus. Vestibulum ante ipsum primis in faucibus orci luctus et ultrices posuere cubilia curae;
Integer sagittis, eros luctus semper malesuada, leo est sodales lacus, quis mollis tellus ex in justo. Donec ante nibh, viverra quis nibh quis, tempor interdum sem. Donec eu consequat sapien. Pellentesque aliquam neque mauris, et dictum arcu ultricies eu. Sed mattis risus eu turpis sagittis placerat. Nunc viverra mi elit, ac dapibus leo malesuada nec. Nunc eget augue ut massa condimentum faucibus sed vitae justo. Nunc leo lectus, iaculis in lobortis ac, dictum ut metus. Pellentesque habitant morbi tristique senectus et netus et malesuada fames ac turpis egestas. Quisque semper vestibulum lacinia.
Quisque ornare gravida libero et laoreet. Morbi condimentum aliquam libero, sed interdum risus efficitur nec. Pellentesque in justo blandit, varius eros non, placerat nisl. Mauris turpis sem, congue at egestas sed, luctus id sapien. Maecenas volutpat, massa sit amet cursus interdum, augue odio fermentum orci, eget accumsan metus mi a augue. Quisque pharetra efficitur turpis, in mattis ipsum suscipit quis. Duis eu consectetur libero. Aenean imperdiet purus nec nibh luctus, sit amet blandit leo euismod.
Donec placerat ut urna vitae porta. Proin ipsum elit, consectetur quis magna at, ultrices commodo lectus. Nulla cursus felis sed libero laoreet blandit. Maecenas mollis, sapien eget aliquet maximus, est tellus consequat dolor, eu tempor sapien ante sit amet arcu. Integer et lectus mi. Sed ullamcorper, enim sed porttitor commodo, sem leo rhoncus velit, sed gravida ipsum est nec ante. Aenean malesuada metus ornare mattis gravida. Fusce porttitor mattis dolor, quis convallis sem. Sed sed ex nisi. Etiam lacus felis, pulvinar et rutrum sed, lobortis nec ex. Pellentesque at dignissim metus. Nunc quis pellentesque ligula. Maecenas ut neque eget dolor ultricies condimentum. Suspendisse eget luctus dolor.
Etiam vitae mattis arcu. Morbi faucibus et justo vel malesuada. Suspendisse semper molestie lorem in tempus. Aliquam ut mauris non velit laoreet sagittis. Sed ac efficitur nibh. Praesent commodo erat a nunc placerat vestibulum. Phasellus laoreet vitae tortor nec aliquet. Fusce nibh augue, mollis sed urna ut, tristique venenatis arcu.

View File

@@ -0,0 +1,14 @@
package eventstore
import (
"time"
)
type Config struct {
PushTimeout time.Duration
MaxRetries uint32
Pusher Pusher
Querier Querier
Searcher Searcher
}

View File

@@ -0,0 +1,111 @@
package eventstore
import (
"encoding/json"
"reflect"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/zerrors"
)
type action interface {
Aggregate() *Aggregate
// Creator is the userid of the user which created the action
Creator() string
// Type describes the action
Type() EventType
// Revision of the action
Revision() uint16
}
// Command is the intend to store an event into the eventstore
type Command interface {
action
// Payload returns the payload of the event. It represent the changed fields by the event
// valid types are:
// * nil: no payload
// * struct: which can be marshalled to json
// * pointer: to struct which can be marshalled to json
// * []byte: json marshalled data
Payload() any
// UniqueConstraints should be added for unique attributes of an event, if nil constraints will not be checked
UniqueConstraints() []*UniqueConstraint
// Fields should be added for fields which should be indexed for lookup, if nil fields will not be indexed
Fields() []*FieldOperation
}
// Event is a stored activity
type Event interface {
action
// Sequence of the event in the aggregate
Sequence() uint64
// CreatedAt is the time the event was created at
CreatedAt() time.Time
// Position is the global position of the event
Position() decimal.Decimal
// Unmarshal parses the payload and stores the result
// in the value pointed to by ptr. If ptr is nil or not a pointer,
// Unmarshal returns an error
Unmarshal(ptr any) error
// Deprecated: only use for migration
DataAsBytes() []byte
}
type EventType string
func EventData(event Command) ([]byte, error) {
switch data := event.Payload().(type) {
case nil:
return nil, nil
case []byte:
if json.Valid(data) {
return data, nil
}
return nil, zerrors.ThrowInvalidArgument(nil, "V2-6SbbS", "data bytes are not json")
}
dataType := reflect.TypeOf(event.Payload())
if dataType.Kind() == reflect.Ptr {
dataType = dataType.Elem()
}
if dataType.Kind() == reflect.Struct {
dataBytes, err := json.Marshal(event.Payload())
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "V2-xG87M", "could not marshal data")
}
return dataBytes, nil
}
return nil, zerrors.ThrowInvalidArgument(nil, "V2-91NRm", "wrong type of event data")
}
type BaseEventSetter[T any] interface {
Event
SetBaseEvent(*BaseEvent)
*T
}
func GenericEventMapper[T any, PT BaseEventSetter[T]](event Event) (Event, error) {
e := PT(new(T))
e.SetBaseEvent(BaseEventFromRepo(event))
err := event.Unmarshal(e)
if err != nil {
return nil, zerrors.ThrowInternal(err, "ES-Thai6", "unable to unmarshal event")
}
return e, nil
}
func isEventTypes(command Command, types ...EventType) bool {
for _, typ := range types {
if command.Type() == typ {
return true
}
}
return false
}

View File

@@ -0,0 +1,136 @@
package eventstore
import (
"context"
"encoding/json"
"strconv"
"strings"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/service"
)
var (
_ Event = (*BaseEvent)(nil)
)
// BaseEvent represents the minimum metadata of an event
type BaseEvent struct {
ID string
EventType EventType `json:"-"`
Agg *Aggregate `json:"-"`
Seq uint64
Pos decimal.Decimal
Creation time.Time
previousAggregateSequence uint64
previousAggregateTypeSequence uint64
//User who created the event
User string `json:"-"`
//Service which created the event
Service string `json:"-"`
Data []byte `json:"-"`
}
// Position implements Event.
func (e *BaseEvent) Position() decimal.Decimal {
return e.Pos
}
// EditorService implements Command
func (e *BaseEvent) EditorService() string {
return e.Service
}
// EditorUser implements Command
func (e *BaseEvent) EditorUser() string {
return e.User
}
// Creator implements action
func (e *BaseEvent) Creator() string {
return e.EditorUser()
}
// Type implements action
func (e *BaseEvent) Type() EventType {
return e.EventType
}
// Sequence is an upcounting unique number of the event
func (e *BaseEvent) Sequence() uint64 {
return e.Seq
}
// CreationDate is the the time, the event is inserted into the eventstore
func (e *BaseEvent) CreationDate() time.Time {
return e.Creation
}
// CreatedAt implements Event
func (e *BaseEvent) CreatedAt() time.Time {
return e.CreationDate()
}
// Aggregate implements action
func (e *BaseEvent) Aggregate() *Aggregate {
return e.Agg
}
// Data returns the payload of the event. It represent the changed fields by the event
func (e *BaseEvent) DataAsBytes() []byte {
return e.Data
}
// Revision implements action
func (e *BaseEvent) Revision() uint16 {
revision, err := strconv.ParseUint(strings.TrimPrefix(string(e.Agg.Version), "v"), 10, 16)
logging.OnError(err).Debug("failed to parse event revision")
return uint16(revision)
}
// Unmarshal implements Event
func (e *BaseEvent) Unmarshal(ptr any) error {
if len(e.Data) == 0 {
return nil
}
return json.Unmarshal(e.Data, ptr)
}
const defaultService = "zitadel"
// BaseEventFromRepo maps a stored event to a BaseEvent
func BaseEventFromRepo(event Event) *BaseEvent {
return &BaseEvent{
Agg: event.Aggregate(),
EventType: event.Type(),
Creation: event.CreatedAt(),
Seq: event.Sequence(),
Service: defaultService,
User: event.Creator(),
Data: event.DataAsBytes(),
Pos: event.Position(),
}
}
// NewBaseEventForPush is the constructor for event's which will be pushed into the eventstore
// the resource owner of the aggregate is only used if it's the first event of this aggregate type
// afterwards the resource owner of the first previous events is taken
func NewBaseEventForPush(ctx context.Context, aggregate *Aggregate, typ EventType) *BaseEvent {
return &BaseEvent{
Agg: aggregate,
User: authz.GetCtxData(ctx).UserID,
Service: service.FromContext(ctx),
EventType: typ,
}
}
func (*BaseEvent) Fields() []*FieldOperation {
return nil
}

View File

@@ -0,0 +1,322 @@
package eventstore
import (
"context"
"errors"
"sort"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/zerrors"
)
func init() {
// this is needed to ensure that position is marshaled as a number
// otherwise it will be marshaled as a string
decimal.MarshalJSONWithoutQuotes = true
}
// Eventstore abstracts all functions needed to store valid events
// and filters the stored events
type Eventstore struct {
PushTimeout time.Duration
maxRetries int
pusher Pusher
querier Querier
searcher Searcher
}
var (
eventInterceptors map[EventType]eventTypeInterceptors
eventTypes []string
aggregateTypes []string
eventTypeMapping = map[EventType]AggregateType{}
)
// RegisterFilterEventMapper registers a function for mapping an eventstore event to an event
func RegisterFilterEventMapper(aggregateType AggregateType, eventType EventType, mapper func(Event) (Event, error)) {
if mapper == nil || eventType == "" {
return
}
appendEventType(eventType)
appendAggregateType(aggregateType)
if eventInterceptors == nil {
eventInterceptors = make(map[EventType]eventTypeInterceptors)
}
interceptor := eventInterceptors[eventType]
interceptor.eventMapper = mapper
eventInterceptors[eventType] = interceptor
eventTypeMapping[eventType] = aggregateType
}
type eventTypeInterceptors struct {
eventMapper func(Event) (Event, error)
}
func NewEventstore(config *Config) *Eventstore {
return &Eventstore{
PushTimeout: config.PushTimeout,
maxRetries: int(config.MaxRetries),
pusher: config.Pusher,
querier: config.Querier,
searcher: config.Searcher,
}
}
// Health checks if the eventstore can properly work
// It checks if the repository can serve load
func (es *Eventstore) Health(ctx context.Context) error {
if err := es.pusher.Health(ctx); err != nil {
return err
}
return es.querier.Health(ctx)
}
// Push pushes the events in a single transaction
// an event needs at least an aggregate
func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error) {
return es.PushWithClient(ctx, nil, cmds...)
}
// PushWithClient pushes the events in a single transaction using the provided database client
// an event needs at least an aggregate
func (es *Eventstore) PushWithClient(ctx context.Context, client database.ContextQueryExecuter, cmds ...Command) ([]Event, error) {
if es.PushTimeout > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, es.PushTimeout)
defer cancel()
}
var (
events []Event
err error
)
// Retry when there is a collision of the sequence as part of the primary key.
// "duplicate key value violates unique constraint \"events2_pkey\" (SQLSTATE 23505)"
// https://github.com/zitadel/zitadel/issues/7202
retry:
for i := 0; i <= es.maxRetries; i++ {
events, err = es.pusher.Push(ctx, client, cmds...)
// if there is a transaction passed the calling function needs to retry
if _, ok := client.(database.Tx); ok {
break retry
}
var pgErr *pgconn.PgError
if !errors.As(err, &pgErr) {
break retry
}
if pgErr.ConstraintName == "events2_pkey" && pgErr.SQLState() == "23505" {
logging.WithError(err).Info("eventstore push retry")
continue
}
if pgErr.SQLState() == "CR000" || pgErr.SQLState() == "40001" {
logging.WithError(err).Info("eventstore push retry")
continue
}
break retry
}
if err != nil {
return nil, err
}
mappedEvents, err := es.mapEvents(events)
if err != nil {
return mappedEvents, err
}
es.notify(mappedEvents)
return mappedEvents, nil
}
func AggregateTypeFromEventType(typ EventType) AggregateType {
return eventTypeMapping[typ]
}
func (es *Eventstore) EventTypes() []string {
return eventTypes
}
func (es *Eventstore) AggregateTypes() []string {
return aggregateTypes
}
// FillFields implements the [Searcher] interface
func (es *Eventstore) FillFields(ctx context.Context, events ...FillFieldsEvent) error {
return es.searcher.FillFields(ctx, events...)
}
// Search implements the [Searcher] interface
func (es *Eventstore) Search(ctx context.Context, conditions ...map[FieldType]any) ([]*SearchResult, error) {
if len(conditions) == 0 {
return nil, zerrors.ThrowInvalidArgument(nil, "V3-5Xbr1", "no search conditions")
}
return es.searcher.Search(ctx, conditions...)
}
// Filter filters the stored events based on the searchQuery
// and maps the events to the defined event structs
//
// Deprecated: Use [FilterToQueryReducer] instead to avoid allocations.
func (es *Eventstore) Filter(ctx context.Context, searchQuery *SearchQueryBuilder) ([]Event, error) {
events := make([]Event, 0, searchQuery.GetLimit())
searchQuery.ensureInstanceID(ctx)
err := es.querier.FilterToReducer(ctx, searchQuery, func(event Event) error {
event, err := es.mapEvent(event)
if err != nil {
return err
}
events = append(events, event)
return nil
})
if err != nil {
return nil, err
}
return events, nil
}
func (es *Eventstore) mapEvents(events []Event) (mappedEvents []Event, err error) {
mappedEvents = make([]Event, len(events))
for i, event := range events {
mappedEvents[i], err = es.mapEventLocked(event)
if err != nil {
return nil, err
}
}
return mappedEvents, nil
}
func (es *Eventstore) mapEvent(event Event) (Event, error) {
return es.mapEventLocked(event)
}
func (es *Eventstore) mapEventLocked(event Event) (Event, error) {
interceptors, ok := eventInterceptors[event.Type()]
if !ok || interceptors.eventMapper == nil {
return BaseEventFromRepo(event), nil
}
return interceptors.eventMapper(event)
}
// TODO: refactor so we can change to the following interface:
/*
type reducer interface {
// Reduce applies an event on the object.
Reduce(Event) error
}
*/
type reducer interface {
//Reduce handles the events of the internal events list
// it only appends the newly added events
Reduce() error
//AppendEvents appends the passed events to an internal list of events
AppendEvents(...Event)
}
// FilterToReducer filters the events based on the search query, appends all events to the reducer and calls it's reduce function
func (es *Eventstore) FilterToReducer(ctx context.Context, searchQuery *SearchQueryBuilder, r reducer) error {
searchQuery.ensureInstanceID(ctx)
return es.querier.FilterToReducer(ctx, searchQuery, func(event Event) error {
event, err := es.mapEvent(event)
if err != nil {
return err
}
r.AppendEvents(event)
return r.Reduce()
})
}
// LatestPosition filters the latest position for the given search query
func (es *Eventstore) LatestPosition(ctx context.Context, queryFactory *SearchQueryBuilder) (decimal.Decimal, error) {
queryFactory.InstanceID(authz.GetInstance(ctx).InstanceID())
return es.querier.LatestPosition(ctx, queryFactory)
}
// InstanceIDs returns the distinct instance ids found by the search query
// Warning: this function can have high impact on performance, only use this function during setup
func (es *Eventstore) InstanceIDs(ctx context.Context, queryFactory *SearchQueryBuilder) ([]string, error) {
return es.querier.InstanceIDs(ctx, queryFactory)
}
func (es *Eventstore) Client() *database.DB {
return es.querier.Client()
}
type QueryReducer interface {
reducer
//Query returns the SearchQueryFactory for the events needed in reducer
Query() *SearchQueryBuilder
}
// FilterToQueryReducer filters the events based on the search query of the query function,
// appends all events to the reducer and calls it's reduce function
func (es *Eventstore) FilterToQueryReducer(ctx context.Context, r QueryReducer) error {
return es.FilterToReducer(ctx, r.Query(), r)
}
type Reducer func(event Event) error
type Querier interface {
// Health checks if the connection to the storage is available
Health(ctx context.Context) error
// FilterToReducer calls r for every event returned from the storage
FilterToReducer(ctx context.Context, searchQuery *SearchQueryBuilder, r Reducer) error
// LatestPosition returns the latest position found by the search query
LatestPosition(ctx context.Context, queryFactory *SearchQueryBuilder) (decimal.Decimal, error)
// InstanceIDs returns the instance ids found by the search query
InstanceIDs(ctx context.Context, queryFactory *SearchQueryBuilder) ([]string, error)
// Client returns the underlying database connection
Client() *database.DB
}
type Pusher interface {
// Health checks if the connection to the storage is available
Health(ctx context.Context) error
// Push stores the actions
Push(ctx context.Context, client database.ContextQueryExecuter, commands ...Command) (_ []Event, err error)
// Client returns the underlying database connection
Client() *database.DB
}
type FillFieldsEvent interface {
Event
Fields() []*FieldOperation
}
type Searcher interface {
// Search allows to search for specific fields of objects
// The instance id is taken from the context
// The list of conditions are combined with AND
// The search fields are combined with OR
// At least one must be defined
Search(ctx context.Context, conditions ...map[FieldType]any) (result []*SearchResult, err error)
// FillFields is to insert the fields of previously stored events
FillFields(ctx context.Context, events ...FillFieldsEvent) error
}
func appendEventType(typ EventType) {
i := sort.SearchStrings(eventTypes, string(typ))
if i < len(eventTypes) && eventTypes[i] == string(typ) {
return
}
eventTypes = append(eventTypes[:i], append([]string{string(typ)}, eventTypes[i:]...)...)
}
func appendAggregateType(typ AggregateType) {
i := sort.SearchStrings(aggregateTypes, string(typ))
if len(aggregateTypes) > i && aggregateTypes[i] == string(typ) {
return
}
aggregateTypes = append(aggregateTypes[:i], append([]string{string(typ)}, aggregateTypes[i:]...)...)
}

View File

@@ -0,0 +1,162 @@
package eventstore_test
import (
"context"
_ "embed"
"fmt"
"strconv"
"testing"
"github.com/zitadel/zitadel/internal/eventstore"
)
//go:embed bench_payload.txt
var text string
func Benchmark_Push_SameAggregate(b *testing.B) {
ctx := context.Background()
smallPayload := struct {
Username string
Firstname string
Lastname string
}{
Username: "username",
Firstname: "firstname",
Lastname: "lastname",
}
bigPayload := struct {
Username string
Firstname string
Lastname string
Text string
}{
Username: "username",
Firstname: "firstname",
Lastname: "lastname",
Text: text,
}
commands := map[string][]eventstore.Command{
"no payload one command": {
generateCommand(eventstore.AggregateType(b.Name()), "id"),
},
"small payload one command": {
generateCommand(eventstore.AggregateType(b.Name()), "id", withTestData(smallPayload)),
},
"big payload one command": {
generateCommand(eventstore.AggregateType(b.Name()), "id", withTestData(bigPayload)),
},
"no payload multiple commands": {
generateCommand(eventstore.AggregateType(b.Name()), "id"),
generateCommand(eventstore.AggregateType(b.Name()), "id"),
generateCommand(eventstore.AggregateType(b.Name()), "id"),
},
"mixed payload multiple command": {
generateCommand(eventstore.AggregateType(b.Name()), "id", withTestData(smallPayload)),
generateCommand(eventstore.AggregateType(b.Name()), "id", withTestData(bigPayload)),
generateCommand(eventstore.AggregateType(b.Name()), "id", withTestData(smallPayload)),
generateCommand(eventstore.AggregateType(b.Name()), "id", withTestData(bigPayload)),
},
}
for cmdsKey, cmds := range commands {
for pusherKey, store := range pushers {
b.Run(fmt.Sprintf("Benchmark_Push_SameAggregate-%s-%s", pusherKey, cmdsKey), func(b *testing.B) {
b.StopTimer()
cleanupEventstore(clients[pusherKey])
b.StartTimer()
for n := 0; n < b.N; n++ {
_, err := store.Push(ctx, store.Client().DB, cmds...)
if err != nil {
b.Error(err)
}
}
})
}
}
}
func Benchmark_Push_MultipleAggregate_Parallel(b *testing.B) {
smallPayload := struct {
Username string
Firstname string
Lastname string
}{
Username: "username",
Firstname: "firstname",
Lastname: "lastname",
}
bigPayload := struct {
Username string
Firstname string
Lastname string
Text string
}{
Username: "username",
Firstname: "firstname",
Lastname: "lastname",
Text: text,
}
commandCreators := map[string]func(id string) []eventstore.Command{
"no payload one command": func(id string) []eventstore.Command {
return []eventstore.Command{
generateCommand(eventstore.AggregateType(b.Name()), id),
}
},
"small payload one command": func(id string) []eventstore.Command {
return []eventstore.Command{
generateCommand(eventstore.AggregateType(b.Name()), id, withTestData(smallPayload)),
}
},
"big payload one command": func(id string) []eventstore.Command {
return []eventstore.Command{
generateCommand(eventstore.AggregateType(b.Name()), id, withTestData(bigPayload)),
}
},
"no payload multiple commands": func(id string) []eventstore.Command {
return []eventstore.Command{
generateCommand(eventstore.AggregateType(b.Name()), id),
generateCommand(eventstore.AggregateType(b.Name()), id),
generateCommand(eventstore.AggregateType(b.Name()), id),
}
},
"mixed payload multiple command": func(id string) []eventstore.Command {
return []eventstore.Command{
generateCommand(eventstore.AggregateType(b.Name()), id, withTestData(smallPayload)),
generateCommand(eventstore.AggregateType(b.Name()), id, withTestData(bigPayload)),
generateCommand(eventstore.AggregateType(b.Name()), id, withTestData(smallPayload)),
generateCommand(eventstore.AggregateType(b.Name()), id, withTestData(bigPayload)),
}
},
}
for cmdsKey, commandCreator := range commandCreators {
for pusherKey, store := range pushers {
b.Run(fmt.Sprintf("Benchmark_Push_DifferentAggregate-%s-%s", cmdsKey, pusherKey), func(b *testing.B) {
b.StopTimer()
cleanupEventstore(clients[pusherKey])
ctx, cancel := context.WithCancel(context.Background())
b.StartTimer()
i := 0
b.RunParallel(func(p *testing.PB) {
for p.Next() {
i++
_, err := store.Push(ctx, store.Client().DB, commandCreator(strconv.Itoa(i))...)
if err != nil {
b.Error(err)
}
}
})
cancel()
})
}
}
}

View File

@@ -0,0 +1,712 @@
package eventstore_test
import (
"context"
"database/sql"
"sync"
"testing"
"time"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
func TestEventstore_Push_OneAggregate(t *testing.T) {
type args struct {
ctx context.Context
commands []eventstore.Command
uniqueDataType string
uniqueDataField string
uniqueDataInstanceID string
}
type eventsRes struct {
pushedEventsCount int
uniqueCount int
assetCount int
aggType eventstore.AggregateType
aggIDs database.TextArray[string]
}
type res struct {
wantErr bool
eventsRes eventsRes
}
tests := []struct {
name string
args args
res res
}{
{
name: "push 1 event",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "1"),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 1,
aggIDs: []string{"1"},
aggType: eventstore.AggregateType(t.Name()),
}},
},
{
name: "push two events on agg",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "6"),
generateCommand(eventstore.AggregateType(t.Name()), "6"),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 2,
aggIDs: []string{"6"},
aggType: eventstore.AggregateType(t.Name()),
},
},
},
{
name: "failed push because context canceled",
args: args{
ctx: canceledCtx(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "9"),
},
},
res: res{
wantErr: true,
eventsRes: eventsRes{
pushedEventsCount: 0,
aggIDs: []string{"9"},
aggType: eventstore.AggregateType(t.Name()),
},
},
},
{
name: "push 1 event and add unique constraint",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "10",
generateAddUniqueConstraint("usernames", "field"),
),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 1,
uniqueCount: 1,
aggIDs: []string{"10"},
aggType: eventstore.AggregateType(t.Name()),
}},
},
{
name: "push 1 event and remove unique constraint",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "11",
generateRemoveUniqueConstraint("usernames", "testremove"),
),
},
uniqueDataType: "usernames",
uniqueDataField: "testremove",
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 1,
uniqueCount: 0,
aggIDs: []string{"11"},
aggType: eventstore.AggregateType(t.Name()),
}},
},
{
name: "push 1 event and remove instance unique constraints",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "12",
generateRemoveUniqueConstraint("instance", "instanceID"),
),
},
uniqueDataType: "usernames",
uniqueDataField: "testremove",
uniqueDataInstanceID: "instanceID",
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 1,
uniqueCount: 0,
aggIDs: []string{"12"},
aggType: eventstore.AggregateType(t.Name()),
}},
},
{
name: "push 1 event and add asset",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "13"),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 1,
assetCount: 1,
aggIDs: []string{"13"},
aggType: eventstore.AggregateType(t.Name()),
}},
},
{
name: "push 1 event and remove asset",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "14"),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 1,
assetCount: 0,
aggIDs: []string{"14"},
aggType: eventstore.AggregateType(t.Name()),
}},
},
}
for _, tt := range tests {
for pusherName, pusher := range pushers {
t.Run(pusherName+"/"+tt.name, func(t *testing.T) {
t.Cleanup(cleanupEventstore(clients[pusherName]))
db := eventstore.NewEventstore(
&eventstore.Config{
Querier: queriers["v2(inmemory)"],
Pusher: pusher,
},
)
if tt.args.uniqueDataType != "" && tt.args.uniqueDataField != "" {
err := fillUniqueData(tt.args.uniqueDataType, tt.args.uniqueDataField, tt.args.uniqueDataInstanceID)
if err != nil {
t.Error("unable to prefill insert unique data: ", err)
return
}
}
if _, err := db.Push(tt.args.ctx, tt.args.commands...); (err != nil) != tt.res.wantErr {
t.Errorf("eventstore.Push() error = %v, wantErr %v", err, tt.res.wantErr)
}
assertEventCount(t,
clients[pusherName],
database.TextArray[eventstore.AggregateType]{tt.res.eventsRes.aggType},
tt.res.eventsRes.aggIDs,
tt.res.eventsRes.pushedEventsCount,
)
assertUniqueConstraint(t, clients[pusherName], tt.args.commands, tt.res.eventsRes.uniqueCount)
})
}
}
}
func TestEventstore_Push_MultipleAggregate(t *testing.T) {
type args struct {
commands []eventstore.Command
}
type eventsRes struct {
pushedEventsCount int
aggType database.TextArray[eventstore.AggregateType]
aggID database.TextArray[string]
}
type res struct {
wantErr bool
eventsRes eventsRes
}
tests := []struct {
name string
args args
res res
}{
{
name: "push two aggregates",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "100"),
generateCommand(eventstore.AggregateType(t.Name()), "101"),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 2,
aggID: []string{"100", "101"},
aggType: database.TextArray[eventstore.AggregateType]{eventstore.AggregateType(t.Name())},
},
},
},
{
name: "push two aggregates both multiple events",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "102"),
generateCommand(eventstore.AggregateType(t.Name()), "102"),
generateCommand(eventstore.AggregateType(t.Name()), "103"),
generateCommand(eventstore.AggregateType(t.Name()), "103"),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 4,
aggID: []string{"102", "103"},
aggType: database.TextArray[eventstore.AggregateType]{eventstore.AggregateType(t.Name())},
},
},
},
{
name: "push two aggregates mixed multiple events",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "106"),
generateCommand(eventstore.AggregateType(t.Name()), "106"),
generateCommand(eventstore.AggregateType(t.Name()), "106"),
generateCommand(eventstore.AggregateType(t.Name()), "106"),
generateCommand(eventstore.AggregateType(t.Name()), "107"),
generateCommand(eventstore.AggregateType(t.Name()), "107"),
generateCommand(eventstore.AggregateType(t.Name()), "107"),
generateCommand(eventstore.AggregateType(t.Name()), "107"),
generateCommand(eventstore.AggregateType(t.Name()), "108"),
generateCommand(eventstore.AggregateType(t.Name()), "108"),
generateCommand(eventstore.AggregateType(t.Name()), "108"),
generateCommand(eventstore.AggregateType(t.Name()), "108"),
},
},
res: res{
wantErr: false,
eventsRes: eventsRes{
pushedEventsCount: 12,
aggID: []string{"106", "107", "108"},
aggType: database.TextArray[eventstore.AggregateType]{eventstore.AggregateType(t.Name())},
},
},
},
}
for _, tt := range tests {
for pusherName, pusher := range pushers {
t.Run(pusherName+"/"+tt.name, func(t *testing.T) {
t.Cleanup(cleanupEventstore(clients[pusherName]))
db := eventstore.NewEventstore(
&eventstore.Config{
Querier: queriers["v2(inmemory)"],
Pusher: pusher,
},
)
if _, err := db.Push(context.Background(), tt.args.commands...); (err != nil) != tt.res.wantErr {
t.Errorf("eventstore.Push() error = %v, wantErr %v", err, tt.res.wantErr)
}
assertEventCount(t, clients[pusherName], tt.res.eventsRes.aggType, tt.res.eventsRes.aggID, tt.res.eventsRes.pushedEventsCount)
})
}
}
}
func TestEventstore_Push_Parallel(t *testing.T) {
type args struct {
commands [][]eventstore.Command
}
type eventsRes struct {
pushedEventsCount int
aggTypes database.TextArray[eventstore.AggregateType]
aggIDs database.TextArray[string]
}
type res struct {
minErrCount int
eventsRes eventsRes
}
tests := []struct {
name string
args args
res res
}{
{
name: "clients push different aggregates",
args: args{
commands: [][]eventstore.Command{
{
generateCommand(eventstore.AggregateType(t.Name()), "200"),
generateCommand(eventstore.AggregateType(t.Name()), "200"),
generateCommand(eventstore.AggregateType(t.Name()), "200"),
generateCommand(eventstore.AggregateType(t.Name()), "201"),
generateCommand(eventstore.AggregateType(t.Name()), "201"),
generateCommand(eventstore.AggregateType(t.Name()), "201"),
},
{
generateCommand(eventstore.AggregateType(t.Name()), "202"),
generateCommand(eventstore.AggregateType(t.Name()), "203"),
generateCommand(eventstore.AggregateType(t.Name()), "203"),
},
},
},
res: res{
minErrCount: 0,
eventsRes: eventsRes{
aggIDs: []string{"200", "201", "202", "203"},
pushedEventsCount: 9,
aggTypes: database.TextArray[eventstore.AggregateType]{eventstore.AggregateType(t.Name())},
},
},
},
{
name: "clients push same aggregates",
args: args{
commands: [][]eventstore.Command{
{
generateCommand(eventstore.AggregateType(t.Name()), "204"),
generateCommand(eventstore.AggregateType(t.Name()), "204"),
},
{
generateCommand(eventstore.AggregateType(t.Name()), "204"),
generateCommand(eventstore.AggregateType(t.Name()), "204"),
},
{
generateCommand(eventstore.AggregateType(t.Name()), "204"),
generateCommand(eventstore.AggregateType(t.Name()), "204"),
},
{
generateCommand(eventstore.AggregateType(t.Name()), "204"),
generateCommand(eventstore.AggregateType(t.Name()), "204"),
},
},
},
res: res{
minErrCount: 0,
eventsRes: eventsRes{
aggIDs: []string{"204"},
pushedEventsCount: 8,
aggTypes: database.TextArray[eventstore.AggregateType]{eventstore.AggregateType(t.Name())},
},
},
},
{
name: "clients push different aggregates",
args: args{
commands: [][]eventstore.Command{
{
generateCommand(eventstore.AggregateType(t.Name()), "207"),
generateCommand(eventstore.AggregateType(t.Name()), "207"),
generateCommand(eventstore.AggregateType(t.Name()), "207"),
generateCommand(eventstore.AggregateType(t.Name()), "207"),
generateCommand(eventstore.AggregateType(t.Name()), "207"),
generateCommand(eventstore.AggregateType(t.Name()), "207"),
},
{
generateCommand(eventstore.AggregateType(t.Name()), "208"),
generateCommand(eventstore.AggregateType(t.Name()), "208"),
generateCommand(eventstore.AggregateType(t.Name()), "208"),
generateCommand(eventstore.AggregateType(t.Name()), "208"),
generateCommand(eventstore.AggregateType(t.Name()), "208"),
},
},
},
res: res{
minErrCount: 0,
eventsRes: eventsRes{
aggIDs: []string{"207", "208"},
pushedEventsCount: 11,
aggTypes: database.TextArray[eventstore.AggregateType]{eventstore.AggregateType(t.Name())},
},
},
},
}
for _, tt := range tests {
for pusherName, pusher := range pushers {
t.Run(pusherName+"/"+tt.name, func(t *testing.T) {
t.Cleanup(cleanupEventstore(clients[pusherName]))
db := eventstore.NewEventstore(
&eventstore.Config{
Querier: queriers["v2(inmemory)"],
Pusher: pusher,
},
)
errs := pushAggregates(db, tt.args.commands)
if len(errs) < tt.res.minErrCount {
t.Errorf("eventstore.Push() error count = %d, wanted err count %d, errs: %v", len(errs), tt.res.minErrCount, errs)
}
assertEventCount(t, clients[pusherName], tt.res.eventsRes.aggTypes, tt.res.eventsRes.aggIDs, tt.res.eventsRes.pushedEventsCount)
})
}
}
}
func TestEventstore_Push_ResourceOwner(t *testing.T) {
type args struct {
commands []eventstore.Command
}
type res struct {
resourceOwners database.TextArray[string]
}
type fields struct {
aggregateIDs database.TextArray[string]
aggregateType string
}
tests := []struct {
name string
args args
res res
fields fields
}{
{
name: "two events of same aggregate same resource owner",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "500", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
generateCommand(eventstore.AggregateType(t.Name()), "500", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
},
},
fields: fields{
aggregateIDs: []string{"500"},
aggregateType: t.Name(),
},
res: res{
resourceOwners: []string{"caos", "caos"},
},
},
{
name: "two events of different aggregate same resource owner",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "501", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
generateCommand(eventstore.AggregateType(t.Name()), "502", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
},
},
fields: fields{
aggregateIDs: []string{"501", "502"},
aggregateType: t.Name(),
},
res: res{
resourceOwners: []string{"caos", "caos"},
},
},
{
name: "two events of different aggregate different resource owner",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "503", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
generateCommand(eventstore.AggregateType(t.Name()), "504", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "zitadel" }),
},
},
fields: fields{
aggregateIDs: []string{"503", "504"},
aggregateType: t.Name(),
},
res: res{
resourceOwners: []string{"caos", "zitadel"},
},
},
{
name: "events of different aggregate different resource owner",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "505", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
generateCommand(eventstore.AggregateType(t.Name()), "505", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
generateCommand(eventstore.AggregateType(t.Name()), "506", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "zitadel" }),
generateCommand(eventstore.AggregateType(t.Name()), "506", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "zitadel" }),
},
},
fields: fields{
aggregateIDs: []string{"505", "506"},
aggregateType: t.Name(),
},
res: res{
resourceOwners: []string{"caos", "caos", "zitadel", "zitadel"},
},
},
{
name: "events of different aggregate different resource owner per event",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "507", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
generateCommand(eventstore.AggregateType(t.Name()), "507", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "ignored" }),
generateCommand(eventstore.AggregateType(t.Name()), "508", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "zitadel" }),
generateCommand(eventstore.AggregateType(t.Name()), "508", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "ignored" }),
},
},
fields: fields{
aggregateIDs: []string{"507", "508"},
aggregateType: t.Name(),
},
res: res{
resourceOwners: []string{"caos", "caos", "zitadel", "zitadel"},
},
},
{
name: "events of one aggregate different resource owner per event",
args: args{
commands: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "509", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "caos" }),
generateCommand(eventstore.AggregateType(t.Name()), "509", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "ignored" }),
generateCommand(eventstore.AggregateType(t.Name()), "509", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "ignored" }),
generateCommand(eventstore.AggregateType(t.Name()), "509", func(e *testEvent) { e.BaseEvent.Agg.ResourceOwner = "ignored" }),
},
},
fields: fields{
aggregateIDs: []string{"509"},
aggregateType: t.Name(),
},
res: res{
resourceOwners: []string{"caos", "caos", "caos", "caos"},
},
},
}
for _, tt := range tests {
for pusherName, pusher := range pushers {
t.Run(pusherName+"/"+tt.name, func(t *testing.T) {
t.Cleanup(cleanupEventstore(clients[pusherName]))
db := eventstore.NewEventstore(
&eventstore.Config{
Querier: queriers["v2(inmemory)"],
Pusher: pusher,
},
)
events, err := db.Push(context.Background(), tt.args.commands...)
if err != nil {
t.Errorf("eventstore.Push() error = %v", err)
}
if len(events) != len(tt.res.resourceOwners) {
t.Errorf("length of events (%d) and resource owners (%d) must be equal", len(events), len(tt.res.resourceOwners))
return
}
for i, event := range events {
if event.Aggregate().ResourceOwner != tt.res.resourceOwners[i] {
t.Errorf("resource owner not expected want: %q got: %q", tt.res.resourceOwners[i], event.Aggregate().ResourceOwner)
}
}
assertResourceOwners(t, clients[pusherName], tt.res.resourceOwners, tt.fields.aggregateIDs, tt.fields.aggregateType)
})
}
}
}
func pushAggregates(es *eventstore.Eventstore, aggregateCommands [][]eventstore.Command) []error {
wg := sync.WaitGroup{}
errs := make([]error, 0)
errsMu := sync.Mutex{}
wg.Add(len(aggregateCommands))
ctx, cancel := context.WithCancel(context.Background())
for _, commands := range aggregateCommands {
go func(events []eventstore.Command) {
<-ctx.Done()
_, err := es.Push(context.Background(), events...) //nolint:contextcheck
if err != nil {
errsMu.Lock()
errs = append(errs, err)
errsMu.Unlock()
}
wg.Done()
}(commands)
}
// wait till all routines are started
time.Sleep(100 * time.Millisecond)
cancel()
wg.Wait()
return errs
}
func assertResourceOwners(t *testing.T, db *database.DB, resourceOwners, aggregateIDs database.TextArray[string], aggregateType string) {
t.Helper()
eventCount := 0
err := db.Query(func(rows *sql.Rows) error {
for i := 0; rows.Next(); i++ {
var resourceOwner string
err := rows.Scan(&resourceOwner)
if err != nil {
return err
}
if resourceOwner != resourceOwners[i] {
t.Errorf("unexpected resource owner in queried event. want %q, got: %q", resourceOwners[i], resourceOwner)
}
eventCount++
}
return nil
}, "SELECT owner FROM eventstore.events2 WHERE aggregate_type = $1 AND aggregate_id = ANY($2) ORDER BY position, in_tx_order", aggregateType, aggregateIDs)
if err != nil {
t.Error("query failed: ", err)
return
}
if eventCount != len(resourceOwners) {
t.Errorf("wrong queried event count: want %d, got %d", len(resourceOwners), eventCount)
}
}
func assertEventCount(t *testing.T, db *database.DB, aggTypes database.TextArray[eventstore.AggregateType], aggIDs database.TextArray[string], maxPushedEventsCount int) {
t.Helper()
var count int
err := db.QueryRow(func(row *sql.Row) error {
return row.Scan(&count)
}, "SELECT count(*) FROM eventstore.events2 where aggregate_type = ANY($1) AND aggregate_id = ANY($2)", aggTypes, aggIDs)
if err != nil {
t.Errorf("unexpected err in row.Scan: %v", err)
return
}
if count > maxPushedEventsCount {
t.Errorf("expected push count %d got %d", maxPushedEventsCount, count)
}
}
func assertUniqueConstraint(t *testing.T, db *database.DB, commands []eventstore.Command, expectedCount int) {
t.Helper()
var uniqueConstraint *eventstore.UniqueConstraint
for _, command := range commands {
if e := command.(*testEvent); len(e.uniqueConstraints) > 0 {
uniqueConstraint = e.uniqueConstraints[0]
break
}
}
if uniqueConstraint == nil {
return
}
var uniqueCount int
err := db.QueryRow(func(row *sql.Row) error {
return row.Scan(&uniqueCount)
}, "SELECT COUNT(*) FROM eventstore.unique_constraints where unique_type = $1 AND unique_field = $2", uniqueConstraint.UniqueType, uniqueConstraint.UniqueField)
if err != nil {
t.Error("unable to query inserted rows: ", err)
return
}
if uniqueCount != expectedCount {
t.Errorf("expected unique count %d got %d", expectedCount, uniqueCount)
}
}

View File

@@ -0,0 +1,217 @@
package eventstore_test
import (
"context"
"testing"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/eventstore"
)
func TestEventstore_Filter(t *testing.T) {
type args struct {
searchQuery *eventstore.SearchQueryBuilder
}
type fields struct {
existingEvents []eventstore.Command
}
type res struct {
eventCount int
}
tests := []struct {
name string
fields fields
args args
res res
wantErr bool
}{
{
name: "aggregate type filter no events",
args: args{
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes("not found").
Builder(),
},
fields: fields{
existingEvents: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "300"),
generateCommand(eventstore.AggregateType(t.Name()), "300"),
generateCommand(eventstore.AggregateType(t.Name()), "300"),
},
},
res: res{
eventCount: 0,
},
wantErr: false,
},
{
name: "aggregate type and id filter events found",
args: args{
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(eventstore.AggregateType(t.Name())).
AggregateIDs("303").
Builder(),
},
fields: fields{
existingEvents: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "303"),
generateCommand(eventstore.AggregateType(t.Name()), "303"),
generateCommand(eventstore.AggregateType(t.Name()), "303"),
generateCommand(eventstore.AggregateType(t.Name()), "305"),
},
},
res: res{
eventCount: 3,
},
wantErr: false,
},
{
name: "exclude aggregate type and event type",
args: args{
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(eventstore.AggregateType(t.Name())).
Builder().
ExcludeAggregateIDs().
EventTypes("test.updated").
AggregateTypes(eventstore.AggregateType(t.Name())).
Builder(),
},
fields: fields{
existingEvents: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "306"),
generateCommand(
eventstore.AggregateType(t.Name()),
"306",
func(te *testEvent) {
te.EventType = "test.updated"
},
),
generateCommand(
eventstore.AggregateType(t.Name()),
"308",
),
},
},
res: res{
eventCount: 1,
},
wantErr: false,
},
}
for _, tt := range tests {
for querierName, querier := range queriers {
t.Run(querierName+"/"+tt.name, func(t *testing.T) {
t.Cleanup(cleanupEventstore(clients[querierName]))
db := eventstore.NewEventstore(
&eventstore.Config{
Querier: querier,
Pusher: pushers["v3(inmemory)"],
},
)
// setup initial data for query
if _, err := db.Push(context.Background(), tt.fields.existingEvents...); err != nil {
t.Errorf("error in setup = %v", err)
return
}
events, err := db.Filter(context.Background(), tt.args.searchQuery)
if (err != nil) != tt.wantErr {
t.Errorf("eventstore.query() error = %v, wantErr %v", err, tt.wantErr)
}
if len(events) != tt.res.eventCount {
t.Errorf("eventstore.query() expected event count: %d got %d", tt.res.eventCount, len(events))
}
})
}
}
}
func TestEventstore_LatestPosition(t *testing.T) {
type args struct {
searchQuery *eventstore.SearchQueryBuilder
}
type fields struct {
existingEvents []eventstore.Command
}
type res struct {
position decimal.Decimal
}
tests := []struct {
name string
fields fields
args args
res res
wantErr bool
}{
{
name: "aggregate type filter no sequence",
args: args{
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxPosition).
AddQuery().
AggregateTypes("not found").
Builder(),
},
fields: fields{
existingEvents: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "400"),
generateCommand(eventstore.AggregateType(t.Name()), "400"),
generateCommand(eventstore.AggregateType(t.Name()), "400"),
},
},
wantErr: false,
},
{
name: "aggregate type filter sequence",
args: args{
searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxPosition).
AddQuery().
AggregateTypes(eventstore.AggregateType(t.Name())).
Builder(),
},
fields: fields{
existingEvents: []eventstore.Command{
generateCommand(eventstore.AggregateType(t.Name()), "401"),
generateCommand(eventstore.AggregateType(t.Name()), "401"),
generateCommand(eventstore.AggregateType(t.Name()), "401"),
},
},
wantErr: false,
},
}
for _, tt := range tests {
for querierName, querier := range queriers {
t.Run(querierName+"/"+tt.name, func(t *testing.T) {
t.Cleanup(cleanupEventstore(clients[querierName]))
db := eventstore.NewEventstore(
&eventstore.Config{
Querier: querier,
Pusher: pushers["v3(inmemory)"],
},
)
// setup initial data for query
_, err := db.Push(context.Background(), tt.fields.existingEvents...)
if err != nil {
t.Errorf("error in setup = %v", err)
return
}
position, err := db.LatestPosition(context.Background(), tt.args.searchQuery)
if (err != nil) != tt.wantErr {
t.Errorf("eventstore.query() error = %v, wantErr %v", err, tt.wantErr)
}
if tt.res.position.GreaterThan(position) {
t.Errorf("eventstore.query() expected position: %v got %v", tt.res.position, position)
}
})
}
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,325 @@
package eventstore_test
import (
"context"
"testing"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
query_repo "github.com/zitadel/zitadel/internal/eventstore/repository/sql"
v3 "github.com/zitadel/zitadel/internal/eventstore/v3"
)
// ------------------------------------------------------------
// User aggregate start
// ------------------------------------------------------------
func NewUserAggregate(id string) *eventstore.Aggregate {
return eventstore.NewAggregate(
authz.NewMockContext("zitadel", "caos", "adlerhurst"),
id,
"test.user",
"v1",
)
}
// ------------------------------------------------------------
// User added event start
// ------------------------------------------------------------
type UserAddedEvent struct {
eventstore.BaseEvent `json:"-"`
FirstName string `json:"firstName"`
}
func NewUserAddedEvent(id string, firstName string) *UserAddedEvent {
return &UserAddedEvent{
FirstName: firstName,
BaseEvent: *eventstore.NewBaseEventForPush(
context.Background(),
NewUserAggregate(id),
"user.added"),
}
}
func UserAddedEventMapper() (eventstore.AggregateType, eventstore.EventType, func(eventstore.Event) (eventstore.Event, error)) {
return "user", "user.added", func(event eventstore.Event) (eventstore.Event, error) {
e := &UserAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := event.Unmarshal(e)
if err != nil {
return nil, err
}
return e, nil
}
}
func (e *UserAddedEvent) Payload() interface{} {
return e
}
func (e *UserAddedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *UserAddedEvent) Assets() []*eventstore.Asset {
return nil
}
// ------------------------------------------------------------
// User first name changed event start
// ------------------------------------------------------------
type UserFirstNameChangedEvent struct {
eventstore.BaseEvent `json:"-"`
FirstName string `json:"firstName"`
}
func NewUserFirstNameChangedEvent(id, firstName string) *UserFirstNameChangedEvent {
return &UserFirstNameChangedEvent{
FirstName: firstName,
BaseEvent: *eventstore.NewBaseEventForPush(
context.Background(),
NewUserAggregate(id),
"user.firstname.changed"),
}
}
func UserFirstNameChangedMapper() (eventstore.AggregateType, eventstore.EventType, func(eventstore.Event) (eventstore.Event, error)) {
return "user", "user.firstName.changed", func(event eventstore.Event) (eventstore.Event, error) {
e := &UserFirstNameChangedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}
err := event.Unmarshal(e)
if err != nil {
return nil, err
}
return e, nil
}
}
func (e *UserFirstNameChangedEvent) Payload() interface{} {
return e
}
func (e *UserFirstNameChangedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *UserFirstNameChangedEvent) Assets() []*eventstore.Asset {
return nil
}
// ------------------------------------------------------------
// User password checked event start
// ------------------------------------------------------------
type UserPasswordCheckedEvent struct {
eventstore.BaseEvent `json:"-"`
}
func NewUserPasswordCheckedEvent(id string) *UserPasswordCheckedEvent {
return &UserPasswordCheckedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
context.Background(),
NewUserAggregate(id),
"user.password.checked"),
}
}
func UserPasswordCheckedMapper() (eventstore.AggregateType, eventstore.EventType, func(eventstore.Event) (eventstore.Event, error)) {
return "user", "user.password.checked", func(event eventstore.Event) (eventstore.Event, error) {
return &UserPasswordCheckedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}, nil
}
}
func (e *UserPasswordCheckedEvent) Payload() interface{} {
return nil
}
func (e *UserPasswordCheckedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *UserPasswordCheckedEvent) Assets() []*eventstore.Asset {
return nil
}
// ------------------------------------------------------------
// User deleted event
// ------------------------------------------------------------
type UserDeletedEvent struct {
eventstore.BaseEvent `json:"-"`
}
func NewUserDeletedEvent(id string) *UserDeletedEvent {
return &UserDeletedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
context.Background(),
NewUserAggregate(id),
"user.deleted"),
}
}
func UserDeletedMapper() (eventstore.AggregateType, eventstore.EventType, func(eventstore.Event) (eventstore.Event, error)) {
return "user", "user.deleted", func(event eventstore.Event) (eventstore.Event, error) {
return &UserDeletedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(event),
}, nil
}
}
func (e *UserDeletedEvent) Payload() interface{} {
return nil
}
func (e *UserDeletedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *UserDeletedEvent) Assets() []*eventstore.Asset {
return nil
}
// ------------------------------------------------------------
// Users read model start
// ------------------------------------------------------------
type UsersReadModel struct {
eventstore.ReadModel
Users []*UserReadModel
}
func (rm *UsersReadModel) AppendEvents(events ...eventstore.Event) {
rm.ReadModel.AppendEvents(events...)
for _, event := range events {
switch e := event.(type) {
case *UserAddedEvent:
//insert
user := NewUserReadModel(e.Aggregate().ID)
rm.Users = append(rm.Users, user)
user.AppendEvents(e)
case *UserFirstNameChangedEvent, *UserPasswordCheckedEvent:
//update
_, user := rm.userByID(e.Aggregate().ID)
if user == nil {
return
}
user.AppendEvents(e)
case *UserDeletedEvent:
idx, _ := rm.userByID(e.Aggregate().ID)
if idx < 0 {
return
}
copy(rm.Users[idx:], rm.Users[idx+1:])
rm.Users[len(rm.Users)-1] = nil // or the zero value of T
rm.Users = rm.Users[:len(rm.Users)-1]
}
}
}
func (rm *UsersReadModel) Reduce() error {
for _, user := range rm.Users {
err := user.Reduce()
if err != nil {
return err
}
}
rm.ReadModel.Reduce()
return nil
}
func (rm *UsersReadModel) userByID(id string) (idx int, user *UserReadModel) {
for idx, user = range rm.Users {
if user.ID == id {
return idx, user
}
}
return -1, nil
}
// ------------------------------------------------------------
// User read model start
// ------------------------------------------------------------
type UserReadModel struct {
eventstore.ReadModel
ID string
FirstName string
pwCheckCount int
lastPasswordCheck time.Time
}
func NewUserReadModel(id string) *UserReadModel {
return &UserReadModel{
ID: id,
}
}
func (rm *UserReadModel) Reduce() error {
for _, event := range rm.ReadModel.Events {
switch e := event.(type) {
case *UserAddedEvent:
rm.FirstName = e.FirstName
case *UserFirstNameChangedEvent:
rm.FirstName = e.FirstName
case *UserPasswordCheckedEvent:
rm.pwCheckCount++
rm.lastPasswordCheck = e.CreationDate()
}
}
rm.ReadModel.Reduce()
return nil
}
// ------------------------------------------------------------
// Tests
// ------------------------------------------------------------
func TestUserReadModel(t *testing.T) {
es := eventstore.NewEventstore(
&eventstore.Config{
Querier: query_repo.NewPostgres(testClient),
Pusher: v3.NewEventstore(testClient),
},
)
eventstore.RegisterFilterEventMapper(UserAddedEventMapper())
eventstore.RegisterFilterEventMapper(UserFirstNameChangedMapper())
eventstore.RegisterFilterEventMapper(UserPasswordCheckedMapper())
eventstore.RegisterFilterEventMapper(UserDeletedMapper())
events, err := es.Push(context.Background(),
NewUserAddedEvent("1", "hodor"),
NewUserAddedEvent("2", "hodor"),
NewUserPasswordCheckedEvent("2"),
NewUserPasswordCheckedEvent("2"),
NewUserFirstNameChangedEvent("2", "ueli"),
NewUserDeletedEvent("2"))
if err != nil {
t.Errorf("unexpected error on push aggregates: %v", err)
}
events = append(events, nil)
t.Logf("%+v\n", events)
users := UsersReadModel{}
err = es.FilterToReducer(context.Background(), eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).AddQuery().AggregateTypes("test.user").Builder(), &users)
if err != nil {
t.Errorf("unexpected error on filter to reducer: %v", err)
}
t.Logf("%+v", users)
}

View File

@@ -0,0 +1,140 @@
package eventstore
// FieldOperation if the definition of the operation to be executed on the field
type FieldOperation struct {
// Set a field in the field table
// if [SearchField.UpsertConflictFields] are set the field will be updated if the conflict fields match
// if no [SearchField.UpsertConflictFields] are set the field will be inserted
Set *Field
// Remove fields using the map as `AND`ed conditions
Remove map[FieldType]any
}
type SearchResult struct {
Aggregate Aggregate
Object Object
FieldName string
// Value represents the stored value
// use the Unmarshal method to parse the value to the desired type
Value interface {
// Unmarshal parses the value to ptr
Unmarshal(ptr any) error
}
}
// // NumericResultValue marshals the value to the given type
type Object struct {
// Type of the object
Type string
// ID of the object
ID string
// Revision of the object, if an object evolves the revision should be increased
// analog to current projection versioning
Revision uint8
}
type Field struct {
Aggregate *Aggregate
Object Object
UpsertConflictFields []FieldType
FieldName string
Value Value
}
type Value struct {
Value any
// MustBeUnique defines if the field must be unique
// This field will replace unique constraints in the future
// If MustBeUnique is true the value must be a primitive type
MustBeUnique bool
// ShouldIndex defines if the field should be indexed
// If the field is not indexed it can not be used in search queries
// If ShouldIndex is true the value must be a primitive type
ShouldIndex bool
}
type SearchValueType int8
const (
SearchValueTypeString SearchValueType = iota
SearchValueTypeNumeric
)
// SetSearchField sets the field based on the defined parameters
// if conflictFields are set the field will be updated if the conflict fields match
func SetField(aggregate *Aggregate, object Object, fieldName string, value *Value, conflictFields ...FieldType) *FieldOperation {
return &FieldOperation{
Set: &Field{
Aggregate: aggregate,
Object: object,
UpsertConflictFields: conflictFields,
FieldName: fieldName,
Value: *value,
},
}
}
// RemoveSearchFields removes fields using the map as `AND`ed conditions
func RemoveSearchFields(clause map[FieldType]any) *FieldOperation {
return &FieldOperation{
Remove: clause,
}
}
// RemoveSearchFieldsByAggregate removes fields using the aggregate as `AND`ed conditions
func RemoveSearchFieldsByAggregate(aggregate *Aggregate) *FieldOperation {
return &FieldOperation{
Remove: map[FieldType]any{
FieldTypeInstanceID: aggregate.InstanceID,
FieldTypeResourceOwner: aggregate.ResourceOwner,
FieldTypeAggregateType: aggregate.Type,
FieldTypeAggregateID: aggregate.ID,
},
}
}
// RemoveSearchFieldsByAggregateAndObject removes fields using the aggregate and object as `AND`ed conditions
func RemoveSearchFieldsByAggregateAndObject(aggregate *Aggregate, object Object) *FieldOperation {
return &FieldOperation{
Remove: map[FieldType]any{
FieldTypeInstanceID: aggregate.InstanceID,
FieldTypeResourceOwner: aggregate.ResourceOwner,
FieldTypeAggregateType: aggregate.Type,
FieldTypeAggregateID: aggregate.ID,
FieldTypeObjectType: object.Type,
FieldTypeObjectID: object.ID,
FieldTypeObjectRevision: object.Revision,
},
}
}
// RemoveSearchFieldsByAggregateAndObjectAndField removes fields using the aggregate, object and field as `AND`ed conditions
func RemoveSearchFieldsByAggregateAndObjectAndField(aggregate *Aggregate, object Object, field string) *FieldOperation {
return &FieldOperation{
Remove: map[FieldType]any{
FieldTypeInstanceID: aggregate.InstanceID,
FieldTypeResourceOwner: aggregate.ResourceOwner,
FieldTypeAggregateType: aggregate.Type,
FieldTypeAggregateID: aggregate.ID,
FieldTypeObjectType: object.Type,
FieldTypeObjectID: object.ID,
FieldTypeObjectRevision: object.Revision,
FieldTypeFieldName: field,
},
}
}
type FieldType int8
const (
FieldTypeAggregateType FieldType = iota
FieldTypeAggregateID
FieldTypeInstanceID
FieldTypeResourceOwner
FieldTypeObjectType
FieldTypeObjectID
FieldTypeObjectRevision
FieldTypeFieldName
FieldTypeValue
)

View File

@@ -0,0 +1,89 @@
package crdb
import (
"database/sql/driver"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
)
type mockExpectation func(sqlmock.Sqlmock)
func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
d,
projectionName,
instanceID,
database.TextArray[string]{instanceID},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
)
}
}
func expectLockMultipleInstances(lockTable, workerName string, d time.Duration, instanceID1, instanceID2 string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\), \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$5\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
d,
projectionName,
instanceID1,
instanceID2,
database.TextArray[string]{instanceID1, instanceID2},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
)
}
}
func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
d,
projectionName,
instanceID,
database.TextArray[string]{instanceID},
).
WillReturnResult(driver.ResultNoRows)
}
}
func expectLockErr(lockTable, workerName string, d time.Duration, instanceID string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`INSERT INTO `+lockTable+
` \(locker_id, locked_until, projection_name, instance_id\) VALUES \(\$1, now\(\)\+\$2::INTERVAL, \$3\, \$4\)`+
` ON CONFLICT \(projection_name, instance_id\)`+
` DO UPDATE SET locker_id = \$1, locked_until = now\(\)\+\$2::INTERVAL`+
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
d,
projectionName,
instanceID,
database.TextArray[string]{instanceID},
).
WillReturnError(err)
}
}

View File

@@ -0,0 +1,107 @@
package crdb
import (
"context"
"database/sql"
"fmt"
"strconv"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/id"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
lockStmtFormat = "INSERT INTO %[1]s" +
" (locker_id, locked_until, projection_name, instance_id) VALUES %[2]s" +
" ON CONFLICT (projection_name, instance_id)" +
" DO UPDATE SET locker_id = $1, locked_until = now()+$2::INTERVAL" +
" WHERE %[1]s.projection_name = $3 AND %[1]s.instance_id = ANY ($%[3]d) AND (%[1]s.locker_id = $1 OR %[1]s.locked_until < now())"
)
type Locker interface {
Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error
Unlock(instanceIDs ...string) error
}
type locker struct {
client *sql.DB
lockStmt func(values string, instances int) string
workerName string
projectionName string
}
func NewLocker(client *sql.DB, lockTable, projectionName string) Locker {
workerName, err := id.SonyFlakeGenerator().Next()
logging.OnError(err).Panic("unable to generate lockID")
return &locker{
client: client,
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
workerName: workerName,
projectionName: projectionName,
}
}
func (h *locker) Lock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) <-chan error {
errs := make(chan error)
go h.handleLock(ctx, errs, lockDuration, instanceIDs...)
return errs
}
func (h *locker) handleLock(ctx context.Context, errs chan error, lockDuration time.Duration, instanceIDs ...string) {
renewLock := time.NewTimer(0)
for {
select {
case <-renewLock.C:
errs <- h.renewLock(ctx, lockDuration, instanceIDs...)
//refresh the lock 500ms before it times out. 500ms should be enough for one transaction
renewLock.Reset(lockDuration - (500 * time.Millisecond))
case <-ctx.Done():
close(errs)
renewLock.Stop()
return
}
}
}
func (h *locker) renewLock(ctx context.Context, lockDuration time.Duration, instanceIDs ...string) error {
lockStmt, values := h.lockStatement(lockDuration, instanceIDs)
res, err := h.client.ExecContext(ctx, lockStmt, values...)
if err != nil {
return zerrors.ThrowInternal(err, "CRDB-uaDoR", "unable to execute lock")
}
if rows, _ := res.RowsAffected(); rows == 0 {
return zerrors.ThrowAlreadyExists(nil, "CRDB-mmi4J", "projection already locked")
}
return nil
}
func (h *locker) Unlock(instanceIDs ...string) error {
lockStmt, values := h.lockStatement(0, instanceIDs)
_, err := h.client.Exec(lockStmt, values...)
if err != nil {
return zerrors.ThrowUnknown(err, "CRDB-JjfwO", "unlock failed")
}
return nil
}
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.TextArray[string]) (string, []interface{}) {
valueQueries := make([]string, len(instanceIDs))
values := make([]interface{}, len(instanceIDs)+4)
values[0] = h.workerName
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
values[1] = lockDuration
values[2] = h.projectionName
for i, instanceID := range instanceIDs {
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
values[i+3] = instanceID
}
values[len(values)-1] = instanceIDs
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
}

View File

@@ -0,0 +1,337 @@
package crdb
import (
"context"
"database/sql"
"errors"
"fmt"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/zerrors"
)
const (
workerName = "test_worker"
projectionName = "my_projection"
lockTable = "my_lock_table"
)
var (
renewNoRowsAffectedErr = zerrors.ThrowAlreadyExists(nil, "CRDB-mmi4J", "projection already locked")
errLock = errors.New("lock err")
)
func TestStatementHandler_handleLock(t *testing.T) {
type want struct {
expectations []mockExpectation
}
type args struct {
lockDuration time.Duration
ctx context.Context
errMock *errsMock
instanceIDs []string
}
tests := []struct {
name string
want want
args args
}{
{
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLockErr(lockTable, workerName, 2*time.Second, "instanceID", errLock),
},
},
args: args{
lockDuration: 2 * time.Second,
ctx: context.Background(),
errMock: &errsMock{
errs: make(chan error),
successfulIters: 2,
shouldErr: true,
},
instanceIDs: []string{"instanceID"},
},
},
{
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
},
},
args: args{
lockDuration: 2 * time.Second,
ctx: context.Background(),
errMock: &errsMock{
errs: make(chan error),
successfulIters: 2,
},
instanceIDs: []string{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
},
},
args: args{
lockDuration: 2 * time.Second,
ctx: context.Background(),
errMock: &errsMock{
errs: make(chan error),
successfulIters: 2,
},
instanceIDs: []string{"instanceID1", "instanceID2"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatal(err)
}
h := &locker{
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {
expectation(mock)
}
ctx, cancel := context.WithCancel(tt.args.ctx)
go tt.args.errMock.handleErrs(t, cancel)
go h.handleLock(ctx, tt.args.errMock.errs, tt.args.lockDuration, tt.args.instanceIDs...)
<-ctx.Done()
mock.MatchExpectationsInOrder(true)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations not met: %v", err)
}
})
}
}
func TestStatementHandler_renewLock(t *testing.T) {
type want struct {
expectations []mockExpectation
isErr func(err error) bool
}
type args struct {
lockDuration time.Duration
instanceIDs []string
}
tests := []struct {
name string
want want
args args
}{
{
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLockErr(lockTable, workerName, 1*time.Second, "instanceID", sql.ErrTxDone),
},
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
},
},
args: args{
lockDuration: 1 * time.Second,
instanceIDs: database.TextArray[string]{"instanceID"},
},
},
{
name: "lock no rows",
want: want{
expectations: []mockExpectation{
expectLockNoRows(lockTable, workerName, 2*time.Second, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, renewNoRowsAffectedErr)
},
},
args: args{
lockDuration: 2 * time.Second,
instanceIDs: database.TextArray[string]{"instanceID"},
},
},
{
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 3*time.Second, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
},
},
args: args{
lockDuration: 3 * time.Second,
instanceIDs: database.TextArray[string]{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 3*time.Second, "instanceID1", "instanceID2"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
},
},
args: args{
lockDuration: 3 * time.Second,
instanceIDs: []string{"instanceID1", "instanceID2"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatal(err)
}
h := &locker{
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {
expectation(mock)
}
err = h.renewLock(context.Background(), tt.args.lockDuration, tt.args.instanceIDs...)
if !tt.want.isErr(err) {
t.Errorf("unexpected error = %v", err)
}
mock.MatchExpectationsInOrder(true)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations not met: %v", err)
}
})
}
}
func TestStatementHandler_Unlock(t *testing.T) {
type want struct {
expectations []mockExpectation
isErr func(err error) bool
}
type args struct {
instanceID string
}
tests := []struct {
name string
args args
want want
}{
{
name: "unlock fails",
args: args{
instanceID: "instanceID",
},
want: want{
expectations: []mockExpectation{
expectLockErr(lockTable, workerName, 0, "instanceID", sql.ErrTxDone),
},
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
},
},
},
{
name: "success",
args: args{
instanceID: "instanceID",
},
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 0, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatal(err)
}
h := &locker{
projectionName: projectionName,
client: client,
workerName: workerName,
lockStmt: func(values string, instances int) string {
return fmt.Sprintf(lockStmtFormat, lockTable, values, instances)
},
}
for _, expectation := range tt.want.expectations {
expectation(mock)
}
err = h.Unlock(tt.args.instanceID)
if !tt.want.isErr(err) {
t.Errorf("unexpected error = %v", err)
}
mock.MatchExpectationsInOrder(true)
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("expectations not met: %v", err)
}
})
}
}
type errsMock struct {
errs chan error
successfulIters int
shouldErr bool
}
func (m *errsMock) handleErrs(t *testing.T, cancel func()) {
for i := 0; i < m.successfulIters; i++ {
if err := <-m.errs; err != nil {
t.Errorf("unexpected err in iteration %d: %v", i, err)
cancel()
return
}
}
if m.shouldErr {
if err := <-m.errs; err == nil {
t.Error("error must not be nil")
}
}
cancel()
}

View File

@@ -0,0 +1,14 @@
package handler
import "context"
// Init initializes the projection with the given check
type Init func(context.Context, *Check) error
type Check struct {
Executes []func(ctx context.Context, executer Executer, projectionName string) (bool, error)
}
func (c *Check) IsNoop() bool {
return len(c.Executes) == 0
}

View File

@@ -0,0 +1,9 @@
package handler
import (
"database/sql"
)
type Executer interface {
Exec(string, ...interface{}) (sql.Result, error)
}

View File

@@ -0,0 +1,95 @@
package handler
import (
"database/sql"
_ "embed"
"time"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
var (
//go:embed failed_event_set.sql
setFailedEventStmt string
//go:embed failed_event_get_count.sql
failureCountStmt string
)
type failure struct {
sequence uint64
instance string
aggregateID string
aggregateType eventstore.AggregateType
eventDate time.Time
err error
}
func failureFromEvent(event eventstore.Event, err error) *failure {
return &failure{
sequence: event.Sequence(),
instance: event.Aggregate().InstanceID,
aggregateID: event.Aggregate().ID,
aggregateType: event.Aggregate().Type,
eventDate: event.CreatedAt(),
err: err,
}
}
func failureFromStatement(statement *Statement, err error) *failure {
return &failure{
sequence: statement.Sequence,
instance: statement.Aggregate.InstanceID,
aggregateID: statement.Aggregate.ID,
aggregateType: statement.Aggregate.Type,
eventDate: statement.CreationDate,
err: err,
}
}
func (h *Handler) handleFailedStmt(tx *sql.Tx, f *failure) (shouldContinue bool) {
failureCount, err := h.failureCount(tx, f)
if err != nil {
h.logFailure(f).WithError(err).Warn("unable to get failure count")
return false
}
failureCount += 1
err = h.setFailureCount(tx, failureCount, f)
h.logFailure(f).OnError(err).Warn("unable to update failure count")
return failureCount >= h.maxFailureCount
}
func (h *Handler) failureCount(tx *sql.Tx, f *failure) (count uint8, err error) {
row := tx.QueryRow(failureCountStmt,
h.projection.Name(),
f.instance,
f.aggregateType,
f.aggregateID,
f.sequence,
)
if err = row.Err(); err != nil {
return 0, zerrors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
}
if err = row.Scan(&count); err != nil {
return 0, zerrors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
}
return count, nil
}
func (h *Handler) setFailureCount(tx *sql.Tx, count uint8, f *failure) error {
_, err := tx.Exec(setFailedEventStmt,
h.projection.Name(),
f.instance,
f.aggregateType,
f.aggregateID,
f.eventDate,
f.sequence,
count,
f.err.Error(),
)
if err != nil {
return zerrors.ThrowInternal(err, "CRDB-4Ht4x", "set failure count failed")
}
return nil
}

View File

@@ -0,0 +1,12 @@
WITH failures AS (
SELECT
failure_count
FROM
projections.failed_events2
WHERE
projection_name = $1
AND instance_id = $2
AND aggregate_type = $3
AND aggregate_id = $4
AND failed_sequence = $5
) SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count

View File

@@ -0,0 +1,31 @@
INSERT INTO projections.failed_events2 (
projection_name
, instance_id
, aggregate_type
, aggregate_id
, event_creation_date
, failed_sequence
, failure_count
, error
, last_failed
) VALUES (
$1
, $2
, $3
, $4
, $5
, $6
, $7
, $8
, now()
) ON CONFLICT (
projection_name
, aggregate_type
, aggregate_id
, failed_sequence
, instance_id
) DO UPDATE SET
failure_count = EXCLUDED.failure_count
, error = EXCLUDED.error
, last_failed = EXCLUDED.last_failed
;

View File

@@ -0,0 +1,213 @@
package handler
import (
"context"
"database/sql"
"errors"
"sync"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/eventstore"
)
type FieldHandler struct {
Handler
}
type fieldProjection struct {
name string
}
// Name implements Projection.
func (f *fieldProjection) Name() string {
return f.name
}
// Reducers implements Projection.
func (f *fieldProjection) Reducers() []AggregateReducer {
return nil
}
var _ Projection = (*fieldProjection)(nil)
// NewFieldHandler returns a projection handler which backfills the `eventstore.fields` table with historic events which
// might have existed before they had and Field Operations defined.
// The events are filtered by the mapped aggregate types and each event type for that aggregate.
func NewFieldHandler(config *Config, name string, eventTypes map[eventstore.AggregateType][]eventstore.EventType) *FieldHandler {
return &FieldHandler{
Handler: Handler{
projection: &fieldProjection{name: name},
client: config.Client,
es: config.Eventstore,
bulkLimit: config.BulkLimit,
eventTypes: eventTypes,
requeueEvery: config.RequeueEvery,
now: time.Now,
maxFailureCount: config.MaxFailureCount,
retryFailedAfter: config.RetryFailedAfter,
triggeredInstancesSync: sync.Map{},
triggerWithoutEvents: config.TriggerWithoutEvents,
txDuration: config.TransactionDuration,
},
}
}
// Trigger executes the backfill job of events for the instance currently in the context.
func (h *FieldHandler) Trigger(ctx context.Context, opts ...TriggerOpt) (err error) {
config := new(triggerConfig)
for _, opt := range opts {
opt(config)
}
cancel := h.lockInstance(ctx, config)
if cancel == nil {
return nil
}
defer cancel()
for i := 0; ; i++ {
additionalIteration, err := h.processEvents(ctx, config)
h.log().OnError(err).Info("process events failed")
h.log().WithField("iteration", i).Debug("trigger iteration")
if !additionalIteration || err != nil {
return err
}
}
}
func (h *FieldHandler) processEvents(ctx context.Context, config *triggerConfig) (additionalIteration bool, err error) {
defer func() {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// error returned if the row is currently locked by another connection
if pgErr.Code == "55P03" {
h.log().Debug("state already locked")
err = nil
additionalIteration = false
}
}
}()
txCtx := ctx
if h.txDuration > 0 {
var cancel, cancelTx func()
// add 100ms to store current state if iteration takes too long
txCtx, cancelTx = context.WithTimeout(ctx, h.txDuration+100*time.Millisecond)
defer cancelTx()
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
defer cancel()
}
tx, err := h.client.BeginTx(txCtx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return false, err
}
defer func() {
if err != nil && !errors.Is(err, &executionError{}) {
rollbackErr := tx.Rollback()
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
return
}
commitErr := tx.Commit()
if err == nil {
err = commitErr
}
}()
// always await currently running transactions
config.awaitRunning = true
currentState, err := h.currentState(ctx, tx, config)
if err != nil {
if errors.Is(err, errJustUpdated) {
return false, nil
}
return additionalIteration, err
}
// stop execution if currentState.eventTimestamp >= config.maxCreatedAt
if !config.maxPosition.IsZero() && currentState.position.GreaterThanOrEqual(config.maxPosition) {
return false, nil
}
if config.minPosition.GreaterThan(decimal.NewFromInt(0)) {
currentState.position = config.minPosition
currentState.offset = 0
}
events, additionalIteration, err := h.fetchEvents(ctx, tx, currentState)
if err != nil {
return additionalIteration, err
}
if len(events) == 0 {
err = h.setState(tx, currentState)
return additionalIteration, err
}
err = h.es.FillFields(ctx, events...)
if err != nil {
return false, err
}
err = h.setState(tx, currentState)
return additionalIteration, err
}
func (h *FieldHandler) fetchEvents(ctx context.Context, tx *sql.Tx, currentState *state) (_ []eventstore.FillFieldsEvent, additionalIteration bool, err error) {
events, err := h.es.Filter(ctx, h.eventQuery(currentState).SetTx(tx))
if err != nil || len(events) == 0 {
h.log().OnError(err).Debug("filter eventstore failed")
return nil, false, err
}
eventAmount := len(events)
idx, offset := skipPreviouslyReducedEvents(events, currentState)
if currentState.position.Equal(events[len(events)-1].Position()) {
offset += currentState.offset
}
currentState.position = events[len(events)-1].Position()
currentState.offset = offset
currentState.aggregateID = events[len(events)-1].Aggregate().ID
currentState.aggregateType = events[len(events)-1].Aggregate().Type
currentState.sequence = events[len(events)-1].Sequence()
currentState.eventTimestamp = events[len(events)-1].CreatedAt()
if idx+1 == len(events) {
return nil, false, nil
}
events = events[idx+1:]
additionalIteration = eventAmount == int(h.bulkLimit)
fillFieldsEvents := make([]eventstore.FillFieldsEvent, len(events))
highestPosition := events[len(events)-1].Position()
for i, event := range events {
if event.Position().Equal(highestPosition) {
offset++
}
fillFieldsEvents[i] = event.(eventstore.FillFieldsEvent)
}
return fillFieldsEvents, additionalIteration, nil
}
func skipPreviouslyReducedEvents(events []eventstore.Event, currentState *state) (index int, offset uint32) {
var position decimal.Decimal
for i, event := range events {
if !event.Position().Equal(position) {
offset = 0
position = event.Position()
}
offset++
if event.Position().Equal(currentState.position) &&
event.Aggregate().ID == currentState.aggregateID &&
event.Aggregate().Type == currentState.aggregateType &&
event.Sequence() == currentState.sequence {
return i, offset
}
}
return -1, 0
}

View File

@@ -0,0 +1,752 @@
package handler
import (
"context"
"database/sql"
"errors"
"math/rand"
"slices"
"sync"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/migration"
"github.com/zitadel/zitadel/internal/repository/instance"
"github.com/zitadel/zitadel/internal/repository/pseudo"
)
type EventStore interface {
InstanceIDs(ctx context.Context, query *eventstore.SearchQueryBuilder) ([]string, error)
FilterToQueryReducer(ctx context.Context, reducer eventstore.QueryReducer) error
Filter(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error)
Push(ctx context.Context, cmds ...eventstore.Command) ([]eventstore.Event, error)
FillFields(ctx context.Context, events ...eventstore.FillFieldsEvent) error
}
type Config struct {
Client *database.DB
Eventstore EventStore
BulkLimit uint16
RequeueEvery time.Duration
RetryFailedAfter time.Duration
TransactionDuration time.Duration
MaxFailureCount uint8
TriggerWithoutEvents Reduce
ActiveInstancer interface {
ActiveInstances() []string
}
}
type Handler struct {
client *database.DB
projection Projection
es EventStore
bulkLimit uint16
eventTypes map[eventstore.AggregateType][]eventstore.EventType
maxFailureCount uint8
retryFailedAfter time.Duration
requeueEvery time.Duration
txDuration time.Duration
now nowFunc
queryGlobal bool
triggeredInstancesSync sync.Map
triggerWithoutEvents Reduce
cacheInvalidations []func(ctx context.Context, aggregates []*eventstore.Aggregate)
queryInstances func() ([]string, error)
metrics *ProjectionMetrics
}
var _ migration.Migration = (*Handler)(nil)
// Execute implements migration.Migration.
func (h *Handler) Execute(ctx context.Context, startedEvent eventstore.Event) error {
start := time.Now()
logging.WithFields("projection", h.ProjectionName()).Info("projection starts prefilling")
logTicker := time.NewTicker(30 * time.Second)
go func() {
for range logTicker.C {
logging.WithFields("projection", h.ProjectionName()).Info("projection is prefilling")
}
}()
instanceIDs, err := h.existingInstances(ctx)
if err != nil {
return err
}
// default amount of workers is 10
workerCount := 10
if h.client.DB.Stats().MaxOpenConnections > 0 {
workerCount = h.client.DB.Stats().MaxOpenConnections / 4
}
// ensure that at least one worker is active
if workerCount == 0 {
workerCount = 1
}
// spawn less workers if not all workers needed
if workerCount > len(instanceIDs) {
workerCount = len(instanceIDs)
}
instances := make(chan string, workerCount)
var wg sync.WaitGroup
wg.Add(workerCount)
for i := 0; i < workerCount; i++ {
go h.executeInstances(ctx, instances, startedEvent, &wg)
}
for _, instance := range instanceIDs {
instances <- instance
}
close(instances)
wg.Wait()
logTicker.Stop()
logging.WithFields("projection", h.ProjectionName(), "took", time.Since(start)).Info("projections ended prefilling")
return nil
}
func (h *Handler) executeInstances(ctx context.Context, instances <-chan string, startedEvent eventstore.Event, wg *sync.WaitGroup) {
for instance := range instances {
h.triggerInstances(ctx, []string{instance}, WithMaxPosition(startedEvent.Position()))
}
wg.Done()
}
// String implements migration.Migration.
func (h *Handler) String() string {
return h.ProjectionName()
}
// nowFunc makes [time.Now] mockable
type nowFunc func() time.Time
type Projection interface {
Name() string
Reducers() []AggregateReducer
}
type GlobalProjection interface {
Projection
FilterGlobalEvents()
}
func NewHandler(
ctx context.Context,
config *Config,
projection Projection,
) *Handler {
aggregates := make(map[eventstore.AggregateType][]eventstore.EventType, len(projection.Reducers()))
for _, reducer := range projection.Reducers() {
eventTypes := make([]eventstore.EventType, len(reducer.EventReducers))
for i, eventReducer := range reducer.EventReducers {
eventTypes[i] = eventReducer.Event
}
if _, ok := aggregates[reducer.Aggregate]; ok {
aggregates[reducer.Aggregate] = append(aggregates[reducer.Aggregate], eventTypes...)
continue
}
aggregates[reducer.Aggregate] = eventTypes
}
metrics := NewProjectionMetrics()
handler := &Handler{
projection: projection,
client: config.Client,
es: config.Eventstore,
bulkLimit: config.BulkLimit,
eventTypes: aggregates,
requeueEvery: config.RequeueEvery,
now: time.Now,
maxFailureCount: config.MaxFailureCount,
retryFailedAfter: config.RetryFailedAfter,
triggeredInstancesSync: sync.Map{},
triggerWithoutEvents: config.TriggerWithoutEvents,
txDuration: config.TransactionDuration,
queryInstances: func() ([]string, error) {
if config.ActiveInstancer != nil {
return config.ActiveInstancer.ActiveInstances(), nil
}
return nil, nil
},
metrics: metrics,
}
if _, ok := projection.(GlobalProjection); ok {
handler.queryGlobal = true
}
return handler
}
func (h *Handler) Start(ctx context.Context) {
go h.schedule(ctx)
if h.triggerWithoutEvents != nil {
return
}
go h.subscribe(ctx)
}
type checkInit struct {
didInit bool
projectionName string
}
// AppendEvents implements eventstore.QueryReducer.
func (ci *checkInit) AppendEvents(...eventstore.Event) {
ci.didInit = true
}
// Query implements eventstore.QueryReducer.
func (ci *checkInit) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
Limit(1).
InstanceID("").
AddQuery().
AggregateTypes(migration.SystemAggregate).
AggregateIDs(migration.SystemAggregateID).
EventTypes(migration.DoneType).
EventData(map[string]interface{}{
"name": ci.projectionName,
}).
Builder()
}
// Reduce implements eventstore.QueryReducer.
func (*checkInit) Reduce() error {
return nil
}
var _ eventstore.QueryReducer = (*checkInit)(nil)
func (h *Handler) didInitialize(ctx context.Context) bool {
initiated := checkInit{
projectionName: h.ProjectionName(),
}
err := h.es.FilterToQueryReducer(ctx, &initiated)
if err != nil {
return false
}
return initiated.didInit
}
func (h *Handler) schedule(ctx context.Context) {
// start the projection and its configured `RequeueEvery`
reset := randomizeStart(0, h.requeueEvery.Seconds())
if !h.didInitialize(ctx) {
reset = randomizeStart(0, 0.5)
}
t := time.NewTimer(reset)
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
instances, err := h.queryInstances()
h.log().OnError(err).Debug("unable to query instances")
h.triggerInstances(call.WithTimestamp(ctx), instances)
t.Reset(h.requeueEvery)
}
}
}
func (h *Handler) triggerInstances(ctx context.Context, instances []string, triggerOpts ...TriggerOpt) {
for _, instance := range instances {
instanceCtx := authz.WithInstanceID(ctx, instance)
// simple implementation of do while
_, err := h.Trigger(instanceCtx, triggerOpts...)
// skip retry if everything is fine
if err == nil {
continue
}
h.log().WithField("instance", instance).WithError(err).Debug("trigger failed")
time.Sleep(h.retryFailedAfter)
// retry if trigger failed
for ; err != nil; _, err = h.Trigger(instanceCtx, triggerOpts...) {
time.Sleep(h.retryFailedAfter)
h.log().WithField("instance", instance).WithError(err).Debug("trigger failed")
}
}
}
func randomizeStart(min, maxSeconds float64) time.Duration {
d := min + rand.Float64()*(maxSeconds-min)
return time.Duration(d*1000) * time.Millisecond
}
func (h *Handler) subscribe(ctx context.Context) {
queue := make(chan eventstore.Event, 100)
subscription := eventstore.SubscribeEventTypes(queue, h.eventTypes)
for {
select {
case <-ctx.Done():
subscription.Unsubscribe()
h.log().Debug("shutdown")
return
case event := <-queue:
events := checkAdditionalEvents(queue, event)
solvedInstances := make([]string, 0, len(events))
queueCtx := call.WithTimestamp(ctx)
for _, e := range events {
if instanceSolved(solvedInstances, e.Aggregate().InstanceID) {
continue
}
queueCtx = authz.WithInstanceID(queueCtx, e.Aggregate().InstanceID)
_, err := h.Trigger(queueCtx)
h.log().OnError(err).Debug("trigger of queued event failed")
if err == nil {
solvedInstances = append(solvedInstances, e.Aggregate().InstanceID)
}
}
}
}
}
func instanceSolved(solvedInstances []string, instanceID string) bool {
for _, solvedInstance := range solvedInstances {
if solvedInstance == instanceID {
return true
}
}
return false
}
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
events := make([]eventstore.Event, 1)
events[0] = event
for {
wait := time.NewTimer(1 * time.Millisecond)
select {
case event := <-eventQueue:
events = append(events, event)
case <-wait.C:
return events
}
}
}
type existingInstances []string
// AppendEvents implements eventstore.QueryReducer.
func (ai *existingInstances) AppendEvents(events ...eventstore.Event) {
for _, event := range events {
switch event.Type() {
case instance.InstanceAddedEventType:
*ai = append(*ai, event.Aggregate().InstanceID)
case instance.InstanceRemovedEventType:
*ai = slices.DeleteFunc(*ai, func(s string) bool {
return s == event.Aggregate().InstanceID
})
}
}
}
// Query implements eventstore.QueryReducer.
func (*existingInstances) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(instance.AggregateType).
EventTypes(
instance.InstanceAddedEventType,
instance.InstanceRemovedEventType,
).
Builder()
}
// Reduce implements eventstore.QueryReducer.
// reduce is not used as events are reduced during AppendEvents
func (*existingInstances) Reduce() error {
return nil
}
var _ eventstore.QueryReducer = (*existingInstances)(nil)
func (h *Handler) existingInstances(ctx context.Context) ([]string, error) {
ai := existingInstances{}
if err := h.es.FilterToQueryReducer(ctx, &ai); err != nil {
return nil, err
}
return ai, nil
}
type triggerConfig struct {
awaitRunning bool
maxPosition decimal.Decimal
minPosition decimal.Decimal
}
type TriggerOpt func(conf *triggerConfig)
func WithAwaitRunning() TriggerOpt {
return func(conf *triggerConfig) {
conf.awaitRunning = true
}
}
func WithMaxPosition(position decimal.Decimal) TriggerOpt {
return func(conf *triggerConfig) {
conf.maxPosition = position
}
}
func WithMinPosition(position decimal.Decimal) TriggerOpt {
return func(conf *triggerConfig) {
conf.minPosition = position
}
}
func (h *Handler) Trigger(ctx context.Context, opts ...TriggerOpt) (_ context.Context, err error) {
config := new(triggerConfig)
for _, opt := range opts {
opt(config)
}
cancel := h.lockInstance(ctx, config)
if cancel == nil {
return call.ResetTimestamp(ctx), nil
}
defer cancel()
for i := 0; ; i++ {
additionalIteration, err := h.processEvents(ctx, config)
h.log().OnError(err).Info("process events failed")
h.log().WithField("iteration", i).Debug("trigger iteration")
if !additionalIteration || err != nil {
return call.ResetTimestamp(ctx), err
}
}
}
// RegisterCacheInvalidation registers a function to be called when a cache needs to be invalidated.
// In order to avoid race conditions, this method must be called before [Handler.Start] is called.
func (h *Handler) RegisterCacheInvalidation(invalidate func(ctx context.Context, aggregates []*eventstore.Aggregate)) {
h.cacheInvalidations = append(h.cacheInvalidations, invalidate)
}
// lockInstance tries to lock the instance.
// If the instance is already locked from another process no cancel function is returned
// the instance can be skipped then
// If the instance is locked, an unlock deferrable function is returned
func (h *Handler) lockInstance(ctx context.Context, config *triggerConfig) func() {
instanceID := authz.GetInstance(ctx).InstanceID()
// Check that the instance has a lock
instanceLock, _ := h.triggeredInstancesSync.LoadOrStore(instanceID, make(chan bool, 1))
// in case we don't want to wait for a running trigger / lock (e.g. spooler),
// we can directly return if we cannot lock
if !config.awaitRunning {
select {
case instanceLock.(chan bool) <- true:
return func() {
<-instanceLock.(chan bool)
}
default:
return nil
}
}
// in case we want to wait for a running trigger / lock (e.g. query),
// we try to lock as long as the context is not cancelled
select {
case instanceLock.(chan bool) <- true:
return func() {
<-instanceLock.(chan bool)
}
case <-ctx.Done():
return nil
}
}
func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (additionalIteration bool, err error) {
defer func() {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// error returned if the row is currently locked by another connection
if pgErr.Code == "55P03" {
h.log().Debug("state already locked")
err = nil
additionalIteration = false
}
}
}()
txCtx := ctx
if h.txDuration > 0 {
var cancel, cancelTx func()
// add 100ms to store current state if iteration takes too long
txCtx, cancelTx = context.WithTimeout(ctx, h.txDuration+100*time.Millisecond)
defer cancelTx()
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
defer cancel()
}
start := time.Now()
tx, err := h.client.BeginTx(txCtx, nil)
if err != nil {
return false, err
}
defer func() {
if err != nil && !errors.Is(err, &executionError{}) {
rollbackErr := tx.Rollback()
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
return
}
}()
currentState, err := h.currentState(ctx, tx, config)
if err != nil {
if errors.Is(err, errJustUpdated) {
return false, nil
}
return additionalIteration, err
}
// stop execution if currentState.position >= config.maxPosition
if !config.maxPosition.Equal(decimal.Decimal{}) && currentState.position.GreaterThanOrEqual(config.maxPosition) {
return false, nil
}
if config.minPosition.GreaterThan(decimal.NewFromInt(0)) {
currentState.position = config.minPosition
currentState.offset = 0
}
var statements []*Statement
statements, additionalIteration, err = h.generateStatements(ctx, tx, currentState)
if err != nil {
return additionalIteration, err
}
defer func() {
commitErr := tx.Commit()
if err == nil {
err = commitErr
}
h.metrics.ProjectionEventsProcessed(ctx, h.ProjectionName(), int64(len(statements)), err == nil)
if err == nil && currentState.aggregateID != "" && len(statements) > 0 {
// Don't update projection timing or latency unless we successfully processed events
h.metrics.ProjectionUpdateTiming(ctx, h.ProjectionName(), float64(time.Since(start).Seconds()))
h.metrics.ProjectionStateLatency(ctx, h.ProjectionName(), time.Since(currentState.eventTimestamp).Seconds())
h.invalidateCaches(ctx, aggregatesFromStatements(statements))
}
}()
if len(statements) == 0 {
err = h.setState(tx, currentState)
return additionalIteration, err
}
lastProcessedIndex, err := h.executeStatements(ctx, tx, statements)
h.log().OnError(err).WithField("lastProcessedIndex", lastProcessedIndex).Debug("execution of statements failed")
if lastProcessedIndex < 0 {
return false, err
}
currentState.position = statements[lastProcessedIndex].Position
currentState.offset = statements[lastProcessedIndex].offset
currentState.aggregateID = statements[lastProcessedIndex].Aggregate.ID
currentState.aggregateType = statements[lastProcessedIndex].Aggregate.Type
currentState.sequence = statements[lastProcessedIndex].Sequence
currentState.eventTimestamp = statements[lastProcessedIndex].CreationDate
setStateErr := h.setState(tx, currentState)
if setStateErr != nil {
err = setStateErr
}
return additionalIteration, err
}
func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentState *state) (_ []*Statement, additionalIteration bool, err error) {
if h.triggerWithoutEvents != nil {
stmt, err := h.triggerWithoutEvents(pseudo.NewScheduledEvent(ctx, time.Now(), currentState.instanceID))
if err != nil {
return nil, false, err
}
return []*Statement{stmt}, false, nil
}
events, err := h.es.Filter(ctx, h.eventQuery(currentState).SetTx(tx))
if err != nil {
h.log().WithError(err).Debug("filter eventstore failed")
return nil, false, err
}
eventAmount := len(events)
statements, err := h.eventsToStatements(tx, events, currentState)
if err != nil || len(statements) == 0 {
return nil, false, err
}
idx := skipPreviouslyReducedStatements(statements, currentState)
if idx+1 == len(statements) {
currentState.position = statements[len(statements)-1].Position
currentState.offset = statements[len(statements)-1].offset
currentState.aggregateID = statements[len(statements)-1].Aggregate.ID
currentState.aggregateType = statements[len(statements)-1].Aggregate.Type
currentState.sequence = statements[len(statements)-1].Sequence
currentState.eventTimestamp = statements[len(statements)-1].CreationDate
return nil, false, nil
}
statements = statements[idx+1:]
additionalIteration = eventAmount == int(h.bulkLimit)
if len(statements) < len(events) {
// retry immediately if statements failed
additionalIteration = true
}
return statements, additionalIteration, nil
}
func skipPreviouslyReducedStatements(statements []*Statement, currentState *state) int {
for i, statement := range statements {
if statement.Position.Equal(currentState.position) &&
statement.Aggregate.ID == currentState.aggregateID &&
statement.Aggregate.Type == currentState.aggregateType &&
statement.Sequence == currentState.sequence {
return i
}
}
return -1
}
func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, statements []*Statement) (lastProcessedIndex int, err error) {
lastProcessedIndex = -1
for i, statement := range statements {
select {
case <-ctx.Done():
return lastProcessedIndex, ctx.Err()
default:
err := h.executeStatement(ctx, tx, statement)
if err != nil {
return lastProcessedIndex, err
}
lastProcessedIndex = i
}
}
return lastProcessedIndex, nil
}
func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, statement *Statement) (err error) {
if statement.Execute == nil {
return nil
}
_, err = tx.ExecContext(ctx, "SAVEPOINT exec_stmt")
if err != nil {
h.log().WithError(err).Debug("create savepoint failed")
return err
}
if err = statement.Execute(ctx, tx, h.projection.Name()); err != nil {
h.log().WithError(err).Error("statement execution failed")
_, rollbackErr := tx.ExecContext(ctx, "ROLLBACK TO SAVEPOINT exec_stmt")
h.log().OnError(rollbackErr).Error("rollback to savepoint failed")
shouldContinue := h.handleFailedStmt(tx, failureFromStatement(statement, err))
if shouldContinue {
return nil
}
return &executionError{parent: err}
}
return nil
}
func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder {
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
Limit(uint64(h.bulkLimit)).
OrderAsc().
InstanceID(currentState.instanceID)
if currentState.position.GreaterThan(decimal.Decimal{}) {
builder = builder.PositionAtLeast(currentState.position)
if currentState.offset > 0 {
builder = builder.Offset(currentState.offset)
}
}
if h.queryGlobal {
return builder
}
aggregateTypes := make([]eventstore.AggregateType, 0, len(h.eventTypes))
eventTypes := make([]eventstore.EventType, 0, len(h.eventTypes))
for aggregate, events := range h.eventTypes {
aggregateTypes = append(aggregateTypes, aggregate)
eventTypes = append(eventTypes, events...)
}
return builder.AddQuery().AggregateTypes(aggregateTypes...).EventTypes(eventTypes...).Builder()
}
// ProjectionName returns the name of the underlying projection.
func (h *Handler) ProjectionName() string {
return h.projection.Name()
}
func (h *Handler) invalidateCaches(ctx context.Context, aggregates []*eventstore.Aggregate) {
if len(h.cacheInvalidations) == 0 {
return
}
var wg sync.WaitGroup
wg.Add(len(h.cacheInvalidations))
for _, invalidate := range h.cacheInvalidations {
go func(invalidate func(context.Context, []*eventstore.Aggregate)) {
defer wg.Done()
invalidate(ctx, aggregates)
}(invalidate)
}
wg.Wait()
}
// aggregatesFromStatements returns the unique aggregates from statements.
// Duplicate aggregates are omitted.
func aggregatesFromStatements(statements []*Statement) []*eventstore.Aggregate {
aggregates := make([]*eventstore.Aggregate, 0, len(statements))
for _, statement := range statements {
if !slices.ContainsFunc(aggregates, func(aggregate *eventstore.Aggregate) bool {
return *statement.Aggregate == *aggregate
}) {
aggregates = append(aggregates, statement.Aggregate)
}
}
return aggregates
}

View File

@@ -0,0 +1,425 @@
package handler
import (
"context"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore/handler"
"github.com/zitadel/zitadel/internal/zerrors"
)
type Table struct {
columns []*InitColumn
primaryKey PrimaryKey
indices []*Index
constraints []*Constraint
foreignKeys []*ForeignKey
}
func NewTable(columns []*InitColumn, key PrimaryKey, opts ...TableOption) *Table {
t := &Table{
columns: columns,
primaryKey: key,
}
for _, opt := range opts {
opt(t)
}
return t
}
type SuffixedTable struct {
Table
suffix string
}
func NewSuffixedTable(columns []*InitColumn, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
return &SuffixedTable{
Table: *NewTable(columns, key, opts...),
suffix: suffix,
}
}
type TableOption func(*Table)
func WithIndex(index *Index) TableOption {
return func(table *Table) {
table.indices = append(table.indices, index)
}
}
func WithConstraint(constraint *Constraint) TableOption {
return func(table *Table) {
table.constraints = append(table.constraints, constraint)
}
}
func WithForeignKey(key *ForeignKey) TableOption {
return func(table *Table) {
table.foreignKeys = append(table.foreignKeys, key)
}
}
type InitColumn struct {
Name string
Type ColumnType
nullable bool
defaultValue interface{}
deleteCascade string
}
type ColumnOption func(*InitColumn)
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *InitColumn {
column := &InitColumn{
Name: name,
Type: columnType,
nullable: false,
defaultValue: nil,
}
for _, opt := range opts {
opt(column)
}
return column
}
func Nullable() ColumnOption {
return func(c *InitColumn) {
c.nullable = true
}
}
func Default(value interface{}) ColumnOption {
return func(c *InitColumn) {
c.defaultValue = value
}
}
func DeleteCascade(column string) ColumnOption {
return func(c *InitColumn) {
c.deleteCascade = column
}
}
type PrimaryKey []string
func NewPrimaryKey(columnNames ...string) PrimaryKey {
return columnNames
}
type ColumnType int32
const (
ColumnTypeText ColumnType = iota
ColumnTypeTextArray
ColumnTypeJSONB
ColumnTypeBytes
ColumnTypeTimestamp
ColumnTypeInterval
ColumnTypeEnum
ColumnTypeEnumArray
ColumnTypeInt64
ColumnTypeBool
)
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
i := &Index{
Name: name,
Columns: columns,
}
for _, opt := range opts {
opt(i)
}
return i
}
type Index struct {
Name string
Columns []string
includes []string
}
type indexOpts func(*Index)
func WithInclude(columns ...string) indexOpts {
return func(i *Index) {
i.includes = columns
}
}
func NewConstraint(name string, columns []string) *Constraint {
i := &Constraint{
Name: name,
Columns: columns,
}
return i
}
type Constraint struct {
Name string
Columns []string
}
func NewForeignKey(name string, columns []string, refColumns []string) *ForeignKey {
i := &ForeignKey{
Name: name,
Columns: columns,
RefColumns: refColumns,
}
return i
}
func NewForeignKeyOfPublicKeys() *ForeignKey {
return &ForeignKey{
Name: "",
}
}
type ForeignKey struct {
Name string
Columns []string
RefColumns []string
}
type initializer interface {
Init() *handler.Check
}
func (h *Handler) Init(ctx context.Context) error {
check, ok := h.projection.(initializer)
if !ok || check.Init().IsNoop() {
return nil
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return zerrors.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
}
for i, execute := range check.Init().Executes {
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("executing check")
next, err := execute(ctx, tx, h.projection.Name())
if err != nil {
logging.OnError(tx.Rollback()).Debug("unable to rollback")
return err
}
if !next {
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("projection set up")
break
}
}
return tx.Commit()
}
func NewTableCheck(table *Table, opts ...execOption) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
return createTableStatement(table, config.tableName, "")
}
executes := make([]func(context.Context, handler.Executer, string) (bool, error), len(table.indices)+1)
executes[0] = execNextIfExists(config, create, opts, true)
for i, index := range table.indices {
executes[i+1] = execNextIfExists(config, createIndexCheck(index), opts, true)
}
return &handler.Check{
Executes: executes,
}
}
func NewMultiTableCheck(primaryTable *Table, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
stmt := createTableStatement(primaryTable, config.tableName, "")
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
return stmt
}
return &handler.Check{
Executes: []func(context.Context, handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, true),
},
}
}
func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
var stmt string
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
stmt += createViewStatement(config.tableName, selectStmt)
return stmt
}
return &handler.Check{
Executes: []func(context.Context, handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, false),
},
}
}
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(ctx context.Context, handler handler.Executer, name string) (bool, error) {
return func(ctx context.Context, handler handler.Executer, name string) (shouldExecuteNext bool, err error) {
_, err = handler.Exec("SAVEPOINT exec_stmt")
if err != nil {
return false, zerrors.ThrowInternal(err, "V2-U1wlz", "create savepoint failed")
}
defer func() {
if err == nil {
return
}
if isErrAlreadyExists(err) {
_, err = handler.Exec("ROLLBACK TO SAVEPOINT exec_stmt")
shouldExecuteNext = executeNext
return
}
}()
err = exec(config, q, opts)(ctx, handler, name)
return false, err
}
}
func isErrAlreadyExists(err error) bool {
caosErr := &zerrors.ZitadelError{}
if !errors.As(err, &caosErr) {
return false
}
pgErr := new(pgconn.PgError)
if errors.As(caosErr.Parent, &pgErr) {
return pgErr.Code == "42P07"
}
return false
}
func createTableStatement(table *Table, tableName string, suffix string) string {
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
tableName+suffix,
createColumnsStatement(table.columns, tableName),
strings.Join(table.primaryKey, ", "),
)
for _, key := range table.foreignKeys {
ref := tableName
if len(key.RefColumns) > 0 {
ref += fmt.Sprintf("(%s)", strings.Join(key.RefColumns, ","))
}
if len(key.Columns) == 0 {
key.Columns = table.primaryKey
}
stmt += fmt.Sprintf(", CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE CASCADE", foreignKeyName(key.Name, tableName, suffix), strings.Join(key.Columns, ","), ref)
}
for _, constraint := range table.constraints {
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraintName(constraint.Name, tableName, suffix), strings.Join(constraint.Columns, ","))
}
stmt += ");"
for _, index := range table.indices {
stmt += createIndexStatement(index, tableName+suffix)
}
return stmt
}
func createViewStatement(viewName string, selectStmt string) string {
return fmt.Sprintf("CREATE VIEW %s AS %s",
viewName,
selectStmt,
)
}
func createIndexCheck(index *Index) func(config execConfig) string {
return func(config execConfig) string {
return createIndexStatement(index, config.tableName)
}
}
func createIndexStatement(index *Index, tableName string) string {
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
indexName(index.Name, tableName),
tableName,
strings.Join(index.Columns, ","),
)
if len(index.includes) > 0 {
stmt += " INCLUDE (" + strings.Join(index.includes, ", ") + ")"
}
return stmt + ";"
}
func foreignKeyName(name, tableName, suffix string) string {
if name == "" {
key := "fk" + suffix + "_ref_" + tableNameWithoutSchema(tableName)
return key
}
return "fk_" + tableNameWithoutSchema(tableName+suffix) + "_" + name
}
func constraintName(name, tableName, suffix string) string {
return tableNameWithoutSchema(tableName+suffix) + "_" + name + "_unique"
}
func indexName(name, tableName string) string {
return tableNameWithoutSchema(tableName) + "_" + name + "_idx"
}
func tableNameWithoutSchema(name string) string {
return name[strings.LastIndex(name, ".")+1:]
}
func createColumnsStatement(cols []*InitColumn, tableName string) string {
columns := make([]string, len(cols))
for i, col := range cols {
column := col.Name + " " + columnType(col.Type)
if !col.nullable {
column += " NOT NULL"
}
if col.defaultValue != nil {
column += " DEFAULT " + defaultValue(col.defaultValue)
}
if len(col.deleteCascade) != 0 {
column += fmt.Sprintf(" REFERENCES %s (%s) ON DELETE CASCADE", tableName, col.deleteCascade)
}
columns[i] = column
}
return strings.Join(columns, ",")
}
func defaultValue(value interface{}) string {
switch v := value.(type) {
case string:
return "'" + v + "'"
case fmt.Stringer:
return fmt.Sprintf("%#v", v)
default:
return fmt.Sprintf("%v", v)
}
}
func columnType(columnType ColumnType) string {
switch columnType {
case ColumnTypeText:
return "TEXT"
case ColumnTypeTextArray:
return "TEXT[]"
case ColumnTypeTimestamp:
return "TIMESTAMPTZ"
case ColumnTypeInterval:
return "INTERVAL"
case ColumnTypeEnum:
return "SMALLINT"
case ColumnTypeEnumArray:
return "SMALLINT[]"
case ColumnTypeInt64:
return "BIGINT"
case ColumnTypeBool:
return "BOOLEAN"
case ColumnTypeJSONB:
return "JSONB"
case ColumnTypeBytes:
return "BYTEA"
default:
panic("unknown column type")
}
}

View File

@@ -0,0 +1,23 @@
package handler
import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"
)
func (h *Handler) log() *logging.Entry {
return logging.WithFields("projection", h.projection.Name())
}
func (h *Handler) logFailure(fail *failure) *logging.Entry {
return h.log().WithField("sequence", fail.sequence).
WithField("instance", fail.instance).
WithField("aggregate", fail.aggregateID)
}
func (h *Handler) logEvent(event eventstore.Event) *logging.Entry {
return h.log().WithField("sequence", event.Sequence()).
WithField("instance", event.Aggregate().InstanceID).
WithField("aggregate", event.Aggregate().Type)
}

View File

@@ -0,0 +1,70 @@
package handler
import (
"context"
"github.com/zitadel/logging"
"go.opentelemetry.io/otel/attribute"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
)
const (
ProjectionLabel = "projection"
SuccessLabel = "success"
ProjectionEventsProcessed = "projection_events_processed"
ProjectionHandleTimerMetric = "projection_handle_timer"
ProjectionStateLatencyMetric = "projection_state_latency"
)
type ProjectionMetrics struct {
provider metrics.Metrics
}
func NewProjectionMetrics() *ProjectionMetrics {
projectionMetrics := &ProjectionMetrics{provider: metrics.M}
err := projectionMetrics.provider.RegisterCounter(
ProjectionEventsProcessed,
"Number of events reduced to process projection updates",
)
logging.OnError(err).Error("failed to register projection events processed counter")
err = projectionMetrics.provider.RegisterHistogram(
ProjectionHandleTimerMetric,
"Time taken to process a projection update",
"s",
[]float64{0.005, 0.01, 0.05, 0.1, 1, 5, 10, 30, 60, 120},
)
logging.OnError(err).Error("failed to register projection handle timer metric")
err = projectionMetrics.provider.RegisterHistogram(
ProjectionStateLatencyMetric,
"When finishing processing a batch of events, this track the age of the last events seen from current time",
"s",
[]float64{0.1, 0.5, 1, 5, 10, 30, 60, 300, 600, 1800},
)
logging.OnError(err).Error("failed to register projection state latency metric")
return projectionMetrics
}
func (m *ProjectionMetrics) ProjectionUpdateTiming(ctx context.Context, projection string, duration float64) {
err := m.provider.AddHistogramMeasurement(ctx, ProjectionHandleTimerMetric, duration, map[string]attribute.Value{
ProjectionLabel: attribute.StringValue(projection),
})
logging.OnError(err).Error("failed to add projection trigger timing")
}
func (m *ProjectionMetrics) ProjectionEventsProcessed(ctx context.Context, projection string, count int64, success bool) {
err := m.provider.AddCount(ctx, ProjectionEventsProcessed, count, map[string]attribute.Value{
ProjectionLabel: attribute.StringValue(projection),
SuccessLabel: attribute.BoolValue(success),
})
logging.OnError(err).Error("failed to add projection events processed metric")
}
func (m *ProjectionMetrics) ProjectionStateLatency(ctx context.Context, projection string, latency float64) {
err := m.provider.AddHistogramMeasurement(ctx, ProjectionStateLatencyMetric, latency, map[string]attribute.Value{
ProjectionLabel: attribute.StringValue(projection),
})
logging.OnError(err).Error("failed to add projection state latency metric")
}

View File

@@ -0,0 +1,132 @@
package handler
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
)
func TestNewProjectionMetrics(t *testing.T) {
mockMetrics := metrics.NewMockMetrics()
metrics.M = mockMetrics
metrics := NewProjectionMetrics()
require.NotNil(t, metrics)
assert.NotNil(t, metrics.provider)
}
func TestProjectionMetrics_ProjectionUpdateTiming(t *testing.T) {
mockMetrics := metrics.NewMockMetrics()
metrics.M = mockMetrics
projectionMetrics := NewProjectionMetrics()
ctx := context.Background()
projection := "test_projection"
duration := 0.5
projectionMetrics.ProjectionUpdateTiming(ctx, projection, duration)
values := mockMetrics.GetHistogramValues(ProjectionHandleTimerMetric)
require.Len(t, values, 1)
assert.Equal(t, duration, values[0])
labels := mockMetrics.GetHistogramLabels(ProjectionHandleTimerMetric)
require.Len(t, labels, 1)
assert.Equal(t, projection, labels[0][ProjectionLabel].AsString())
}
func TestProjectionMetrics_ProjectionEventsProcessed(t *testing.T) {
mockMetrics := metrics.NewMockMetrics()
metrics.M = mockMetrics
projectionMetrics := NewProjectionMetrics()
ctx := context.Background()
projection := "test_projection"
count := int64(5)
success := true
projectionMetrics.ProjectionEventsProcessed(ctx, projection, count, success)
value := mockMetrics.GetCounterValue(ProjectionEventsProcessed)
assert.Equal(t, count, value)
labels := mockMetrics.GetCounterLabels(ProjectionEventsProcessed)
require.Len(t, labels, 1)
assert.Equal(t, projection, labels[0][ProjectionLabel].AsString())
assert.Equal(t, success, labels[0][SuccessLabel].AsBool())
}
func TestProjectionMetrics_ProjectionStateLatency(t *testing.T) {
mockMetrics := metrics.NewMockMetrics()
metrics.M = mockMetrics
projectionMetrics := NewProjectionMetrics()
ctx := context.Background()
projection := "test_projection"
latency := 10.0
projectionMetrics.ProjectionStateLatency(ctx, projection, latency)
values := mockMetrics.GetHistogramValues(ProjectionStateLatencyMetric)
require.Len(t, values, 1)
assert.Equal(t, latency, values[0])
labels := mockMetrics.GetHistogramLabels(ProjectionStateLatencyMetric)
require.Len(t, labels, 1)
assert.Equal(t, projection, labels[0][ProjectionLabel].AsString())
}
func TestProjectionMetrics_Integration(t *testing.T) {
mockMetrics := metrics.NewMockMetrics()
metrics.M = mockMetrics
projectionMetrics := NewProjectionMetrics()
ctx := context.Background()
projection := "test_projection"
start := time.Now()
projectionMetrics.ProjectionEventsProcessed(ctx, projection, 3, true)
projectionMetrics.ProjectionEventsProcessed(ctx, projection, 1, false)
duration := time.Since(start).Seconds()
projectionMetrics.ProjectionUpdateTiming(ctx, projection, duration)
latency := 5.0
projectionMetrics.ProjectionStateLatency(ctx, projection, latency)
value := mockMetrics.GetCounterValue(ProjectionEventsProcessed)
assert.Equal(t, int64(4), value)
timingValues := mockMetrics.GetHistogramValues(ProjectionHandleTimerMetric)
require.Len(t, timingValues, 1)
assert.Equal(t, duration, timingValues[0])
latencyValues := mockMetrics.GetHistogramValues(ProjectionStateLatencyMetric)
require.Len(t, latencyValues, 1)
assert.Equal(t, latency, latencyValues[0])
eventsLabels := mockMetrics.GetCounterLabels(ProjectionEventsProcessed)
require.Len(t, eventsLabels, 2)
assert.Equal(t, projection, eventsLabels[0][ProjectionLabel].AsString())
assert.Equal(t, true, eventsLabels[0][SuccessLabel].AsBool())
assert.Equal(t, projection, eventsLabels[1][ProjectionLabel].AsString())
assert.Equal(t, false, eventsLabels[1][SuccessLabel].AsBool())
timingLabels := mockMetrics.GetHistogramLabels(ProjectionHandleTimerMetric)
require.Len(t, timingLabels, 1)
assert.Equal(t, projection, timingLabels[0][ProjectionLabel].AsString())
latencyLabels := mockMetrics.GetHistogramLabels(ProjectionStateLatencyMetric)
require.Len(t, latencyLabels, 1)
assert.Equal(t, projection, latencyLabels[0][ProjectionLabel].AsString())
}

View File

@@ -0,0 +1,23 @@
package handler
var _ Projection = (*projection)(nil)
type projection struct {
name string
reducers []AggregateReducer
}
// ActiveInstances implements [Projection]
func (p *projection) ActiveInstances() []string {
return nil
}
// Name implements [Projection]
func (p *projection) Name() string {
return p.name
}
// Reducers implements [Projection]
func (p *projection) Reducers() []AggregateReducer {
return p.reducers
}

View File

@@ -0,0 +1,21 @@
package handler
import "github.com/zitadel/zitadel/internal/eventstore"
// EventReducer represents the required data
// to work with events
type EventReducer struct {
Event eventstore.EventType
Reduce Reduce
}
// Reduce reduces the given event to a statement
// which is used to update the projection
type Reduce func(eventstore.Event) (*Statement, error)
// EventReducer represents the required data
// to work with aggregates
type AggregateReducer struct {
Aggregate eventstore.AggregateType
EventReducers []EventReducer
}

View File

@@ -0,0 +1,120 @@
package handler
import (
"context"
"database/sql"
_ "embed"
"errors"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
type state struct {
instanceID string
position decimal.Decimal
eventTimestamp time.Time
aggregateType eventstore.AggregateType
aggregateID string
sequence uint64
offset uint32
}
var (
//go:embed state_get.sql
currentStateStmt string
//go:embed state_get_await.sql
currentStateAwaitStmt string
//go:embed state_set.sql
updateStateStmt string
//go:embed state_lock.sql
lockStateStmt string
errJustUpdated = errors.New("projection was just updated")
)
func (h *Handler) currentState(ctx context.Context, tx *sql.Tx, config *triggerConfig) (currentState *state, err error) {
currentState = &state{
instanceID: authz.GetInstance(ctx).InstanceID(),
}
var (
aggregateID = new(sql.NullString)
aggregateType = new(sql.NullString)
sequence = new(sql.NullInt64)
timestamp = new(sql.NullTime)
position = new(decimal.NullDecimal)
offset = new(sql.NullInt64)
)
stateQuery := currentStateStmt
if config.awaitRunning {
stateQuery = currentStateAwaitStmt
}
row := tx.QueryRow(stateQuery, currentState.instanceID, h.projection.Name())
err = row.Scan(
aggregateID,
aggregateType,
sequence,
timestamp,
position,
offset,
)
if errors.Is(err, sql.ErrNoRows) {
err = h.lockState(tx, currentState.instanceID)
}
if err != nil {
h.log().WithError(err).Debug("unable to query current state")
return nil, err
}
currentState.aggregateID = aggregateID.String
currentState.aggregateType = eventstore.AggregateType(aggregateType.String)
currentState.sequence = uint64(sequence.Int64)
currentState.eventTimestamp = timestamp.Time
currentState.position = position.Decimal
// psql does not provide unsigned numbers so we work around it
currentState.offset = uint32(offset.Int64)
return currentState, nil
}
func (h *Handler) setState(tx *sql.Tx, updatedState *state) error {
res, err := tx.Exec(updateStateStmt,
h.projection.Name(),
updatedState.instanceID,
updatedState.aggregateID,
updatedState.aggregateType,
updatedState.sequence,
updatedState.eventTimestamp,
updatedState.position,
updatedState.offset,
)
if err != nil {
h.log().WithError(err).Warn("unable to update state")
return zerrors.ThrowInternal(err, "V2-WF23g2", "unable to update state")
}
if affected, err := res.RowsAffected(); affected == 0 {
h.log().OnError(err).Error("unable to check if states are updated")
return zerrors.ThrowInternal(err, "V2-FGEKi", "unable to update state")
}
return nil
}
func (h *Handler) lockState(tx *sql.Tx, instanceID string) error {
res, err := tx.Exec(lockStateStmt,
h.projection.Name(),
instanceID,
)
if err != nil {
return err
}
if affected, err := res.RowsAffected(); affected == 0 || err != nil {
return zerrors.ThrowInternal(err, "V2-lpiK0", "projection already locked")
}
return nil
}

View File

@@ -0,0 +1,13 @@
SELECT
aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
, filter_offset
FROM
projections.current_states
WHERE
instance_id = $1
AND projection_name = $2
FOR UPDATE NOWAIT;

View File

@@ -0,0 +1,13 @@
SELECT
aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
, filter_offset
FROM
projections.current_states
WHERE
instance_id = $1
AND projection_name = $2
FOR UPDATE;

View File

@@ -0,0 +1,9 @@
INSERT INTO projections.current_states (
projection_name
, instance_id
, last_updated
) VALUES (
$1
, $2
, now()
) ON CONFLICT DO NOTHING;

View File

@@ -0,0 +1,32 @@
INSERT INTO projections.current_states (
projection_name
, instance_id
, aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
, last_updated
, filter_offset
) VALUES (
$1
, $2
, $3
, $4
, $5
, $6
, $7
, now()
, $8
) ON CONFLICT (
projection_name
, instance_id
) DO UPDATE SET
aggregate_id = $3
, aggregate_type = $4
, "sequence" = $5
, event_date = $6
, "position" = $7
, last_updated = statement_timestamp()
, filter_offset = $8
;

View File

@@ -0,0 +1,452 @@
package handler
import (
"context"
"database/sql"
"database/sql/driver"
_ "embed"
"errors"
"reflect"
"testing"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestHandler_lockState(t *testing.T) {
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
instanceID string
}
tests := []struct {
name string
fields fields
args args
isErr func(t *testing.T, err error)
}{
{
name: "tx closed",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
instanceID: "instance",
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v got: %v", sql.ErrTxDone, err)
}
},
},
{
name: "no rows affeced",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecNoRowsAffected(),
),
),
},
args: args{
instanceID: "instance",
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, zerrors.ThrowInternal(nil, "V2-lpiK0", "")) {
t.Errorf("unexpected error: want internal (V2lpiK0), got: %v", err)
}
},
},
{
name: "rows affected",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecRowsAffected(1),
),
),
},
args: args{
instanceID: "instance",
},
},
}
for _, tt := range tests {
if tt.isErr == nil {
tt.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
h := &Handler{
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
err = h.lockState(tx, tt.args.instanceID)
tt.isErr(t, err)
tt.fields.mock.Assert(t)
})
}
}
func TestHandler_updateLastUpdated(t *testing.T) {
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
updatedState *state
}
tests := []struct {
name string
fields fields
args args
isErr func(t *testing.T, err error)
}{
{
name: "update fails",
fields: fields{
projection: &projection{
name: "instance",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: decimal.NewFromInt(42),
},
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
}
},
},
{
name: "no rows affected",
fields: fields{
projection: &projection{
name: "instance",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecNoRowsAffected(),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: decimal.NewFromInt(42),
},
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, zerrors.ThrowInternal(nil, "V2-FGEKi", "")) {
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
}
},
},
{
name: "success",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecArgs(
"projection",
"instance",
"aggregate id",
eventstore.AggregateType("aggregate type"),
uint64(42),
mock.AnyType[time.Time]{},
decimal.NewFromInt(42),
uint32(0),
),
mock.WithExecRowsAffected(1),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: decimal.NewFromInt(42),
aggregateType: "aggregate type",
aggregateID: "aggregate id",
sequence: 42,
},
},
},
}
for _, tt := range tests {
if tt.isErr == nil {
tt.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
h := &Handler{
projection: tt.fields.projection,
}
err = h.setState(tx, tt.args.updatedState)
tt.isErr(t, err)
tt.fields.mock.Assert(t)
})
}
}
func TestHandler_currentState(t *testing.T) {
testTime := time.Now()
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
ctx context.Context
}
type want struct {
currentState *state
isErr func(t *testing.T, err error)
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "connection done",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(sql.ErrConnDone),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrConnDone) {
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrConnDone, err)
}
},
},
},
{
name: "no row but lock err",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(sql.ErrNoRows),
),
mock.ExcpectExec(lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrTxDone, err)
}
},
},
},
{
name: "state locked",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(&pgconn.PgError{Code: "55P03"}),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
pgErr := new(pgconn.PgError)
if !errors.As(err, &pgErr) {
t.Errorf("error should be PgErr but was %T", err)
return
}
if pgErr.Code != "55P03" {
t.Errorf("expected code 55P03 got: %s", pgErr.Code)
}
},
},
},
{
name: "success",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryResult(
[]string{"aggregate_id", "aggregate_type", "event_sequence", "event_date", "position", "offset"},
[][]driver.Value{
{
"aggregate id",
"aggregate type",
int64(42),
testTime,
decimal.NewFromInt(42).String(),
uint16(10),
},
},
),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
currentState: &state{
instanceID: "instance",
eventTimestamp: testTime,
position: decimal.NewFromInt(42),
aggregateType: "aggregate type",
aggregateID: "aggregate id",
sequence: 42,
offset: 10,
},
},
},
}
for _, tt := range tests {
if tt.want.isErr == nil {
tt.want.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
h := &Handler{
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.BeginTx(context.Background(), nil)
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
gotCurrentState, err := h.currentState(tt.args.ctx, tx, new(triggerConfig))
tt.want.isErr(t, err)
if !reflect.DeepEqual(gotCurrentState, tt.want.currentState) {
t.Errorf("Handler.currentState() gotCurrentState = %v, want %v", gotCurrentState, tt.want.currentState)
}
tt.fields.mock.Assert(t)
})
}
}

View File

@@ -0,0 +1,709 @@
package handler
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"golang.org/x/exp/constraints"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
var _ error = (*executionError)(nil)
type executionError struct {
parent error
}
// Error implements error.
func (s *executionError) Error() string {
return fmt.Sprintf("statement failed: %v", s.parent)
}
func (s *executionError) Is(err error) bool {
_, ok := err.(*executionError)
return ok
}
func (s *executionError) Unwrap() error {
return s.parent
}
func (h *Handler) eventsToStatements(tx *sql.Tx, events []eventstore.Event, currentState *state) (statements []*Statement, err error) {
statements = make([]*Statement, 0, len(events))
previousPosition := currentState.position
offset := currentState.offset
for _, event := range events {
statement, err := h.reduce(event)
if err != nil {
h.logEvent(event).WithError(err).Error("reduce failed")
if shouldContinue := h.handleFailedStmt(tx, failureFromEvent(event, err)); shouldContinue {
continue
}
return statements, err
}
offset++
if !previousPosition.Equal(event.Position()) {
// offset is 1 because we want to skip this event
offset = 1
}
statement.offset = offset
statement.Position = event.Position()
previousPosition = event.Position()
statements = append(statements, statement)
}
return statements, nil
}
func (h *Handler) reduce(event eventstore.Event) (*Statement, error) {
for _, reducer := range h.projection.Reducers() {
if reducer.Aggregate != event.Aggregate().Type {
continue
}
for _, reduce := range reducer.EventReducers {
if reduce.Event != event.Type() {
continue
}
return reduce.Reduce(event)
}
}
return NewNoOpStatement(event), nil
}
type Statement struct {
Aggregate *eventstore.Aggregate
Sequence uint64
Position decimal.Decimal
CreationDate time.Time
offset uint32
Execute Exec
}
type Exec func(ctx context.Context, ex Executer, projectionName string) error
func WithTableSuffix(name string) func(*execConfig) {
return func(o *execConfig) {
o.tableName += "_" + name
}
}
var (
ErrNoProjection = errors.New("no projection")
ErrNoValues = errors.New("no values")
ErrNoCondition = errors.New("no condition")
)
func NewStatement(event eventstore.Event, e Exec) *Statement {
return &Statement{
Aggregate: event.Aggregate(),
Sequence: event.Sequence(),
Position: event.Position(),
CreationDate: event.CreatedAt(),
Execute: e,
}
}
func NewCreateStatement(event eventstore.Event, values []Column, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
columnNames := strings.Join(cols, ", ")
valuesPlaceholder := strings.Join(params, ", ")
config := execConfig{
args: args,
}
if len(values) == 0 {
config.err = ErrNoValues
}
q := func(config execConfig) string {
return "INSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
}
return NewStatement(event, exec(config, q, opts))
}
func NewUpsertStatement(event eventstore.Event, conflictCols []Column, values []Column, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
conflictTarget := make([]string, len(conflictCols))
for i, col := range conflictCols {
conflictTarget[i] = col.Name
}
config := execConfig{}
if len(values) == 0 {
config.err = ErrNoValues
}
updateCols, updateVals, args := getUpdateCols(values, conflictTarget, params, args)
if len(updateCols) == 0 || len(updateVals) == 0 {
config.err = ErrNoValues
}
config.args = args
q := func(config execConfig) string {
var updateStmt string
// the postgres standard does not allow to update a single column using a multi-column update
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
if len(updateCols) == 1 && !strings.HasPrefix(updateVals[0], "SELECT") {
updateStmt = "UPDATE SET " + updateCols[0] + " = " + updateVals[0]
} else {
updateStmt = "UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
}
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO " + updateStmt
}
return NewStatement(event, exec(config, q, opts))
}
var _ ValueContainer = (*onlySetValueOnInsert)(nil)
type onlySetValueOnInsert struct {
Table string
Value interface{}
}
func (c *onlySetValueOnInsert) GetValue() interface{} {
return c.Value
}
func OnlySetValueOnInsert(table string, value interface{}) *onlySetValueOnInsert {
return &onlySetValueOnInsert{
Table: table,
Value: value,
}
}
type onlySetValueInCase struct {
Table string
Value interface{}
Condition Condition
}
func (c *onlySetValueInCase) GetValue() interface{} {
return c.Value
}
// ColumnChangedCondition checks the current value and if it changed to a specific new value
func ColumnChangedCondition(table, column string, currentValue, newValue interface{}) Condition {
return func(param string) (string, []any) {
index, _ := strconv.Atoi(param)
return fmt.Sprintf("%[1]s.%[2]s = $%[3]d AND EXCLUDED.%[2]s = $%[4]d", table, column, index, index+1), []any{currentValue, newValue}
}
}
// ColumnIsNullCondition checks if the current value is null
func ColumnIsNullCondition(table, column string) Condition {
return func(param string) (string, []any) {
return fmt.Sprintf("%[1]s.%[2]s IS NULL", table, column), nil
}
}
// ConditionOr links multiple Conditions by OR
func ConditionOr(conditions ...Condition) Condition {
return func(param string) (_ string, args []any) {
if len(conditions) == 0 {
return "", nil
}
b := strings.Builder{}
s, arg := conditions[0](param)
b.WriteString(s)
args = append(args, arg...)
for i := 1; i < len(conditions); i++ {
b.WriteString(" OR ")
s, condArgs := conditions[i](param)
b.WriteString(s)
args = append(args, condArgs...)
}
return b.String(), args
}
}
// OnlySetValueInCase will only update to the desired value if the condition applies
func OnlySetValueInCase(table string, value interface{}, condition Condition) *onlySetValueInCase {
return &onlySetValueInCase{
Table: table,
Value: value,
Condition: condition,
}
}
func getUpdateCols(cols []Column, conflictTarget, params []string, args []interface{}) (updateCols, updateVals []string, updatedArgs []interface{}) {
updateCols = make([]string, len(cols))
updateVals = make([]string, len(cols))
updatedArgs = args
for i := len(cols) - 1; i >= 0; i-- {
col := cols[i]
updateCols[i] = col.Name
switch v := col.Value.(type) {
case *onlySetValueOnInsert:
updateVals[i] = v.Table + "." + col.Name
case *onlySetValueInCase:
s, condArgs := v.Condition(strconv.Itoa(len(params) + 1))
updatedArgs = append(updatedArgs, condArgs...)
updateVals[i] = fmt.Sprintf("CASE WHEN %[1]s THEN EXCLUDED.%[2]s ELSE %[3]s.%[2]s END", s, col.Name, v.Table)
default:
updateVals[i] = "EXCLUDED" + "." + col.Name
}
for _, conflict := range conflictTarget {
if conflict == col.Name {
copy(updateCols[i:], updateCols[i+1:])
updateCols[len(updateCols)-1] = ""
updateCols = updateCols[:len(updateCols)-1]
copy(updateVals[i:], updateVals[i+1:])
updateVals[len(updateVals)-1] = ""
updateVals = updateVals[:len(updateVals)-1]
break
}
}
}
return updateCols, updateVals, updatedArgs
}
func NewUpdateStatement(event eventstore.Event, values []Column, conditions []Condition, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
wheres, whereArgs := conditionsToWhere(conditions, len(args)+1)
args = append(args, whereArgs...)
config := execConfig{
args: args,
}
if len(values) == 0 {
config.err = ErrNoValues
}
if len(conditions) == 0 {
config.err = ErrNoCondition
}
q := func(config execConfig) string {
// the postgres standard does not allow to update a single column using a multi-column update
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
if len(cols) == 1 && !strings.HasPrefix(params[0], "SELECT") {
return "UPDATE " + config.tableName + " SET " + cols[0] + " = " + params[0] + " WHERE " + strings.Join(wheres, " AND ")
}
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
}
return NewStatement(event, exec(config, q, opts))
}
func NewDeleteStatement(event eventstore.Event, conditions []Condition, opts ...execOption) *Statement {
wheres, args := conditionsToWhere(conditions, 1)
wheresPlaceholders := strings.Join(wheres, " AND ")
config := execConfig{
args: args,
}
if len(conditions) == 0 {
config.err = ErrNoCondition
}
q := func(config execConfig) string {
return "DELETE FROM " + config.tableName + " WHERE " + wheresPlaceholders
}
return NewStatement(event, exec(config, q, opts))
}
func NewNoOpStatement(event eventstore.Event) *Statement {
return NewStatement(event, nil)
}
func NewSleepStatement(event eventstore.Event, d time.Duration, opts ...execOption) *Statement {
return NewStatement(
event,
exec(
execConfig{
args: []any{float64(d) / float64(time.Second)},
},
func(_ execConfig) string {
return "SELECT pg_sleep($1);"
},
opts,
),
)
}
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *Statement {
if len(opts) == 0 {
return NewNoOpStatement(event)
}
execs := make([]Exec, len(opts))
for i, opt := range opts {
execs[i] = opt(event)
}
return NewStatement(event, multiExec(execs))
}
func AddNoOpStatement() func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewNoOpStatement(event).Execute
}
}
func AddCreateStatement(columns []Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCreateStatement(event, columns, opts...).Execute
}
}
func AddUpsertStatement(indexCols []Column, values []Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpsertStatement(event, indexCols, values, opts...).Execute
}
}
func AddUpdateStatement(values []Column, conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpdateStatement(event, values, conditions, opts...).Execute
}
}
func AddDeleteStatement(conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewDeleteStatement(event, conditions, opts...).Execute
}
}
func AddCopyStatement(conflict, from, to []Column, conditions []NamespacedCondition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
}
}
func AddSleepStatement(d time.Duration, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewSleepStatement(event, d, opts...).Execute
}
}
func NewArrayAppendCol(column string, value interface{}) Column {
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
return "array_append(" + column + ", " + placeholder + ")"
},
}
}
func NewArrayRemoveCol(column string, value interface{}) Column {
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
return "array_remove(" + column + ", " + placeholder + ")"
},
}
}
func NewArrayIntersectCol(column string, value interface{}) Column {
var arrayType string
switch value.(type) {
case []string, database.TextArray[string]:
arrayType = "TEXT"
//TODO: handle more types if necessary
}
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
return "SELECT ARRAY( SELECT UNNEST(" + column + ") INTERSECT SELECT UNNEST (" + placeholder + "::" + arrayType + "[]))"
},
}
}
func NewCopyCol(column, from string) Column {
return Column{
Name: column,
Value: NewCol(from, nil),
}
}
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
// cols represent the columns which are objective to change.
// if the value of a col is empty the data will be copied from the selected row
// if the value of a col is not empty the data will be set by the static value
// conds represent the conditions for the selection subquery
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []Column, nsCond []NamespacedCondition, opts ...execOption) *Statement {
columnNames := make([]string, len(to))
selectColumns := make([]string, len(from))
updateColumns := make([]string, len(columnNames))
argCounter := 0
args := []interface{}{}
for i, col := range from {
columnNames[i] = to[i].Name
selectColumns[i] = from[i].Name
updateColumns[i] = "EXCLUDED." + col.Name
if col.Value != nil {
argCounter++
selectColumns[i] = "$" + strconv.Itoa(argCounter)
updateColumns[i] = selectColumns[i]
args = append(args, col.Value)
}
}
cond := make([]Condition, len(nsCond))
for i := range nsCond {
cond[i] = nsCond[i]("copy_table")
}
wheres, values := conditionsToWhere(cond, len(args)+1)
args = append(args, values...)
conflictTargets := make([]string, len(conflictCols))
for i, conflictCol := range conflictCols {
conflictTargets[i] = conflictCol.Name
}
config := execConfig{
args: args,
}
if len(from) == 0 || len(to) == 0 || len(from) != len(to) {
config.err = ErrNoValues
}
if len(cond) == 0 {
config.err = ErrNoCondition
}
q := func(config execConfig) string {
return "INSERT INTO " +
config.tableName +
" (" +
strings.Join(columnNames, ", ") +
") SELECT " +
strings.Join(selectColumns, ", ") +
" FROM " +
config.tableName + " AS copy_table WHERE " +
strings.Join(wheres, " AND ") +
" ON CONFLICT (" +
strings.Join(conflictTargets, ", ") +
") DO UPDATE SET (" +
strings.Join(columnNames, ", ") +
") = (" +
strings.Join(updateColumns, ", ") +
")"
}
return NewStatement(event, exec(config, q, opts))
}
type ValueContainer interface {
GetValue() interface{}
}
func columnsToQuery(cols []Column) (names []string, parameters []string, values []interface{}) {
names = make([]string, len(cols))
values = make([]interface{}, len(cols))
parameters = make([]string, len(cols))
var parameterIndex int
for i, col := range cols {
names[i] = col.Name
switch c := col.Value.(type) {
case Column:
parameters[i] = c.Name
continue
case ValueContainer:
values[parameterIndex] = c.GetValue()
default:
values[parameterIndex] = col.Value
}
parameters[i] = "$" + strconv.Itoa(parameterIndex+1)
if col.ParameterOpt != nil {
parameters[i] = col.ParameterOpt(parameters[i])
}
parameterIndex++
}
return names, parameters, values[:parameterIndex]
}
func conditionsToWhere(conds []Condition, paramOffset int) (wheres []string, values []interface{}) {
wheres = make([]string, len(conds))
values = make([]any, 0, len(conds))
for i, cond := range conds {
var args []any
wheres[i], args = cond("$" + strconv.Itoa(paramOffset))
paramOffset += len(args)
values = append(values, args...)
wheres[i] = "(" + wheres[i] + ")"
}
return wheres, values
}
type Column struct {
Name string
Value interface{}
ParameterOpt func(string) string
}
func NewCol(name string, value interface{}) Column {
return Column{
Name: name,
Value: value,
}
}
func NewJSONCol(name string, value interface{}) Column {
marshalled, err := json.Marshal(value)
if err != nil {
logging.WithFields("column", name).WithError(err).Panic("unable to marshal column")
}
return NewCol(name, marshalled)
}
func NewIncrementCol[Int constraints.Integer](column string, value Int) Column {
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
return column + " + " + placeholder
},
}
}
type Condition func(param string) (string, []any)
type NamespacedCondition func(namespace string) Condition
func NewCond(name string, value interface{}) Condition {
return func(param string) (string, []any) {
return name + " = " + param, []any{value}
}
}
func NewUnequalCond(name string, value any) Condition {
return func(param string) (string, []any) {
return name + " <> " + param, []any{value}
}
}
func NewNamespacedCondition(name string, value interface{}) NamespacedCondition {
return func(namespace string) Condition {
return NewCond(namespace+"."+name, value)
}
}
func NewLessThanCond(column string, value interface{}) Condition {
return func(param string) (string, []any) {
return column + " < " + param, []any{value}
}
}
func NewIsNullCond(column string) Condition {
return func(string) (string, []any) {
return column + " IS NULL", nil
}
}
func NewIsNotNullCond(column string) Condition {
return func(string) (string, []any) {
return column + " IS NOT NULL", nil
}
}
// NewTextArrayContainsCond returns a Condition that checks if the column that stores an array of text contains the given value
func NewTextArrayContainsCond(column string, value string) Condition {
return func(param string) (string, []any) {
return column + " @> " + param, []any{database.TextArray[string]{value}}
}
}
// Not is a function and not a method, so that calling it is well readable
// For example conditions := []Condition{ Not(NewTextArrayContainsCond())}
func Not(condition Condition) Condition {
return func(param string) (string, []any) {
cond, value := condition(param)
return "NOT (" + cond + ")", value
}
}
// NewOneOfTextCond returns a Condition that checks if the column that stores a text is one of the given values
func NewOneOfTextCond(column string, values []string) Condition {
return func(param string) (string, []any) {
return column + " = ANY(" + param + ")", []any{database.TextArray[string](values)}
}
}
type Executer interface {
Exec(string, ...interface{}) (sql.Result, error)
}
type execOption func(*execConfig)
type execConfig struct {
tableName string
args []interface{}
err error
}
type query func(config execConfig) string
func exec(config execConfig, q query, opts []execOption) Exec {
return func(ctx context.Context, ex Executer, projectionName string) (err error) {
if projectionName == "" {
return ErrNoProjection
}
if config.err != nil {
return config.err
}
config.tableName = projectionName
for _, opt := range opts {
opt(&config)
}
_, err = ex.Exec(q(config), config.args...)
if err != nil {
return zerrors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
}
return nil
}
}
func multiExec(execList []Exec) Exec {
return func(ctx context.Context, ex Executer, projectionName string) error {
for _, exec := range execList {
if exec == nil {
continue
}
if err := exec(ctx, ex, projectionName); err != nil {
return err
}
}
return nil
}
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,257 @@
package eventstore_test
import (
"context"
"encoding/json"
"os"
"testing"
"time"
pgxdecimal "github.com/jackc/pgx-shopspring-decimal"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/cmd/initialise"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/database/postgres"
"github.com/zitadel/zitadel/internal/eventstore"
es_sql "github.com/zitadel/zitadel/internal/eventstore/repository/sql"
new_es "github.com/zitadel/zitadel/internal/eventstore/v3"
)
var (
testClient *database.DB
queriers map[string]eventstore.Querier = make(map[string]eventstore.Querier)
pushers map[string]eventstore.Pusher = make(map[string]eventstore.Pusher)
clients map[string]*database.DB = make(map[string]*database.DB)
)
func TestMain(m *testing.M) {
os.Exit(func() int {
config, cleanup := postgres.StartEmbedded()
defer cleanup()
testClient = &database.DB{
Database: new(testDB),
}
connConfig, err := pgxpool.ParseConfig(config.GetConnectionURL())
logging.OnError(err).Fatal("unable to parse db url")
connConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
pgxdecimal.Register(conn.TypeMap())
return new_es.RegisterEventstoreTypes(ctx, conn)
}
pool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
logging.OnError(err).Fatal("unable to create db pool")
testClient.DB = stdlib.OpenDBFromPool(pool)
err = testClient.Ping()
logging.OnError(err).Fatal("unable to ping db")
v2 := &es_sql.Postgres{DB: testClient}
queriers["v2(inmemory)"] = v2
clients["v2(inmemory)"] = testClient
pushers["v3(inmemory)"] = new_es.NewEventstore(testClient)
clients["v3(inmemory)"] = testClient
if localDB, err := connectLocalhost(); err == nil {
err = initDB(context.Background(), localDB)
logging.OnError(err).Fatal("migrations failed")
pushers["v3(singlenode)"] = new_es.NewEventstore(localDB)
clients["v3(singlenode)"] = localDB
}
defer func() {
logging.OnError(testClient.Close()).Error("unable to close db")
}()
err = initDB(context.Background(), &database.DB{DB: testClient.DB, Database: &postgres.Config{Database: "zitadel"}})
logging.OnError(err).Fatal("migrations failed")
return m.Run()
}())
}
func initDB(ctx context.Context, db *database.DB) error {
config := new(database.Config)
config.SetConnector(&postgres.Config{User: postgres.User{Username: "zitadel"}, Database: "zitadel"})
if err := initialise.ReadStmts(); err != nil {
return err
}
err := initialise.Init(ctx, db,
initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.DatabaseName()),
initialise.VerifyGrant(config.DatabaseName(), config.Username()))
if err != nil {
return err
}
err = initialise.VerifyZitadel(ctx, db, *config)
if err != nil {
return err
}
// create old events
_, err = db.Exec(oldEventsTable)
return err
}
func connectLocalhost() (*database.DB, error) {
config, err := pgxpool.ParseConfig("postgresql://postgres@localhost:5432/postgres?sslmode=disable")
if err != nil {
return nil, err
}
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
pgxdecimal.Register(conn.TypeMap())
return new_es.RegisterEventstoreTypes(ctx, conn)
}
pool, err := pgxpool.NewWithConfig(context.Background(), config)
if err != nil {
return nil, err
}
client := stdlib.OpenDBFromPool(pool)
if err = client.Ping(); err != nil {
return nil, err
}
return &database.DB{
DB: client,
Database: new(testDB),
}, nil
}
type testDB struct{}
func (_ *testDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
func (*testDB) DatabaseName() string { return "db" }
func (*testDB) Username() string { return "user" }
func (*testDB) Type() dialect.DatabaseType { return dialect.DatabaseTypePostgres }
func generateCommand(aggregateType eventstore.AggregateType, aggregateID string, opts ...func(*testEvent)) eventstore.Command {
e := &testEvent{
BaseEvent: eventstore.BaseEvent{
Agg: &eventstore.Aggregate{
ID: aggregateID,
Type: aggregateType,
ResourceOwner: "ro",
Version: "v1",
},
Service: "svc",
EventType: "test.created",
},
}
for _, opt := range opts {
opt(e)
}
return e
}
type testEvent struct {
eventstore.BaseEvent
uniqueConstraints []*eventstore.UniqueConstraint
}
func (e *testEvent) Payload() any {
return e.BaseEvent.Data
}
func (e *testEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return e.uniqueConstraints
}
func canceledCtx() context.Context {
ctx, cancel := context.WithCancel(context.Background())
cancel()
return ctx
}
func fillUniqueData(unique_type, field, instanceID string) error {
_, err := testClient.Exec("INSERT INTO eventstore.unique_constraints (unique_type, unique_field, instance_id) VALUES ($1, $2, $3)", unique_type, field, instanceID)
return err
}
func generateAddUniqueConstraint(table, uniqueField string) func(e *testEvent) {
return func(e *testEvent) {
e.uniqueConstraints = append(e.uniqueConstraints,
&eventstore.UniqueConstraint{
UniqueType: table,
UniqueField: uniqueField,
Action: eventstore.UniqueConstraintAdd,
},
)
}
}
func generateRemoveUniqueConstraint(table, uniqueField string) func(e *testEvent) {
return func(e *testEvent) {
e.uniqueConstraints = append(e.uniqueConstraints,
&eventstore.UniqueConstraint{
UniqueType: table,
UniqueField: uniqueField,
Action: eventstore.UniqueConstraintRemove,
},
)
}
}
func withTestData(data any) func(e *testEvent) {
return func(e *testEvent) {
d, err := json.Marshal(data)
if err != nil {
panic("marshal data failed")
}
e.BaseEvent.Data = d
}
}
func cleanupEventstore(client *database.DB) func() {
return func() {
_, err := client.Exec("TRUNCATE eventstore.events")
if err != nil {
logging.Warnf("unable to truncate events: %v", err)
}
_, err = client.Exec("TRUNCATE eventstore.events2")
if err != nil {
logging.Warnf("unable to truncate events: %v", err)
}
_, err = client.Exec("TRUNCATE eventstore.unique_constraints")
if err != nil {
logging.Warnf("unable to truncate unique constraints: %v", err)
}
}
}
const oldEventsTable = `CREATE TABLE IF NOT EXISTS eventstore.events (
id UUID DEFAULT gen_random_uuid()
, event_type TEXT NOT NULL
, aggregate_type TEXT NOT NULL
, aggregate_id TEXT NOT NULL
, aggregate_version TEXT NOT NULL
, event_sequence BIGINT NOT NULL
, previous_aggregate_sequence BIGINT
, previous_aggregate_type_sequence INT8
, creation_date TIMESTAMPTZ NOT NULL DEFAULT now()
, created_at TIMESTAMPTZ NOT NULL DEFAULT clock_timestamp()
, event_data JSONB
, editor_user TEXT NOT NULL
, editor_service TEXT
, resource_owner TEXT NOT NULL
, instance_id TEXT NOT NULL
, "position" DECIMAL NOT NULL
, in_tx_order INTEGER NOT NULL
, PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence)
);`

View File

@@ -0,0 +1,55 @@
package eventstore
import (
"time"
"github.com/shopspring/decimal"
)
// ReadModel is the minimum representation of a read model.
// It implements a basic reducer
// it might be saved in a database or in memory
type ReadModel struct {
AggregateID string `json:"-"`
ProcessedSequence uint64 `json:"-"`
CreationDate time.Time `json:"-"`
ChangeDate time.Time `json:"-"`
Events []Event `json:"-"`
ResourceOwner string `json:"-"`
InstanceID string `json:"-"`
Position decimal.Decimal `json:"-"`
}
// AppendEvents adds all the events to the read model.
// The function doesn't compute the new state of the read model
func (rm *ReadModel) AppendEvents(events ...Event) {
rm.Events = append(rm.Events, events...)
}
// Reduce is the basic implementation of reducer
// If this function is extended the extending function should be the last step
func (rm *ReadModel) Reduce() error {
if len(rm.Events) == 0 {
return nil
}
if rm.AggregateID == "" {
rm.AggregateID = rm.Events[0].Aggregate().ID
}
if rm.ResourceOwner == "" {
rm.ResourceOwner = rm.Events[0].Aggregate().ResourceOwner
}
if rm.InstanceID == "" {
rm.InstanceID = rm.Events[0].Aggregate().InstanceID
}
if rm.CreationDate.IsZero() {
rm.CreationDate = rm.Events[0].CreatedAt()
}
rm.ChangeDate = rm.Events[len(rm.Events)-1].CreatedAt()
rm.ProcessedSequence = rm.Events[len(rm.Events)-1].Sequence()
rm.Position = rm.Events[len(rm.Events)-1].Position()
// all events processed and not needed anymore
rm.Events = rm.Events[0:0]
return nil
}

View File

@@ -0,0 +1,24 @@
package repository
//Asset represents all information about a asset (img)
type Asset struct {
// ID is to refer to the asset
ID string
//Asset is the actual image
Asset []byte
//Action defines if asset should be added or removed
Action AssetAction
}
type AssetAction int32
const (
AssetAdded AssetAction = iota
AssetRemoved
assetCount
)
func (f AssetAction) Valid() bool {
return f >= 0 && f < assetCount
}

View File

@@ -0,0 +1,133 @@
package repository
import (
"database/sql"
"encoding/json"
"strconv"
"strings"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"
)
var _ eventstore.Event = (*Event)(nil)
// Event represents all information about a manipulation of an aggregate
type Event struct {
//ID is a generated uuid for this event
ID string
// Seq is the sequence of the event
Seq uint64
// Pos is the global sequence of the event multiple events can have the same sequence
Pos decimal.Decimal
//CreationDate is the time the event is created
// it's used for human readability.
// Don't use it for event ordering,
// time drifts in different services could cause integrity problems
CreationDate time.Time
// Typ describes the cause of the event (e.g. user.added)
// it should always be in past-form
Typ eventstore.EventType
//Data describe the changed fields (e.g. userName = "hodor")
// data must always a pointer to a struct, a struct or a byte array containing json bytes
Data []byte
//EditorUser should be a unique identifier for the user which created the event
// it's meant for maintainability.
// It's recommend to use the aggregate id of the user
EditorUser string
//Version describes the definition of the aggregate at a certain point in time
// it's used in read models to reduce the events in the correct definition
Version eventstore.Version
//AggregateID id is the unique identifier of the aggregate
// the client must generate it by it's own
AggregateID string
//AggregateType describes the meaning of the aggregate for this event
// it could an object like user
AggregateType eventstore.AggregateType
//ResourceOwner is the organisation which owns this aggregate
// an aggregate can only be managed by one organisation
// use the ID of the org
ResourceOwner sql.NullString
//InstanceID is the instance where this event belongs to
// use the ID of the instance
InstanceID string
Constraints []*eventstore.UniqueConstraint
}
// Aggregate implements [eventstore.Event]
func (e *Event) Aggregate() *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: e.AggregateID,
Type: e.AggregateType,
ResourceOwner: e.ResourceOwner.String,
InstanceID: e.InstanceID,
Version: e.Version,
}
}
// Creator implements [eventstore.Event]
func (e *Event) Creator() string {
return e.EditorUser
}
// Type implements [eventstore.Event]
func (e *Event) Type() eventstore.EventType {
return e.Typ
}
// Revision implements [eventstore.Event]
func (e *Event) Revision() uint16 {
revision, err := strconv.ParseUint(strings.TrimPrefix(string(e.Version), "v"), 10, 16)
logging.OnError(err).Debug("failed to parse event revision")
return uint16(revision)
}
// Sequence implements [eventstore.Event]
func (e *Event) Sequence() uint64 {
return e.Seq
}
// Position implements [eventstore.Event]
func (e *Event) Position() decimal.Decimal {
return e.Pos
}
// CreatedAt implements [eventstore.Event]
func (e *Event) CreatedAt() time.Time {
return e.CreationDate
}
// Unmarshal implements [eventstore.Event]
func (e *Event) Unmarshal(ptr any) error {
if len(e.Data) == 0 {
return nil
}
return json.Unmarshal(e.Data, ptr)
}
// DataAsBytes implements [eventstore.Event]
func (e *Event) DataAsBytes() []byte {
return e.Data
}
func (e *Event) Payload() any {
return e.Data
}
func (e *Event) UniqueConstraints() []*eventstore.UniqueConstraint {
return e.Constraints
}
func (e *Event) Fields() []*eventstore.FieldOperation {
return nil
}

View File

@@ -0,0 +1,3 @@
package mock
//go:generate mockgen -package mock -destination ./repository.mock.go github.com/zitadel/zitadel/internal/eventstore Querier,Pusher

View File

@@ -0,0 +1,186 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/zitadel/internal/eventstore (interfaces: Querier,Pusher)
//
// Generated by this command:
//
// mockgen -package mock -destination ./repository.mock.go github.com/zitadel/zitadel/internal/eventstore Querier,Pusher
//
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
decimal "github.com/shopspring/decimal"
database "github.com/zitadel/zitadel/internal/database"
eventstore "github.com/zitadel/zitadel/internal/eventstore"
gomock "go.uber.org/mock/gomock"
)
// MockQuerier is a mock of Querier interface.
type MockQuerier struct {
ctrl *gomock.Controller
recorder *MockQuerierMockRecorder
}
// MockQuerierMockRecorder is the mock recorder for MockQuerier.
type MockQuerierMockRecorder struct {
mock *MockQuerier
}
// NewMockQuerier creates a new mock instance.
func NewMockQuerier(ctrl *gomock.Controller) *MockQuerier {
mock := &MockQuerier{ctrl: ctrl}
mock.recorder = &MockQuerierMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockQuerier) EXPECT() *MockQuerierMockRecorder {
return m.recorder
}
// Client mocks base method.
func (m *MockQuerier) Client() *database.DB {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Client")
ret0, _ := ret[0].(*database.DB)
return ret0
}
// Client indicates an expected call of Client.
func (mr *MockQuerierMockRecorder) Client() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Client", reflect.TypeOf((*MockQuerier)(nil).Client))
}
// FilterToReducer mocks base method.
func (m *MockQuerier) FilterToReducer(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder, arg2 eventstore.Reducer) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "FilterToReducer", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// FilterToReducer indicates an expected call of FilterToReducer.
func (mr *MockQuerierMockRecorder) FilterToReducer(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterToReducer", reflect.TypeOf((*MockQuerier)(nil).FilterToReducer), arg0, arg1, arg2)
}
// Health mocks base method.
func (m *MockQuerier) Health(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Health indicates an expected call of Health.
func (mr *MockQuerierMockRecorder) Health(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockQuerier)(nil).Health), arg0)
}
// InstanceIDs mocks base method.
func (m *MockQuerier) InstanceIDs(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder) ([]string, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InstanceIDs", arg0, arg1)
ret0, _ := ret[0].([]string)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// InstanceIDs indicates an expected call of InstanceIDs.
func (mr *MockQuerierMockRecorder) InstanceIDs(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InstanceIDs", reflect.TypeOf((*MockQuerier)(nil).InstanceIDs), arg0, arg1)
}
// LatestPosition mocks base method.
func (m *MockQuerier) LatestPosition(arg0 context.Context, arg1 *eventstore.SearchQueryBuilder) (decimal.Decimal, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "LatestPosition", arg0, arg1)
ret0, _ := ret[0].(decimal.Decimal)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// LatestPosition indicates an expected call of LatestPosition.
func (mr *MockQuerierMockRecorder) LatestPosition(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LatestPosition", reflect.TypeOf((*MockQuerier)(nil).LatestPosition), arg0, arg1)
}
// MockPusher is a mock of Pusher interface.
type MockPusher struct {
ctrl *gomock.Controller
recorder *MockPusherMockRecorder
}
// MockPusherMockRecorder is the mock recorder for MockPusher.
type MockPusherMockRecorder struct {
mock *MockPusher
}
// NewMockPusher creates a new mock instance.
func NewMockPusher(ctrl *gomock.Controller) *MockPusher {
mock := &MockPusher{ctrl: ctrl}
mock.recorder = &MockPusherMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockPusher) EXPECT() *MockPusherMockRecorder {
return m.recorder
}
// Client mocks base method.
func (m *MockPusher) Client() *database.DB {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Client")
ret0, _ := ret[0].(*database.DB)
return ret0
}
// Client indicates an expected call of Client.
func (mr *MockPusherMockRecorder) Client() *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Client", reflect.TypeOf((*MockPusher)(nil).Client))
}
// Health mocks base method.
func (m *MockPusher) Health(arg0 context.Context) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Health", arg0)
ret0, _ := ret[0].(error)
return ret0
}
// Health indicates an expected call of Health.
func (mr *MockPusherMockRecorder) Health(arg0 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Health", reflect.TypeOf((*MockPusher)(nil).Health), arg0)
}
// Push mocks base method.
func (m *MockPusher) Push(arg0 context.Context, arg1 database.ContextQueryExecuter, arg2 ...eventstore.Command) ([]eventstore.Event, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Push", varargs...)
ret0, _ := ret[0].([]eventstore.Event)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Push indicates an expected call of Push.
func (mr *MockPusherMockRecorder) Push(arg0, arg1 any, arg2 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Push", reflect.TypeOf((*MockPusher)(nil).Push), varargs...)
}

View File

@@ -0,0 +1,235 @@
package mock
import (
"context"
"encoding/json"
"fmt"
"testing"
"time"
"github.com/shopspring/decimal"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
type MockRepository struct {
*MockPusher
*MockQuerier
}
func NewRepo(t *testing.T) *MockRepository {
controller := gomock.NewController(t)
return &MockRepository{
MockPusher: NewMockPusher(controller),
MockQuerier: NewMockQuerier(controller),
}
}
func (m *MockRepository) ExpectFilterNoEventsNoError() *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).Return(nil)
return m
}
func (m *MockRepository) ExpectFilterEvents(events ...eventstore.Event) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(_ context.Context, _ *eventstore.SearchQueryBuilder, reduce eventstore.Reducer) error {
for _, event := range events {
if err := reduce(event); err != nil {
return err
}
}
return nil
},
)
return m
}
func (m *MockRepository) ExpectFilterEventsError(err error) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().FilterToReducer(gomock.Any(), gomock.Any(), gomock.Any()).Return(err)
return m
}
func (m *MockRepository) ExpectInstanceIDs(hasFilters []*repository.Filter, instanceIDs ...string) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
matcher := gomock.Any()
if len(hasFilters) > 0 {
matcher = &filterQueryMatcher{SubQueries: [][]*repository.Filter{hasFilters}}
}
m.MockQuerier.EXPECT().InstanceIDs(gomock.Any(), matcher).Return(instanceIDs, nil)
return m
}
func (m *MockRepository) ExpectInstanceIDsError(err error) *MockRepository {
m.MockQuerier.ctrl.T.Helper()
m.MockQuerier.EXPECT().InstanceIDs(gomock.Any(), gomock.Any()).Return(nil, err)
return m
}
// ExpectPush checks if the expectedCommands are send to the Push method.
// The call will sleep at least the amount of passed duration.
func (m *MockRepository) ExpectPush(expectedCommands []eventstore.Command, sleep time.Duration) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
m.MockPusher.ctrl.T.Helper()
time.Sleep(sleep)
if len(expectedCommands) != len(commands) {
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
}
for i, expectedCommand := range expectedCommands {
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Aggregate(), commands[i].Aggregate()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Aggregate [%d]: expected: %#v got: %#v", i, expectedCommand.Aggregate(), commands[i].Aggregate())
}
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Creator(), commands[i].Creator()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Creator [%d]: expected: %#v got: %#v", i, expectedCommand.Creator(), commands[i].Creator())
}
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Type(), commands[i].Type()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Type [%d]: expected: %#v got: %#v", i, expectedCommand.Type(), commands[i].Type())
}
if !assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Revision(), commands[i].Revision()) {
m.MockPusher.ctrl.T.Errorf("invalid command.Revision [%d]: expected: %#v got: %#v", i, expectedCommand.Revision(), commands[i].Revision())
}
var expectedPayload []byte
expectedPayload, ok := expectedCommand.Payload().([]byte)
if !ok {
expectedPayload, _ = json.Marshal(expectedCommand.Payload())
}
if string(expectedPayload) == "" {
expectedPayload = []byte("null")
}
gotPayload, _ := json.Marshal(commands[i].Payload())
if !assert.Equal(m.MockPusher.ctrl.T, expectedPayload, gotPayload) {
m.MockPusher.ctrl.T.Errorf("invalid command.Payload [%d]: expected: %#v got: %#v", i, expectedCommand.Payload(), commands[i].Payload())
}
if !assert.ElementsMatch(m.MockPusher.ctrl.T, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints()) {
m.MockPusher.ctrl.T.Errorf("invalid command.UniqueConstraints [%d]: expected: %#v got: %#v", i, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints())
}
}
events := make([]eventstore.Event, len(commands))
for i, command := range commands {
events[i] = &mockEvent{
Command: command,
}
}
return events, nil
},
)
return m
}
func (m *MockRepository) ExpectPushFailed(err error, expectedCommands []eventstore.Command) *MockRepository {
m.MockPusher.ctrl.T.Helper()
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
if len(expectedCommands) != len(commands) {
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
}
for i, expectedCommand := range expectedCommands {
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Aggregate(), commands[i].Aggregate())
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Creator(), commands[i].Creator())
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Type(), commands[i].Type())
assert.Equal(m.MockPusher.ctrl.T, expectedCommand.Revision(), commands[i].Revision())
var expectedPayload []byte
expectedPayload, ok := expectedCommand.Payload().([]byte)
if !ok {
expectedPayload, _ = json.Marshal(expectedCommand.Payload())
}
if string(expectedPayload) == "" {
expectedPayload = []byte("null")
}
gotPayload, _ := json.Marshal(commands[i].Payload())
assert.Equal(m.MockPusher.ctrl.T, expectedPayload, gotPayload)
assert.ElementsMatch(m.MockPusher.ctrl.T, expectedCommand.UniqueConstraints(), commands[i].UniqueConstraints())
}
return nil, err
},
)
return m
}
type mockEvent struct {
eventstore.Command
sequence uint64
createdAt time.Time
}
// DataAsBytes implements eventstore.Event
func (e *mockEvent) DataAsBytes() []byte {
if e.Payload() == nil {
return nil
}
payload, err := json.Marshal(e.Payload())
if err != nil {
panic(err)
}
return payload
}
func (e *mockEvent) Unmarshal(ptr any) error {
if e.Payload() == nil {
return nil
}
payload, err := json.Marshal(e.Payload())
if err != nil {
return err
}
return json.Unmarshal(payload, ptr)
}
func (e *mockEvent) Sequence() uint64 {
return e.sequence
}
func (e *mockEvent) Position() decimal.Decimal {
return decimal.Decimal{}
}
func (e *mockEvent) CreatedAt() time.Time {
return e.createdAt
}
func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
assert.Len(m.MockPusher.ctrl.T, commands, len(expectedCommands))
events := make([]eventstore.Event, len(commands))
for i, command := range commands {
events[i] = &mockEvent{
Command: command,
}
}
return events, nil
},
)
return m
}
func (m *MockRepository) ExpectRandomPushFailed(err error, expectedEvents []eventstore.Command) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.ContextQueryExecuter, events ...eventstore.Command) ([]eventstore.Event, error) {
assert.Len(m.MockPusher.ctrl.T, events, len(expectedEvents))
return nil, err
},
)
return m
}

View File

@@ -0,0 +1,33 @@
package mock
import (
"encoding/json"
"fmt"
"reflect"
"go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
var _ gomock.Matcher = (*filterMatcher)(nil)
var _ gomock.GotFormatter = (*filterMatcher)(nil)
type filterMatcher repository.Filter
func (f *filterMatcher) String() string {
jsonValue, err := json.Marshal(f.Value)
if err != nil {
panic(err)
}
return fmt.Sprintf("%d %d (content=%+v,type=%T,json=%s)", f.Field, f.Operation, f.Value, f.Value, string(jsonValue))
}
func (f *filterMatcher) Matches(x interface{}) bool {
other := x.(*repository.Filter)
return f.Field == other.Field && f.Operation == other.Operation && reflect.DeepEqual(f.Value, other.Value)
}
func (f *filterMatcher) Got(got interface{}) string {
return (*filterMatcher)(got.(*repository.Filter)).String()
}

View File

@@ -0,0 +1,45 @@
package mock
import (
"fmt"
"strings"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
type filterQueryMatcher repository.SearchQuery
func (f *filterQueryMatcher) String() string {
var filterLists []string
for _, filterSlice := range f.SubQueries {
var str string
for _, filter := range filterSlice {
str += "," + (*filterMatcher)(filter).String()
}
filterLists = append(filterLists, fmt.Sprintf("[%s]", strings.TrimPrefix(str, ",")))
}
return fmt.Sprintf("Filters: %s", strings.Join(filterLists, " "))
}
func (f *filterQueryMatcher) Matches(x interface{}) bool {
other := x.(*repository.SearchQuery)
if len(f.SubQueries) != len(other.SubQueries) {
return false
}
for filterSliceIdx, filterSlice := range f.SubQueries {
if len(filterSlice) != len(other.SubQueries[filterSliceIdx]) {
return false
}
for filterIdx, filter := range f.SubQueries[filterSliceIdx] {
if !(*filterMatcher)(filter).Matches(other.SubQueries[filterSliceIdx][filterIdx]) {
return false
}
}
}
return true
}
func (f *filterQueryMatcher) Got(got interface{}) string {
return (*filterQueryMatcher)(got.(*repository.SearchQuery)).String()
}

View File

@@ -0,0 +1,326 @@
package repository
import (
"database/sql"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
// SearchQuery defines the which and how data are queried
type SearchQuery struct {
Columns eventstore.Columns
SubQueries [][]*Filter
Tx *sql.Tx
LockRows bool
LockOption eventstore.LockOption
AwaitOpenTransactions bool
Limit uint64
Offset uint32
Desc bool
InstanceID *Filter
InstanceIDs *Filter
ExcludedInstances *Filter
Creator *Filter
Owner *Filter
Position *Filter
Sequence *Filter
CreatedAfter *Filter
CreatedBefore *Filter
ExcludeAggregateIDs []*Filter
}
// Filter represents all fields needed to compare a field of an event with a value
type Filter struct {
Field Field
Value interface{}
Operation Operation
}
// Operation defines how fields are compared
type Operation int32
const (
// OperationEquals compares two values for equality
OperationEquals Operation = iota + 1
// OperationGreater compares if the given values is greater than the stored one
OperationGreater
// OperationLess compares if the given values is less than the stored one
OperationLess
// OperationIn checks if a stored value matches one of the passed value list
OperationIn
// OperationJSONContains checks if a stored value matches the given json
OperationJSONContains
// OperationNotIn checks if a stored value does not match one of the passed value list
OperationNotIn
OperationGreaterOrEquals
operationCount
)
// Field is the representation of a field from the event
type Field int32
const (
// FieldAggregateType represents the aggregate type field
FieldAggregateType Field = iota + 1
// FieldAggregateID represents the aggregate id field
FieldAggregateID
// FieldSequence represents the sequence field
FieldSequence
// FieldResourceOwner represents the resource owner field
FieldResourceOwner
// FieldInstanceID represents the instance id field
FieldInstanceID
// FieldEditorService represents the editor service field
FieldEditorService
// FieldEditorUser represents the editor user field
FieldEditorUser
// FieldEventType represents the event type field
FieldEventType
// FieldEventData represents the event data field
FieldEventData
// FieldCreationDate represents the creation date field
FieldCreationDate
// FieldPosition represents the field of the global sequence
FieldPosition
fieldCount
)
// NewFilter is used in tests. Use searchQuery.*Filter() instead
func NewFilter(field Field, value interface{}, operation Operation) *Filter {
return &Filter{
Field: field,
Value: value,
Operation: operation,
}
}
// Validate checks if the fields of the filter have valid values
func (f *Filter) Validate() error {
if f == nil {
return zerrors.ThrowPreconditionFailed(nil, "REPO-z6KcG", "filter is nil")
}
if f.Field <= 0 || f.Field >= fieldCount {
return zerrors.ThrowPreconditionFailed(nil, "REPO-zw62U", "field not definded")
}
if f.Value == nil {
return zerrors.ThrowPreconditionFailed(nil, "REPO-GJ9ct", "no value definded")
}
if f.Operation <= 0 || f.Operation >= operationCount {
return zerrors.ThrowPreconditionFailed(nil, "REPO-RrQTy", "operation not definded")
}
return nil
}
func QueryFromBuilder(builder *eventstore.SearchQueryBuilder) (*SearchQuery, error) {
if builder == nil ||
builder.GetColumns().Validate() != nil {
return nil, zerrors.ThrowPreconditionFailed(nil, "MODEL-4m9gs", "builder invalid")
}
query := &SearchQuery{
Columns: builder.GetColumns(),
Limit: builder.GetLimit(),
Offset: builder.GetOffset(),
Desc: builder.GetDesc(),
Tx: builder.GetTx(),
AwaitOpenTransactions: builder.GetAwaitOpenTransactions(),
SubQueries: make([][]*Filter, len(builder.GetQueries())),
}
query.LockRows, query.LockOption = builder.GetLockRows()
for _, f := range []func(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter{
instanceIDFilter,
instanceIDsFilter,
editorUserFilter,
resourceOwnerFilter,
positionAfterFilter,
eventSequenceGreaterFilter,
creationDateAfterFilter,
creationDateBeforeFilter,
} {
filter := f(builder, query)
if filter == nil {
continue
}
if err := filter.Validate(); err != nil {
return nil, err
}
}
for i, q := range builder.GetQueries() {
for _, f := range []func(query *eventstore.SearchQuery) *Filter{
aggregateTypeFilter,
aggregateIDFilter,
eventTypeFilter,
eventDataFilter,
eventPositionAfterFilter,
} {
filter := f(q)
if filter == nil {
continue
}
if err := filter.Validate(); err != nil {
return nil, err
}
query.SubQueries[i] = append(query.SubQueries[i], filter)
}
}
if excludeAggregateIDs := builder.GetExcludeAggregateIDs(); excludeAggregateIDs != nil {
for _, f := range []func(query *eventstore.ExclusionQuery) *Filter{
excludeAggregateTypeFilter,
excludeEventTypeFilter,
} {
filter := f(excludeAggregateIDs)
if filter == nil {
continue
}
if err := filter.Validate(); err != nil {
return nil, err
}
query.ExcludeAggregateIDs = append(query.ExcludeAggregateIDs, filter)
}
}
return query, nil
}
func eventSequenceGreaterFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetEventSequenceGreater() == 0 {
return nil
}
sortOrder := OperationGreater
if builder.GetDesc() {
sortOrder = OperationLess
}
query.Sequence = NewFilter(FieldSequence, builder.GetEventSequenceGreater(), sortOrder)
return query.Sequence
}
func creationDateAfterFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetCreationDateAfter().IsZero() {
return nil
}
query.CreatedAfter = NewFilter(FieldCreationDate, builder.GetCreationDateAfter(), OperationGreater)
return query.CreatedAfter
}
func creationDateBeforeFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetCreationDateBefore().IsZero() {
return nil
}
query.CreatedBefore = NewFilter(FieldCreationDate, builder.GetCreationDateBefore(), OperationLess)
return query.CreatedBefore
}
func resourceOwnerFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetResourceOwner() == "" {
return nil
}
query.Owner = NewFilter(FieldResourceOwner, builder.GetResourceOwner(), OperationEquals)
return query.Owner
}
func editorUserFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetEditorUser() == "" {
return nil
}
query.Creator = NewFilter(FieldEditorUser, builder.GetEditorUser(), OperationEquals)
return query.Creator
}
func instanceIDFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetInstanceID() == nil {
return nil
}
query.InstanceID = NewFilter(FieldInstanceID, *builder.GetInstanceID(), OperationEquals)
return query.InstanceID
}
func instanceIDsFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetInstanceIDs() == nil {
return nil
}
query.InstanceIDs = NewFilter(FieldInstanceID, database.TextArray[string](builder.GetInstanceIDs()), OperationIn)
return query.InstanceIDs
}
func positionAfterFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuery) *Filter {
if builder.GetPositionAtLeast().IsZero() {
return nil
}
query.Position = NewFilter(FieldPosition, builder.GetPositionAtLeast(), OperationGreaterOrEquals)
return query.Position
}
func aggregateIDFilter(query *eventstore.SearchQuery) *Filter {
if len(query.GetAggregateIDs()) < 1 {
return nil
}
if len(query.GetAggregateIDs()) == 1 {
return NewFilter(FieldAggregateID, query.GetAggregateIDs()[0], OperationEquals)
}
return NewFilter(FieldAggregateID, database.TextArray[string](query.GetAggregateIDs()), OperationIn)
}
func eventTypeFilter(query *eventstore.SearchQuery) *Filter {
if len(query.GetEventTypes()) < 1 {
return nil
}
if len(query.GetEventTypes()) == 1 {
return NewFilter(FieldEventType, query.GetEventTypes()[0], OperationEquals)
}
return NewFilter(FieldEventType, database.TextArray[eventstore.EventType](query.GetEventTypes()), OperationIn)
}
func aggregateTypeFilter(query *eventstore.SearchQuery) *Filter {
if len(query.GetAggregateTypes()) < 1 {
return nil
}
if len(query.GetAggregateTypes()) == 1 {
return NewFilter(FieldAggregateType, query.GetAggregateTypes()[0], OperationEquals)
}
return NewFilter(FieldAggregateType, database.TextArray[eventstore.AggregateType](query.GetAggregateTypes()), OperationIn)
}
func eventDataFilter(query *eventstore.SearchQuery) *Filter {
if len(query.GetEventData()) == 0 {
return nil
}
return NewFilter(FieldEventData, query.GetEventData(), OperationJSONContains)
}
func eventPositionAfterFilter(query *eventstore.SearchQuery) *Filter {
if pos := query.GetPositionAfter(); !pos.Equal(decimal.Decimal{}) {
return NewFilter(FieldPosition, pos, OperationGreater)
}
return nil
}
func excludeEventTypeFilter(query *eventstore.ExclusionQuery) *Filter {
if len(query.GetEventTypes()) < 1 {
return nil
}
if len(query.GetEventTypes()) == 1 {
return NewFilter(FieldEventType, query.GetEventTypes()[0], OperationEquals)
}
return NewFilter(FieldEventType, database.TextArray[eventstore.EventType](query.GetEventTypes()), OperationIn)
}
func excludeAggregateTypeFilter(query *eventstore.ExclusionQuery) *Filter {
if len(query.GetAggregateTypes()) < 1 {
return nil
}
if len(query.GetAggregateTypes()) == 1 {
return NewFilter(FieldAggregateType, query.GetAggregateTypes()[0], OperationEquals)
}
return NewFilter(FieldAggregateType, database.TextArray[eventstore.AggregateType](query.GetAggregateTypes()), OperationIn)
}

View File

@@ -0,0 +1,146 @@
package repository
import (
"reflect"
"testing"
"github.com/zitadel/zitadel/internal/eventstore"
)
func TestNewFilter(t *testing.T) {
type args struct {
field Field
value interface{}
operation Operation
}
tests := []struct {
name string
args args
want *Filter
}{
{
name: "aggregateID equals",
args: args{
field: FieldAggregateID,
value: "hodor",
operation: OperationEquals,
},
want: &Filter{Field: FieldAggregateID, Operation: OperationEquals, Value: "hodor"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := NewFilter(tt.args.field, tt.args.value, tt.args.operation); !reflect.DeepEqual(got, tt.want) {
t.Errorf("NewFilter() = %v, want %v", got, tt.want)
}
})
}
}
func TestFilter_Validate(t *testing.T) {
type fields struct {
field Field
value interface{}
operation Operation
isNil bool
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "correct filter",
fields: fields{
field: FieldSequence,
operation: OperationGreater,
value: uint64(235),
},
wantErr: false,
},
{
name: "filter is nil",
fields: fields{isNil: true},
wantErr: true,
},
{
name: "no field error",
fields: fields{
operation: OperationGreater,
value: uint64(235),
},
wantErr: true,
},
{
name: "no value error",
fields: fields{
field: FieldSequence,
operation: OperationGreater,
},
wantErr: true,
},
{
name: "no operation error",
fields: fields{
field: FieldSequence,
value: uint64(235),
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var f *Filter
if !tt.fields.isNil {
f = &Filter{
Field: tt.fields.field,
Value: tt.fields.value,
Operation: tt.fields.operation,
}
}
if err := f.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Filter.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}
func TestColumns_Validate(t *testing.T) {
type fields struct {
columns eventstore.Columns
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "correct filter",
fields: fields{
columns: eventstore.ColumnsEvent,
},
wantErr: false,
},
{
name: "columns too low",
fields: fields{
columns: 0,
},
wantErr: true,
},
{
name: "columns too high",
fields: fields{
columns: 100,
},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.fields.columns.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Columns.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,115 @@
package sql
import (
"context"
"database/sql"
"os"
"testing"
"time"
pgxdecimal "github.com/jackc/pgx-shopspring-decimal"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/cmd/initialise"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/database/postgres"
new_es "github.com/zitadel/zitadel/internal/eventstore/v3"
)
var (
testClient *sql.DB
)
func TestMain(m *testing.M) {
os.Exit(func() int {
config, cleanup := postgres.StartEmbedded()
defer cleanup()
connConfig, err := pgxpool.ParseConfig(config.GetConnectionURL())
logging.OnError(err).Fatal("unable to parse db url")
connConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
pgxdecimal.Register(conn.TypeMap())
return new_es.RegisterEventstoreTypes(ctx, conn)
}
pool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
logging.OnError(err).Fatal("unable to create db pool")
testClient = stdlib.OpenDBFromPool(pool)
err = testClient.Ping()
logging.OnError(err).Fatal("unable to ping db")
defer func() {
logging.OnError(testClient.Close()).Error("unable to close db")
}()
err = initDB(context.Background(), &database.DB{DB: testClient, Database: &postgres.Config{Database: "zitadel"}})
logging.OnError(err).Fatal("migrations failed")
return m.Run()
}())
}
func initDB(ctx context.Context, db *database.DB) error {
config := new(database.Config)
config.SetConnector(&postgres.Config{User: postgres.User{Username: "zitadel"}, Database: "zitadel"})
if err := initialise.ReadStmts(); err != nil {
return err
}
err := initialise.Init(ctx, db,
initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.DatabaseName()),
initialise.VerifyGrant(config.DatabaseName(), config.Username()))
if err != nil {
return err
}
err = initialise.VerifyZitadel(context.Background(), db, *config)
if err != nil {
return err
}
// create old events
_, err = db.Exec(oldEventsTable)
return err
}
type testDB struct{}
func (_ *testDB) Timetravel(time.Duration) string { return " AS OF SYSTEM TIME '-1 ms' " }
func (*testDB) DatabaseName() string { return "db" }
func (*testDB) Username() string { return "user" }
func (*testDB) Type() dialect.DatabaseType { return dialect.DatabaseTypePostgres }
const oldEventsTable = `CREATE TABLE IF NOT EXISTS eventstore.events (
id UUID DEFAULT gen_random_uuid()
, event_type TEXT NOT NULL
, aggregate_type TEXT NOT NULL
, aggregate_id TEXT NOT NULL
, aggregate_version TEXT NOT NULL
, event_sequence BIGINT NOT NULL
, previous_aggregate_sequence BIGINT
, previous_aggregate_type_sequence INT8
, creation_date TIMESTAMPTZ NOT NULL DEFAULT now()
, created_at TIMESTAMPTZ NOT NULL DEFAULT clock_timestamp()
, event_data JSONB
, editor_user TEXT NOT NULL
, editor_service TEXT
, resource_owner TEXT NOT NULL
, instance_id TEXT NOT NULL
, "position" DECIMAL NOT NULL
, in_tx_order INTEGER NOT NULL
, PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence)
);`

View File

@@ -0,0 +1,242 @@
package sql
import (
"context"
"errors"
"regexp"
"strconv"
"github.com/jackc/pgx/v5/pgconn"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
// awaitOpenTransactions ensures event ordering, so we don't events younger that open transactions
var (
awaitOpenTransactionsV1 = ` AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')`
awaitOpenTransactionsV2 = ` AND "position" < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')`
)
func awaitOpenTransactions(useV1 bool) string {
if useV1 {
return awaitOpenTransactionsV1
}
return awaitOpenTransactionsV2
}
type Postgres struct {
*database.DB
}
func NewPostgres(client *database.DB) *Postgres {
return &Postgres{client}
}
func (db *Postgres) Health(ctx context.Context) error { return db.Ping() }
// FilterToReducer finds all events matching the given search query and passes them to the reduce function.
func (psql *Postgres) FilterToReducer(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder, reduce eventstore.Reducer) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
err = query(ctx, psql, searchQuery, reduce, false)
if err == nil {
return nil
}
pgErr := new(pgconn.PgError)
// check events2 not exists
if errors.As(err, &pgErr) && pgErr.Code == "42P01" {
return query(ctx, psql, searchQuery, reduce, true)
}
return err
}
// LatestPosition returns the latest position found by the search query
func (db *Postgres) LatestPosition(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) (decimal.Decimal, error) {
var position decimal.Decimal
err := query(ctx, db, searchQuery, &position, false)
return position, err
}
// InstanceIDs returns the instance ids found by the search query
func (db *Postgres) InstanceIDs(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) ([]string, error) {
var ids []string
err := query(ctx, db, searchQuery, &ids, false)
if err != nil {
return nil, err
}
return ids, nil
}
func (db *Postgres) Client() *database.DB {
return db.DB
}
func (db *Postgres) orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string {
if useV1 {
if desc {
return ` ORDER BY event_sequence DESC`
}
return ` ORDER BY event_sequence`
}
if shouldOrderBySequence {
if desc {
return ` ORDER BY "sequence" DESC`
}
return ` ORDER BY "sequence"`
}
if desc {
return ` ORDER BY "position" DESC, in_tx_order DESC`
}
return ` ORDER BY "position", in_tx_order`
}
func (db *Postgres) eventQuery(useV1 bool) string {
if useV1 {
return "SELECT" +
" creation_date" +
", event_type" +
", event_sequence" +
", event_data" +
", editor_user" +
", resource_owner" +
", instance_id" +
", aggregate_type" +
", aggregate_id" +
", aggregate_version" +
" FROM eventstore.events"
}
return "SELECT" +
" created_at" +
", event_type" +
`, "sequence"` +
`, "position"` +
", payload" +
", creator" +
`, "owner"` +
", instance_id" +
", aggregate_type" +
", aggregate_id" +
", revision" +
" FROM eventstore.events2"
}
func (db *Postgres) maxPositionQuery(useV1 bool) string {
if useV1 {
return `SELECT event_sequence FROM eventstore.events`
}
return `SELECT "position" FROM eventstore.events2`
}
func (db *Postgres) instanceIDsQuery(useV1 bool) string {
table := "eventstore.events2"
if useV1 {
table = "eventstore.events"
}
return "SELECT DISTINCT instance_id FROM " + table
}
func (db *Postgres) columnName(col repository.Field, useV1 bool) string {
switch col {
case repository.FieldAggregateID:
return "aggregate_id"
case repository.FieldAggregateType:
return "aggregate_type"
case repository.FieldSequence:
if useV1 {
return "event_sequence"
}
return `"sequence"`
case repository.FieldResourceOwner:
if useV1 {
return "resource_owner"
}
return `"owner"`
case repository.FieldInstanceID:
return "instance_id"
case repository.FieldEditorService:
if useV1 {
return "editor_service"
}
return ""
case repository.FieldEditorUser:
if useV1 {
return "editor_user"
}
return "creator"
case repository.FieldEventType:
return "event_type"
case repository.FieldEventData:
if useV1 {
return "event_data"
}
return "payload"
case repository.FieldCreationDate:
if useV1 {
return "creation_date"
}
return "created_at"
case repository.FieldPosition:
return `"position"`
default:
return ""
}
}
func (db *Postgres) conditionFormat(operation repository.Operation) string {
switch operation {
case repository.OperationIn:
return "%s %s ANY(?)"
case repository.OperationNotIn:
return "%s %s ALL(?)"
case repository.OperationEquals, repository.OperationGreater, repository.OperationLess, repository.OperationJSONContains:
fallthrough
default:
return "%s %s ?"
}
}
func (db *Postgres) operation(operation repository.Operation) string {
switch operation {
case repository.OperationEquals, repository.OperationIn:
return "="
case repository.OperationGreater:
return ">"
case repository.OperationGreaterOrEquals:
return ">="
case repository.OperationLess:
return "<"
case repository.OperationJSONContains:
return "@>"
case repository.OperationNotIn:
return "<>"
}
return ""
}
var (
placeholder = regexp.MustCompile(`\?`)
)
// placeholder replaces all "?" with postgres placeholders ($<NUMBER>)
func (db *Postgres) placeholder(query string) string {
occurrences := placeholder.FindAllStringIndex(query, -1)
if len(occurrences) == 0 {
return query
}
replaced := query[:occurrences[0][0]]
for i, l := range occurrences {
nextIDX := len(query)
if i < len(occurrences)-1 {
nextIDX = occurrences[i+1][0]
}
replaced = replaced + "$" + strconv.Itoa(i+1) + query[l[1]:nextIDX]
}
return replaced
}

View File

@@ -0,0 +1,325 @@
package sql
import (
"database/sql"
"testing"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
func TestPostgres_placeholder(t *testing.T) {
type args struct {
query string
}
type res struct {
query string
}
tests := []struct {
name string
args args
res res
}{
{
name: "no placeholders",
args: args{
query: "SELECT * FROM eventstore.events2",
},
res: res{
query: "SELECT * FROM eventstore.events2",
},
},
{
name: "one placeholder",
args: args{
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = ?",
},
res: res{
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = $1",
},
},
{
name: "multiple placeholders",
args: args{
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = ? AND aggregate_id = ? LIMIT ?",
},
res: res{
query: "SELECT * FROM eventstore.events2 WHERE aggregate_type = $1 AND aggregate_id = $2 LIMIT $3",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &Postgres{}
if query := db.placeholder(tt.args.query); query != tt.res.query {
t.Errorf("Postgres.placeholder() = %v, want %v", query, tt.res.query)
}
})
}
}
func TestPostgres_operation(t *testing.T) {
type res struct {
op string
}
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
res res
}{
{
name: "no op",
args: args{
operation: repository.Operation(-1),
},
res: res{
op: "",
},
},
{
name: "greater",
args: args{
operation: repository.OperationGreater,
},
res: res{
op: ">",
},
},
{
name: "less",
args: args{
operation: repository.OperationLess,
},
res: res{
op: "<",
},
},
{
name: "equals",
args: args{
operation: repository.OperationEquals,
},
res: res{
op: "=",
},
},
{
name: "in",
args: args{
operation: repository.OperationIn,
},
res: res{
op: "=",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &Postgres{}
if got := db.operation(tt.args.operation); got != tt.res.op {
t.Errorf("Postgres.operation() = %v, want %v", got, tt.res.op)
}
})
}
}
func TestPostgres_conditionFormat(t *testing.T) {
type res struct {
format string
}
type args struct {
operation repository.Operation
}
tests := []struct {
name string
args args
res res
}{
{
name: "default",
args: args{
operation: repository.OperationEquals,
},
res: res{
format: "%s %s ?",
},
},
{
name: "in",
args: args{
operation: repository.OperationIn,
},
res: res{
format: "%s %s ANY(?)",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &Postgres{}
if got := db.conditionFormat(tt.args.operation); got != tt.res.format {
t.Errorf("Postgres.conditionFormat() = %v, want %v", got, tt.res.format)
}
})
}
}
func TestPostgres_columnName(t *testing.T) {
type res struct {
name string
}
type args struct {
field repository.Field
useV1 bool
}
tests := []struct {
name string
args args
res res
}{
{
name: "invalid field",
args: args{
field: repository.Field(-1),
},
res: res{
name: "",
},
},
{
name: "aggregate id",
args: args{
field: repository.FieldAggregateID,
},
res: res{
name: "aggregate_id",
},
},
{
name: "aggregate type",
args: args{
field: repository.FieldAggregateType,
},
res: res{
name: "aggregate_type",
},
},
{
name: "editor service",
args: args{
field: repository.FieldEditorService,
useV1: true,
},
res: res{
name: "editor_service",
},
},
{
name: "editor service v2",
args: args{
field: repository.FieldEditorService,
},
res: res{
name: "",
},
},
{
name: "editor user",
args: args{
field: repository.FieldEditorUser,
useV1: true,
},
res: res{
name: "editor_user",
},
},
{
name: "editor user v2",
args: args{
field: repository.FieldEditorUser,
},
res: res{
name: "creator",
},
},
{
name: "event type",
args: args{
field: repository.FieldEventType,
},
res: res{
name: "event_type",
},
},
{
name: "latest sequence",
args: args{
field: repository.FieldSequence,
useV1: true,
},
res: res{
name: "event_sequence",
},
},
{
name: "latest sequence v2",
args: args{
field: repository.FieldSequence,
},
res: res{
name: `"sequence"`,
},
},
{
name: "resource owner",
args: args{
field: repository.FieldResourceOwner,
useV1: true,
},
res: res{
name: "resource_owner",
},
},
{
name: "resource owner v2",
args: args{
field: repository.FieldResourceOwner,
},
res: res{
name: `"owner"`,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
db := &Postgres{}
if got := db.columnName(tt.args.field, tt.args.useV1); got != tt.res.name {
t.Errorf("Postgres.operation() = %v, want %v", got, tt.res.name)
}
})
}
}
func generateEvent(t *testing.T, aggregateID string, opts ...func(*repository.Event)) *repository.Event {
t.Helper()
e := &repository.Event{
AggregateID: aggregateID,
AggregateType: eventstore.AggregateType(t.Name()),
EditorUser: "user",
ResourceOwner: sql.NullString{String: "ro", Valid: true},
Typ: "test.created",
Version: "v1",
Pos: decimal.NewFromInt(42),
}
for _, opt := range opts {
opt(e)
}
return e
}

View File

@@ -0,0 +1,355 @@
package sql
import (
"context"
"database/sql"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/zerrors"
)
type querier interface {
columnName(field repository.Field, useV1 bool) string
operation(repository.Operation) string
conditionFormat(repository.Operation) string
placeholder(query string) string
eventQuery(useV1 bool) string
maxPositionQuery(useV1 bool) string
instanceIDsQuery(useV1 bool) string
Client() *database.DB
orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string
dialect.Database
}
type scan func(dest ...interface{}) error
type tx struct {
*sql.Tx
}
func (t *tx) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
rows, err := t.Tx.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Info("rows.Close failed")
}()
if err = scan(rows); err != nil {
return err
}
return rows.Err()
}
func query(ctx context.Context, criteria querier, searchQuery *eventstore.SearchQueryBuilder, dest interface{}, useV1 bool) error {
q, err := repository.QueryFromBuilder(searchQuery)
if err != nil {
return err
}
query, rowScanner := prepareColumns(criteria, q.Columns, useV1)
where, values := prepareConditions(criteria, q, useV1)
if where == "" || query == "" {
return zerrors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
}
query += where
// instead of using the max function of the database (which doesn't work for postgres)
// we select the most recent row
if q.Columns == eventstore.ColumnsMaxPosition {
q.Limit = 1
q.Desc = true
}
// if there is only one subquery we can optimize the query ordering by ordering by sequence
var shouldOrderBySequence bool
if len(q.SubQueries) == 1 {
for _, filter := range q.SubQueries[0] {
if filter.Field == repository.FieldAggregateID {
shouldOrderBySequence = filter.Operation == repository.OperationEquals
}
}
}
switch q.Columns {
case eventstore.ColumnsEvent,
eventstore.ColumnsMaxPosition:
query += criteria.orderByEventSequence(q.Desc, shouldOrderBySequence, useV1)
}
if q.Limit > 0 {
values = append(values, q.Limit)
query += " LIMIT ?"
}
if q.Offset > 0 {
values = append(values, q.Offset)
query += " OFFSET ?"
}
if q.LockRows {
query += " FOR UPDATE"
switch q.LockOption {
case eventstore.LockOptionWait: // default behavior
case eventstore.LockOptionNoWait:
query += " NOWAIT"
case eventstore.LockOptionSkipLocked:
query += " SKIP LOCKED"
}
}
query = criteria.placeholder(query)
var contextQuerier interface {
QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error
}
contextQuerier = criteria.Client()
if q.Tx != nil {
contextQuerier = &tx{Tx: q.Tx}
}
err = contextQuerier.QueryContext(ctx,
func(rows *sql.Rows) error {
for rows.Next() {
err := rowScanner(rows.Scan, dest)
if err != nil {
return err
}
}
return nil
}, query, values...)
if err != nil {
logging.New().WithError(err).Info("query failed")
return zerrors.ThrowInternal(err, "SQL-KyeAx", "unable to filter events")
}
return nil
}
func prepareColumns(criteria querier, columns eventstore.Columns, useV1 bool) (string, func(s scan, dest interface{}) error) {
switch columns {
case eventstore.ColumnsMaxPosition:
return criteria.maxPositionQuery(useV1), maxPositionScanner
case eventstore.ColumnsInstanceIDs:
return criteria.instanceIDsQuery(useV1), instanceIDsScanner
case eventstore.ColumnsEvent:
return criteria.eventQuery(useV1), eventsScanner(useV1)
default:
return "", nil
}
}
func maxPositionScanner(row scan, dest interface{}) (err error) {
position, ok := dest.(*decimal.Decimal)
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "SQL-NBjA9", "type must be pointer to decimal.Decimal got: %T", dest)
}
var res decimal.NullDecimal
err = row(&res)
if err == nil || errors.Is(err, sql.ErrNoRows) {
*position = res.Decimal
return nil
}
return zerrors.ThrowInternal(err, "SQL-bN5xg", "something went wrong")
}
func instanceIDsScanner(scanner scan, dest interface{}) (err error) {
ids, ok := dest.(*[]string)
if !ok {
return zerrors.ThrowInvalidArgument(nil, "SQL-Begh2", "type must be an array of string")
}
var id string
err = scanner(&id)
if err != nil {
logging.WithError(err).Warn("unable to scan row")
return zerrors.ThrowInternal(err, "SQL-DEFGe", "unable to scan row")
}
*ids = append(*ids, id)
return nil
}
func eventsScanner(useV1 bool) func(scanner scan, dest interface{}) (err error) {
return func(scanner scan, dest interface{}) (err error) {
reduce, ok := dest.(eventstore.Reducer)
if !ok {
return zerrors.ThrowInvalidArgumentf(nil, "SQL-4GP6F", "events scanner: invalid type %T", dest)
}
event := new(repository.Event)
position := new(decimal.NullDecimal)
if useV1 {
err = scanner(
&event.CreationDate,
&event.Typ,
&event.Seq,
&event.Data,
&event.EditorUser,
&event.ResourceOwner,
&event.InstanceID,
&event.AggregateType,
&event.AggregateID,
&event.Version,
)
} else {
var revision uint8
err = scanner(
&event.CreationDate,
&event.Typ,
&event.Seq,
position,
&event.Data,
&event.EditorUser,
&event.ResourceOwner,
&event.InstanceID,
&event.AggregateType,
&event.AggregateID,
&revision,
)
event.Version = eventstore.Version("v" + strconv.Itoa(int(revision)))
}
if err != nil {
logging.New().WithError(err).Warn("unable to scan row")
return zerrors.ThrowInternal(err, "SQL-M0dsf", "unable to scan row")
}
event.Pos = position.Decimal
return reduce(event)
}
}
func prepareConditions(criteria querier, query *repository.SearchQuery, useV1 bool) (_ string, args []any) {
clauses, args := prepareQuery(criteria, useV1, query.InstanceID, query.InstanceIDs, query.ExcludedInstances)
if clauses != "" && len(query.SubQueries) > 0 {
clauses += " AND "
}
subClauses := make([]string, len(query.SubQueries))
for i, filters := range query.SubQueries {
var subArgs []any
subClauses[i], subArgs = prepareQuery(criteria, useV1, filters...)
// an error is thrown in [query]
if subClauses[i] == "" {
return "", nil
}
if len(query.SubQueries) > 1 && len(subArgs) > 1 {
subClauses[i] = "(" + subClauses[i] + ")"
}
args = append(args, subArgs...)
}
if len(subClauses) == 1 {
clauses += subClauses[0]
} else if len(subClauses) > 1 {
clauses += "(" + strings.Join(subClauses, " OR ") + ")"
}
additionalClauses, additionalArgs := prepareQuery(criteria, useV1,
query.Position,
query.Owner,
query.Sequence,
query.CreatedAfter,
query.CreatedBefore,
query.Creator,
)
if additionalClauses != "" {
if clauses != "" {
clauses += " AND "
}
clauses += additionalClauses
args = append(args, additionalArgs...)
}
excludeAggregateIDs := query.ExcludeAggregateIDs
if len(excludeAggregateIDs) > 0 {
excludeAggregateIDs = append(excludeAggregateIDs, query.InstanceID, query.InstanceIDs, query.Position, query.CreatedAfter, query.CreatedBefore)
}
excludeAggregateIDsClauses, excludeAggregateIDsArgs := prepareQuery(criteria, useV1, excludeAggregateIDs...)
if excludeAggregateIDsClauses != "" {
if clauses != "" {
clauses += " AND "
}
if useV1 {
clauses += "aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events WHERE " + excludeAggregateIDsClauses + ")"
} else {
clauses += "aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events2 WHERE " + excludeAggregateIDsClauses + ")"
}
args = append(args, excludeAggregateIDsArgs...)
}
if query.AwaitOpenTransactions {
instanceIDs := make(database.TextArray[string], 0, 3)
if query.InstanceID != nil {
instanceIDs = append(instanceIDs, query.InstanceID.Value.(string))
} else if query.InstanceIDs != nil {
instanceIDs = append(instanceIDs, query.InstanceIDs.Value.(database.TextArray[string])...)
}
for i := range instanceIDs {
instanceIDs[i] = "zitadel_es_pusher_" + instanceIDs[i]
}
clauses += awaitOpenTransactions(useV1)
args = append(args, instanceIDs)
}
if clauses == "" {
return "", nil
}
return " WHERE " + clauses, args
}
func prepareQuery(criteria querier, useV1 bool, filters ...*repository.Filter) (_ string, args []any) {
clauses := make([]string, 0, len(filters))
args = make([]any, 0, len(filters))
for _, filter := range filters {
if filter == nil {
continue
}
arg := filter.Value
// marshal if payload filter
if filter.Field == repository.FieldEventData {
var err error
arg, err = json.Marshal(arg)
if err != nil {
logging.WithError(err).Warn("unable to marshal search value")
continue
}
}
clauses = append(clauses, getCondition(criteria, filter, useV1))
// if mapping failed an error is thrown in [query]
if clauses[len(clauses)-1] == "" {
return "", nil
}
args = append(args, arg)
}
return strings.Join(clauses, " AND "), args
}
func getCondition(cond querier, filter *repository.Filter, useV1 bool) (condition string) {
field := cond.columnName(filter.Field, useV1)
operation := cond.operation(filter.Operation)
if field == "" || operation == "" {
return ""
}
format := cond.conditionFormat(filter.Operation)
return fmt.Sprintf(format, field, operation)
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,436 @@
package eventstore
import (
"context"
"database/sql"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/zerrors"
)
// SearchQueryBuilder represents the builder for your filter
// if invalid data are set the filter will fail
type SearchQueryBuilder struct {
columns Columns
limit uint64
offset uint32
desc bool
resourceOwner string
instanceID *string
instanceIDs []string
editorUser string
queries []*SearchQuery
excludeAggregateIDs *ExclusionQuery
tx *sql.Tx
lockRows bool
lockOption LockOption
positionAtLeast decimal.Decimal
awaitOpenTransactions bool
creationDateAfter time.Time
creationDateBefore time.Time
eventSequenceGreater uint64
}
func (b *SearchQueryBuilder) GetColumns() Columns {
return b.columns
}
func (b *SearchQueryBuilder) GetLimit() uint64 {
return b.limit
}
func (b *SearchQueryBuilder) GetOffset() uint32 {
return b.offset
}
func (b *SearchQueryBuilder) GetDesc() bool {
return b.desc
}
func (b *SearchQueryBuilder) GetResourceOwner() string {
return b.resourceOwner
}
func (b *SearchQueryBuilder) GetInstanceID() *string {
return b.instanceID
}
func (b *SearchQueryBuilder) GetInstanceIDs() []string {
return b.instanceIDs
}
func (b *SearchQueryBuilder) GetEditorUser() string {
return b.editorUser
}
func (b *SearchQueryBuilder) GetQueries() []*SearchQuery {
return b.queries
}
func (b *SearchQueryBuilder) GetExcludeAggregateIDs() *ExclusionQuery {
return b.excludeAggregateIDs
}
func (b *SearchQueryBuilder) GetTx() *sql.Tx {
return b.tx
}
func (b SearchQueryBuilder) GetPositionAtLeast() decimal.Decimal {
return b.positionAtLeast
}
func (b SearchQueryBuilder) GetAwaitOpenTransactions() bool {
return b.awaitOpenTransactions
}
func (q SearchQueryBuilder) GetEventSequenceGreater() uint64 {
return q.eventSequenceGreater
}
func (q SearchQueryBuilder) GetCreationDateAfter() time.Time {
return q.creationDateAfter
}
func (q SearchQueryBuilder) GetCreationDateBefore() time.Time {
return q.creationDateBefore
}
func (q SearchQueryBuilder) GetLockRows() (bool, LockOption) {
return q.lockRows, q.lockOption
}
// ensureInstanceID makes sure that the instance id is always set
func (b *SearchQueryBuilder) ensureInstanceID(ctx context.Context) {
if b.instanceID == nil && len(b.instanceIDs) == 0 && authz.GetInstance(ctx).InstanceID() != "" {
b.InstanceID(authz.GetInstance(ctx).InstanceID())
}
}
type SearchQuery struct {
builder *SearchQueryBuilder
aggregateTypes []AggregateType
aggregateIDs []string
eventTypes []EventType
eventData map[string]interface{}
positionAfter decimal.Decimal
}
func (q SearchQuery) GetAggregateTypes() []AggregateType {
return q.aggregateTypes
}
func (q SearchQuery) GetAggregateIDs() []string {
return q.aggregateIDs
}
func (q SearchQuery) GetEventTypes() []EventType {
return q.eventTypes
}
func (q SearchQuery) GetEventData() map[string]interface{} {
return q.eventData
}
func (q SearchQuery) GetPositionAfter() decimal.Decimal {
return q.positionAfter
}
type ExclusionQuery struct {
builder *SearchQueryBuilder
aggregateTypes []AggregateType
eventTypes []EventType
}
func (q ExclusionQuery) GetAggregateTypes() []AggregateType {
return q.aggregateTypes
}
func (q ExclusionQuery) GetEventTypes() []EventType {
return q.eventTypes
}
// Columns defines which fields of the event are needed for the query
type Columns int8
const (
//ColumnsEvent represents all fields of an event
ColumnsEvent = iota + 1
// ColumnsMaxPosition represents the latest sequence of the filtered events
ColumnsMaxPosition
// ColumnsInstanceIDs represents the instance ids of the filtered events
ColumnsInstanceIDs
columnsCount
)
func (c Columns) Validate() error {
if c <= 0 || c >= columnsCount {
return zerrors.ThrowPreconditionFailed(nil, "REPOS-x8R35", "column out of range")
}
return nil
}
// NewSearchQueryBuilder creates a new builder for event filters
// aggregateTypes must contain at least one aggregate type
func NewSearchQueryBuilder(columns Columns) *SearchQueryBuilder {
return &SearchQueryBuilder{
columns: columns,
}
}
func (builder *SearchQueryBuilder) Matches(commands ...Command) []Command {
matches := make([]Command, 0, len(commands))
for i, command := range commands {
if builder.limit > 0 && builder.limit <= uint64(len(matches)) {
break
}
if builder.offset > 0 && uint32(i) < builder.offset {
continue
}
if builder.matchCommand(command) {
matches = append(matches, command)
}
}
return matches
}
type sequencer interface {
Sequence() uint64
}
func (builder *SearchQueryBuilder) matchCommand(command Command) bool {
if builder.resourceOwner != "" && command.Aggregate().ResourceOwner != builder.resourceOwner {
return false
}
if command.Aggregate().InstanceID != "" && builder.instanceID != nil && *builder.instanceID != "" && command.Aggregate().InstanceID != *builder.instanceID {
return false
}
if seq, ok := command.(sequencer); ok {
if builder.eventSequenceGreater > 0 && seq.Sequence() <= builder.eventSequenceGreater {
return false
}
}
if len(builder.queries) == 0 {
return true
}
for _, query := range builder.queries {
if query.matches(command) {
return true
}
}
return false
}
// Columns defines which fields are set
func (builder *SearchQueryBuilder) Columns(columns Columns) *SearchQueryBuilder {
builder.columns = columns
return builder
}
// Limit defines how many events are returned maximally.
func (builder *SearchQueryBuilder) Limit(limit uint64) *SearchQueryBuilder {
builder.limit = limit
return builder
}
// Limit defines how many events are returned maximally.
func (builder *SearchQueryBuilder) Offset(offset uint32) *SearchQueryBuilder {
builder.offset = offset
return builder
}
// ResourceOwner defines the resource owner (org or instance) of the events
func (builder *SearchQueryBuilder) ResourceOwner(resourceOwner string) *SearchQueryBuilder {
builder.resourceOwner = resourceOwner
return builder
}
// InstanceID defines the instanceID (system) of the events
func (builder *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder {
builder.instanceID = &instanceID
return builder
}
// InstanceIDs defines the instanceIDs (system) of the events
func (builder *SearchQueryBuilder) InstanceIDs(instanceIDs []string) *SearchQueryBuilder {
builder.instanceIDs = instanceIDs
return builder
}
// OrderDesc changes the sorting order of the returned events to descending
func (builder *SearchQueryBuilder) OrderDesc() *SearchQueryBuilder {
builder.desc = true
return builder
}
// OrderAsc changes the sorting order of the returned events to ascending
func (builder *SearchQueryBuilder) OrderAsc() *SearchQueryBuilder {
builder.desc = false
return builder
}
// SetTx ensures that the eventstore library uses the existing transaction
func (builder *SearchQueryBuilder) SetTx(tx *sql.Tx) *SearchQueryBuilder {
builder.tx = tx
return builder
}
func (builder *SearchQueryBuilder) EditorUser(id string) *SearchQueryBuilder {
builder.editorUser = id
return builder
}
// PositionAtLeast filters for events which happened after the specified time
func (builder *SearchQueryBuilder) PositionAtLeast(position decimal.Decimal) *SearchQueryBuilder {
builder.positionAtLeast = position
return builder
}
// AwaitOpenTransactions filters for events which are older than the oldest transaction of the database
func (builder *SearchQueryBuilder) AwaitOpenTransactions() *SearchQueryBuilder {
builder.awaitOpenTransactions = true
return builder
}
// SequenceGreater filters for events with sequence greater the requested sequence
func (builder *SearchQueryBuilder) SequenceGreater(sequence uint64) *SearchQueryBuilder {
builder.eventSequenceGreater = sequence
return builder
}
// CreationDateAfter filters for events which happened after the specified time
func (builder *SearchQueryBuilder) CreationDateAfter(creationDate time.Time) *SearchQueryBuilder {
if creationDate.IsZero() || creationDate.Unix() == 0 {
return builder
}
builder.creationDateAfter = creationDate
return builder
}
// CreationDateBefore filters for events which happened before the specified time
func (builder *SearchQueryBuilder) CreationDateBefore(creationDate time.Time) *SearchQueryBuilder {
if creationDate.IsZero() || creationDate.Unix() == 0 {
return builder
}
builder.creationDateBefore = creationDate
return builder
}
type LockOption int
const (
// Wait until the previous lock on all of the selected rows is released (default)
LockOptionWait LockOption = iota
// With NOWAIT, the statement reports an error, rather than waiting, if a selected row cannot be locked immediately.
LockOptionNoWait
// With SKIP LOCKED, any selected rows that cannot be immediately locked are skipped.
LockOptionSkipLocked
)
// LockRowsDuringTx locks the found rows for the duration of the transaction,
// using the [`FOR UPDATE`](https://www.postgresql.org/docs/17/sql-select.html#SQL-FOR-UPDATE-SHARE) lock strength.
// The lock is removed on transaction commit or rollback.
func (builder *SearchQueryBuilder) LockRowsDuringTx(tx *sql.Tx, option LockOption) *SearchQueryBuilder {
builder.tx = tx
builder.lockRows = true
builder.lockOption = option
return builder
}
// AddQuery creates a new sub query.
// All fields in the sub query are AND-connected in the storage request.
// Multiple sub queries are OR-connected in the storage request.
func (builder *SearchQueryBuilder) AddQuery() *SearchQuery {
query := &SearchQuery{
builder: builder,
}
builder.queries = append(builder.queries, query)
return query
}
// ExcludeAggregateIDs excludes events from the aggregate IDs returned by the [ExclusionQuery].
// There can be only 1 exclusion query. Subsequent calls overwrite previous definitions.
func (builder *SearchQueryBuilder) ExcludeAggregateIDs() *ExclusionQuery {
query := &ExclusionQuery{
builder: builder,
}
builder.excludeAggregateIDs = query
return query
}
// Or creates a new sub query on the search query builder
func (query SearchQuery) Or() *SearchQuery {
return query.builder.AddQuery()
}
// AggregateTypes filters for events with the given aggregate types
func (query *SearchQuery) AggregateTypes(types ...AggregateType) *SearchQuery {
query.aggregateTypes = types
return query
}
// AggregateIDs filters for events with the given aggregate id's
func (query *SearchQuery) AggregateIDs(ids ...string) *SearchQuery {
query.aggregateIDs = ids
return query
}
// EventTypes filters for events with the given event types
func (query *SearchQuery) EventTypes(types ...EventType) *SearchQuery {
query.eventTypes = types
return query
}
// EventData filters for events with the given event data.
// Use this call with care as it will be slower than the other filters.
func (query *SearchQuery) EventData(data map[string]interface{}) *SearchQuery {
query.eventData = data
return query
}
func (query *SearchQuery) PositionAfter(position decimal.Decimal) *SearchQuery {
query.positionAfter = position
return query
}
// Builder returns the SearchQueryBuilder of the sub query
func (query *SearchQuery) Builder() *SearchQueryBuilder {
return query.builder
}
func (query *SearchQuery) matches(command Command) bool {
if ok := isAggregateTypes(command.Aggregate(), query.aggregateTypes...); len(query.aggregateTypes) > 0 && !ok {
return false
}
if ok := isAggregateIDs(command.Aggregate(), query.aggregateIDs...); len(query.aggregateIDs) > 0 && !ok {
return false
}
if ok := isEventTypes(command, query.eventTypes...); len(query.eventTypes) > 0 && !ok {
return false
}
return true
}
// AggregateTypes filters for events with the given aggregate types
func (query *ExclusionQuery) AggregateTypes(types ...AggregateType) *ExclusionQuery {
query.aggregateTypes = types
return query
}
// EventTypes filters for events with the given event types
func (query *ExclusionQuery) EventTypes(types ...EventType) *ExclusionQuery {
query.eventTypes = types
return query
}
// Builder returns the SearchQueryBuilder of the sub query
func (query *ExclusionQuery) Builder() *SearchQueryBuilder {
return query.builder
}

View File

@@ -0,0 +1,707 @@
package eventstore
import (
"reflect"
"testing"
)
func testSetQuery(queryFuncs ...func(*SearchQueryBuilder) *SearchQueryBuilder) func(*SearchQueryBuilder) *SearchQueryBuilder {
return func(builder *SearchQueryBuilder) *SearchQueryBuilder {
for _, queryFunc := range queryFuncs {
queryFunc(builder)
}
return builder
}
}
func testSetSequenceGreater(sequence uint64) func(*SearchQueryBuilder) *SearchQueryBuilder {
return func(builder *SearchQueryBuilder) *SearchQueryBuilder {
builder = builder.SequenceGreater(sequence)
return builder
}
}
func testAddSubQuery(queryFuncs ...func(*SearchQuery) *SearchQuery) func(*SearchQueryBuilder) *SearchQueryBuilder {
return func(builder *SearchQueryBuilder) *SearchQueryBuilder {
query := builder.AddQuery()
for _, queryFunc := range queryFuncs {
queryFunc(query)
}
return query.Builder()
}
}
func testSetColumns(columns Columns) func(factory *SearchQueryBuilder) *SearchQueryBuilder {
return func(factory *SearchQueryBuilder) *SearchQueryBuilder {
factory = factory.Columns(columns)
return factory
}
}
func testSetLimit(limit uint64) func(builder *SearchQueryBuilder) *SearchQueryBuilder {
return func(builder *SearchQueryBuilder) *SearchQueryBuilder {
builder = builder.Limit(limit)
return builder
}
}
func testSetAggregateTypes(types ...AggregateType) func(*SearchQuery) *SearchQuery {
return func(query *SearchQuery) *SearchQuery {
query = query.AggregateTypes(types...)
return query
}
}
func testSetAggregateIDs(aggregateIDs ...string) func(*SearchQuery) *SearchQuery {
return func(query *SearchQuery) *SearchQuery {
query = query.AggregateIDs(aggregateIDs...)
return query
}
}
func testSetEventTypes(eventTypes ...EventType) func(*SearchQuery) *SearchQuery {
return func(query *SearchQuery) *SearchQuery {
query = query.EventTypes(eventTypes...)
return query
}
}
func testSetResourceOwner(resourceOwner string) func(*SearchQueryBuilder) *SearchQueryBuilder {
return func(builder *SearchQueryBuilder) *SearchQueryBuilder {
builder = builder.ResourceOwner(resourceOwner)
return builder
}
}
func testSetSortOrder(asc bool) func(*SearchQueryBuilder) *SearchQueryBuilder {
return func(query *SearchQueryBuilder) *SearchQueryBuilder {
if asc {
query = query.OrderAsc()
} else {
query = query.OrderDesc()
}
return query
}
}
func TestSearchQuerybuilderSetters(t *testing.T) {
type args struct {
columns Columns
setters []func(*SearchQueryBuilder) *SearchQueryBuilder
}
tests := []struct {
name string
args args
res *SearchQueryBuilder
}{
{
name: "New builder",
args: args{
columns: ColumnsEvent,
},
res: &SearchQueryBuilder{
columns: Columns(ColumnsEvent),
},
},
{
name: "set columns",
args: args{
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{testSetColumns(ColumnsMaxPosition)},
},
res: &SearchQueryBuilder{
columns: ColumnsMaxPosition,
},
},
{
name: "set limit",
args: args{
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{testSetLimit(100)},
},
res: &SearchQueryBuilder{
limit: 100,
},
},
{
name: "set sequence greater",
args: args{
setters: []func(b *SearchQueryBuilder) *SearchQueryBuilder{
testSetQuery(testSetSequenceGreater(90)),
},
},
res: &SearchQueryBuilder{
eventSequenceGreater: 90,
},
},
{
name: "set aggregateIDs",
args: args{
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{testAddSubQuery(testSetAggregateIDs("1235", "09824"))},
},
res: &SearchQueryBuilder{
queries: []*SearchQuery{
{
aggregateIDs: []string{"1235", "09824"},
},
},
},
},
{
name: "set eventTypes",
args: args{
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{testAddSubQuery(testSetEventTypes("user.created", "user.updated"))},
},
res: &SearchQueryBuilder{
queries: []*SearchQuery{
{
eventTypes: []EventType{"user.created", "user.updated"},
},
},
},
},
{
name: "set resource owner",
args: args{
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{testSetResourceOwner("hodor")},
},
res: &SearchQueryBuilder{
resourceOwner: "hodor",
},
},
{
name: "default search query",
args: args{
setters: []func(*SearchQueryBuilder) *SearchQueryBuilder{testAddSubQuery(testSetAggregateTypes("user"), testSetAggregateIDs("1235", "024")), testSetSortOrder(false)},
},
res: &SearchQueryBuilder{
desc: true,
queries: []*SearchQuery{
{
aggregateTypes: []AggregateType{"user"},
aggregateIDs: []string{"1235", "024"},
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
builder := NewSearchQueryBuilder(tt.args.columns)
for _, setter := range tt.args.setters {
builder = setter(builder)
}
assertBuilder(t, tt.res, builder)
})
}
}
func assertBuilder(t *testing.T, want, got *SearchQueryBuilder) {
t.Helper()
if got.columns != want.columns {
t.Errorf("wrong column: got: %v want: %v", got.columns, want.columns)
}
if got.desc != want.desc {
t.Errorf("wrong desc: got: %v want: %v", got.desc, want.desc)
}
if got.limit != want.limit {
t.Errorf("wrong limit: got: %v want: %v", got.limit, want.limit)
}
if got.resourceOwner != want.resourceOwner {
t.Errorf("wrong : got: %v want: %v", got.resourceOwner, want.resourceOwner)
}
if len(got.queries) != len(want.queries) {
t.Errorf("wrong length of queries: got: %v want: %v", len(got.queries), len(want.queries))
}
for i, query := range got.queries {
assertQuery(t, i, want.queries[i], query)
}
}
func assertQuery(t *testing.T, i int, want, got *SearchQuery) {
t.Helper()
if !reflect.DeepEqual(got.aggregateIDs, want.aggregateIDs) {
t.Errorf("wrong aggregateIDs in query %d : got: %v want: %v", i, got.aggregateIDs, want.aggregateIDs)
}
if !reflect.DeepEqual(got.aggregateTypes, want.aggregateTypes) {
t.Errorf("wrong aggregateTypes in query %d : got: %v want: %v", i, got.aggregateTypes, want.aggregateTypes)
}
if !reflect.DeepEqual(got.eventData, want.eventData) {
t.Errorf("wrong eventData in query %d : got: %v want: %v", i, got.eventData, want.eventData)
}
// if got.eventSequenceGreater != want.eventSequenceGreater {
// t.Errorf("wrong eventSequenceGreater in query %d : got: %v want: %v", i, got.eventSequenceGreater, want.eventSequenceGreater)
// }
if !reflect.DeepEqual(got.eventTypes, want.eventTypes) {
t.Errorf("wrong eventTypes in query %d : got: %v want: %v", i, got.eventTypes, want.eventTypes)
}
}
func TestSearchQuery_matches(t *testing.T) {
type args struct {
event Command
}
tests := []struct {
name string
query *SearchQuery
event Command
want bool
}{
{
name: "wrong aggregate type",
query: NewSearchQueryBuilder(ColumnsEvent).AddQuery().AggregateTypes("searched"),
event: &matcherCommand{
BaseEvent{
Agg: &Aggregate{
Type: "found",
},
},
},
want: false,
},
{
name: "wrong aggregate id",
query: NewSearchQueryBuilder(ColumnsEvent).AddQuery().AggregateIDs("1", "10", "100"),
event: &matcherCommand{
BaseEvent{
Agg: &Aggregate{
ID: "2",
},
},
},
want: false,
},
{
name: "wrong event type",
query: NewSearchQueryBuilder(ColumnsEvent).AddQuery().EventTypes("event.searched.type"),
event: &matcherCommand{
BaseEvent{
EventType: "event.actual.type",
Agg: &Aggregate{},
},
},
want: false,
},
{
name: "matching",
query: NewSearchQueryBuilder(ColumnsEvent).
AddQuery().
AggregateIDs("2").
AggregateTypes("actual").
EventTypes("event.actual.type"),
event: &matcherCommand{
BaseEvent{
Seq: 55,
Agg: &Aggregate{
ID: "2",
Type: "actual",
},
EventType: "event.actual.type",
},
},
want: true,
},
{
name: "matching empty query",
query: NewSearchQueryBuilder(ColumnsEvent).AddQuery(),
event: &matcherCommand{
BaseEvent{
Seq: 55,
Agg: &Aggregate{
ID: "2",
Type: "actual",
},
EventType: "event.actual.type",
},
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
query := &SearchQuery{
aggregateTypes: tt.query.aggregateTypes,
aggregateIDs: tt.query.aggregateIDs,
eventTypes: tt.query.eventTypes,
eventData: tt.query.eventData,
}
if got := query.matches(tt.event); got != tt.want {
t.Errorf("SearchQuery.matches() = %v, want %v", got, tt.want)
}
})
}
}
type matcherCommand struct {
BaseEvent
}
func (matcherCommand) Payload() any { return nil }
func (matcherCommand) UniqueConstraints() []*UniqueConstraint { return nil }
func TestSearchQueryBuilder_Matches(t *testing.T) {
type args struct {
commands []Command
}
tests := []struct {
name string
builder *SearchQueryBuilder
args args
wantedLen int
}{
{
name: "sequence too high",
builder: NewSearchQueryBuilder(ColumnsEvent).
SequenceGreater(60),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 60,
},
},
},
},
wantedLen: 0,
},
{
name: "limit exeeded",
builder: NewSearchQueryBuilder(ColumnsEvent).
Limit(2),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
InstanceID: "instance",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
InstanceID: "instance",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
InstanceID: "instance",
},
Seq: 1001,
},
},
},
},
wantedLen: 2,
},
{
name: "wrong resource owner",
builder: NewSearchQueryBuilder(ColumnsEvent).
ResourceOwner("query"),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
},
},
},
},
},
wantedLen: 0,
},
{
name: "wrong instance",
builder: NewSearchQueryBuilder(ColumnsEvent).
InstanceID("instance"),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "different instance",
},
},
},
},
},
wantedLen: 0,
},
{
name: "query failed",
builder: NewSearchQueryBuilder(ColumnsEvent).
SequenceGreater(1000),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Seq: 999,
Agg: &Aggregate{},
},
},
},
},
wantedLen: 0,
},
{
name: "matching",
builder: NewSearchQueryBuilder(ColumnsEvent).
Limit(1000).
ResourceOwner("ro").
InstanceID("instance").
SequenceGreater(1000),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
InstanceID: "instance",
},
Seq: 1001,
},
},
},
},
wantedLen: 1,
},
{
name: "matching builder resourceOwner and Instance",
builder: NewSearchQueryBuilder(ColumnsEvent).
ResourceOwner("ro").
InstanceID("instance"),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
InstanceID: "instance",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro2",
InstanceID: "instance2",
},
Seq: 1002,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro2",
InstanceID: "instance",
},
Seq: 1003,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
InstanceID: "instance2",
},
Seq: 1004,
},
},
},
},
wantedLen: 1,
},
{
name: "matching builder resourceOwner only",
builder: NewSearchQueryBuilder(ColumnsEvent).
ResourceOwner("ro"),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
ResourceOwner: "ro2",
},
Seq: 1001,
},
},
},
},
wantedLen: 1,
},
{
name: "matching builder instanceID only",
builder: NewSearchQueryBuilder(ColumnsEvent).
InstanceID("instance"),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance2",
},
Seq: 1001,
},
},
},
},
wantedLen: 1,
},
{
name: "offset too high",
builder: NewSearchQueryBuilder(ColumnsEvent).
Offset(2),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 1001,
},
},
},
},
wantedLen: 0,
},
{
name: "offset",
builder: NewSearchQueryBuilder(ColumnsEvent).
Offset(1),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 1002,
},
},
},
},
wantedLen: 1,
},
{
name: "offset and limit",
builder: NewSearchQueryBuilder(ColumnsEvent).
Offset(1).
Limit(1),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 1002,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
},
Seq: 1002,
},
},
},
},
wantedLen: 1,
},
{
name: "sub query",
builder: NewSearchQueryBuilder(ColumnsEvent).
AddQuery().
AggregateTypes("test").
Builder(),
args: args{
commands: []Command{
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
Type: "test",
},
Seq: 1001,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
Type: "test",
},
Seq: 1002,
},
},
&matcherCommand{
BaseEvent{
Agg: &Aggregate{
InstanceID: "instance",
Type: "test2",
},
Seq: 1003,
},
},
},
},
wantedLen: 2,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := tt.builder.Matches(tt.args.commands...); len(got) != tt.wantedLen {
t.Errorf("SearchQueryBuilder.Matches() = %v, wantted len %v", got, tt.wantedLen)
}
})
}
}

View File

@@ -0,0 +1,106 @@
package eventstore
import (
"slices"
"sync"
"github.com/zitadel/logging"
)
var (
subscriptions = map[AggregateType][]*Subscription{}
subsMutex sync.RWMutex
)
type Subscription struct {
Events chan Event
types map[AggregateType][]EventType
}
// SubscribeAggregates subscribes for all events on the given aggregates
func SubscribeAggregates(eventQueue chan Event, aggregates ...AggregateType) *Subscription {
types := make(map[AggregateType][]EventType, len(aggregates))
for _, aggregate := range aggregates {
types[aggregate] = nil
}
sub := &Subscription{
Events: eventQueue,
types: types,
}
subsMutex.Lock()
defer subsMutex.Unlock()
for _, aggregate := range aggregates {
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
}
return sub
}
// SubscribeEventTypes subscribes for the given event types
// if no event types are provided the subscription is for all events of the aggregate
func SubscribeEventTypes(eventQueue chan Event, types map[AggregateType][]EventType) *Subscription {
sub := &Subscription{
Events: eventQueue,
types: types,
}
subsMutex.Lock()
defer subsMutex.Unlock()
for aggregate := range types {
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
}
return sub
}
func (es *Eventstore) notify(events []Event) {
subsMutex.RLock()
defer subsMutex.RUnlock()
for _, event := range events {
subs, ok := subscriptions[event.Aggregate().Type]
if !ok {
continue
}
for _, sub := range subs {
eventTypes := sub.types[event.Aggregate().Type]
//subscription for all events
if len(eventTypes) == 0 {
sub.Events <- event
continue
}
//subscription for certain events
if slices.Contains(eventTypes, event.Type()) {
select {
case sub.Events <- event:
default:
logging.Debug("unable to push event")
}
}
}
}
}
func (s *Subscription) Unsubscribe() {
subsMutex.Lock()
defer subsMutex.Unlock()
for aggregate := range s.types {
subs, ok := subscriptions[aggregate]
if !ok {
continue
}
for i := len(subs) - 1; i >= 0; i-- {
if subs[i] == s {
subs[i] = subs[len(subs)-1]
subs[len(subs)-1] = nil
subs = subs[:len(subs)-1]
}
}
}
_, ok := <-s.Events
if ok {
close(s.Events)
}
}

View File

@@ -0,0 +1,80 @@
package eventstore
type UniqueConstraint struct {
// UniqueType is the table name for the unique constraint
UniqueType string
// UniqueField is the unique key
UniqueField string
// Action defines if unique constraint should be added or removed
Action UniqueConstraintAction
// ErrorMessage defines the translation file key for the error message
ErrorMessage string
// IsGlobal defines if the unique constraint is globally unique or just within a single instance
IsGlobal bool
}
type UniqueConstraintAction int8
const (
UniqueConstraintAdd UniqueConstraintAction = iota
UniqueConstraintRemove
UniqueConstraintInstanceRemove
uniqueConstraintActionCount
)
func (f UniqueConstraintAction) Valid() bool {
return f >= 0 && f < uniqueConstraintActionCount
}
func NewAddEventUniqueConstraint(
uniqueType,
uniqueField,
errMessage string) *UniqueConstraint {
return &UniqueConstraint{
UniqueType: uniqueType,
UniqueField: uniqueField,
ErrorMessage: errMessage,
Action: UniqueConstraintAdd,
}
}
func NewRemoveUniqueConstraint(
uniqueType,
uniqueField string) *UniqueConstraint {
return &UniqueConstraint{
UniqueType: uniqueType,
UniqueField: uniqueField,
Action: UniqueConstraintRemove,
}
}
func NewRemoveInstanceUniqueConstraints() *UniqueConstraint {
return &UniqueConstraint{
Action: UniqueConstraintInstanceRemove,
}
}
func NewAddGlobalUniqueConstraint(
uniqueType,
uniqueField,
errMessage string) *UniqueConstraint {
return &UniqueConstraint{
UniqueType: uniqueType,
UniqueField: uniqueField,
ErrorMessage: errMessage,
IsGlobal: true,
Action: UniqueConstraintAdd,
}
}
func NewRemoveGlobalUniqueConstraint(
uniqueType,
uniqueField string) *UniqueConstraint {
return &UniqueConstraint{
UniqueType: uniqueType,
UniqueField: uniqueField,
IsGlobal: true,
Action: UniqueConstraintRemove,
}
}

View File

@@ -0,0 +1,154 @@
package models
import (
"encoding/json"
"reflect"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
type EventType string
func (et EventType) String() string {
return string(et)
}
var _ eventstore.Event = (*Event)(nil)
type Event struct {
ID string
Seq uint64
Pos decimal.Decimal
CreationDate time.Time
Typ eventstore.EventType
PreviousSequence uint64
Data []byte
AggregateID string
AggregateType eventstore.AggregateType
AggregateVersion eventstore.Version
Service string
User string
ResourceOwner string
InstanceID string
}
// Aggregate implements [eventstore.Event]
func (e *Event) Aggregate() *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: e.AggregateID,
Type: e.AggregateType,
ResourceOwner: e.ResourceOwner,
InstanceID: e.InstanceID,
// Version: eventstore.Version(e.AggregateVersion),
}
}
// CreatedAt implements [eventstore.Event]
func (e *Event) CreatedAt() time.Time {
return e.CreationDate
}
// DataAsBytes implements [eventstore.Event]
func (e *Event) DataAsBytes() []byte {
return e.Data
}
// Unmarshal implements [eventstore.Event]
func (e *Event) Unmarshal(ptr any) error {
if len(e.Data) == 0 {
return nil
}
return json.Unmarshal(e.Data, ptr)
}
// EditorService implements [eventstore.Event]
func (e *Event) EditorService() string {
return e.Service
}
// Creator implements [eventstore.action]
func (e *Event) Creator() string {
return e.User
}
// Sequence implements [eventstore.Event]
func (e *Event) Sequence() uint64 {
return e.Seq
}
// Position implements [eventstore.Event]
func (e *Event) Position() decimal.Decimal {
return e.Pos
}
// Type implements [eventstore.action]
func (e *Event) Type() eventstore.EventType {
return e.Typ
}
// Type implements [eventstore.action]
func (e *Event) Revision() uint16 {
return 0
}
func eventData(i interface{}) ([]byte, error) {
switch v := i.(type) {
case []byte:
return v, nil
case map[string]interface{}:
bytes, err := json.Marshal(v)
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "MODEL-s2fgE", "unable to marshal data")
}
return bytes, nil
case nil:
return nil, nil
default:
t := reflect.TypeOf(i)
if t.Kind() == reflect.Ptr {
t = t.Elem()
}
if t.Kind() != reflect.Struct {
return nil, zerrors.ThrowInvalidArgument(nil, "MODEL-rjWdN", "data is not valid")
}
bytes, err := json.Marshal(v)
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "MODEL-Y2OpM", "unable to marshal data")
}
return bytes, nil
}
}
func (e *Event) Validate() error {
if e == nil {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-oEAG4", "event is nil")
}
if string(e.Typ) == "" {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-R2sB0", "type not defined")
}
if e.AggregateID == "" {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-A6WwL", "aggregate id not set")
}
if e.AggregateType == "" {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-EzdyK", "aggregate type not set")
}
if err := e.AggregateVersion.Validate(); err != nil {
return zerrors.ThrowPreconditionFailed(err, "MODEL-KO71q", "version invalid")
}
if e.Service == "" {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-4Yqik", "editor service not set")
}
if e.User == "" {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-L3NHO", "editor user not set")
}
if e.ResourceOwner == "" {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-omFVT", "resource ow")
}
return nil
}

View File

@@ -0,0 +1,198 @@
package models
import (
"reflect"
"testing"
)
func Test_eventData(t *testing.T) {
type args struct {
i interface{}
}
tests := []struct {
name string
args args
want []byte
wantErr bool
}{
{
name: "from bytes",
args: args{[]byte(`{"hodor":"asdf"}`)},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from pointer",
args: args{&struct {
Hodor string `json:"hodor"`
}{Hodor: "asdf"}},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from struct",
args: args{struct {
Hodor string `json:"hodor"`
}{Hodor: "asdf"}},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from map",
args: args{
map[string]interface{}{"hodor": "asdf"},
},
want: []byte(`{"hodor":"asdf"}`),
wantErr: false,
},
{
name: "from nil",
args: args{},
want: nil,
wantErr: false,
},
{
name: "invalid data",
args: args{876},
want: nil,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := eventData(tt.args.i)
if (err != nil) != tt.wantErr {
t.Errorf("eventData() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("eventData() = %s, want %s", got, tt.want)
}
})
}
}
func TestEvent_Validate(t *testing.T) {
type fields struct {
event *Event
}
tests := []struct {
name string
fields fields
wantErr bool
}{
{
name: "event nil",
wantErr: true,
},
{
name: "event empty",
fields: fields{event: &Event{}},
wantErr: true,
},
{
name: "no aggregate id",
fields: fields{event: &Event{
AggregateType: "user",
AggregateVersion: "v1.0.0",
Service: "management",
User: "hodor",
ResourceOwner: "org",
Typ: "born",
}},
wantErr: true,
},
{
name: "no aggregate type",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateVersion: "v1.0.0",
Service: "management",
User: "hodor",
ResourceOwner: "org",
Typ: "born",
}},
wantErr: true,
},
{
name: "no aggregate version",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
Service: "management",
User: "hodor",
ResourceOwner: "org",
Typ: "born",
}},
wantErr: true,
},
{
name: "no editor service",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
User: "hodor",
ResourceOwner: "org",
Typ: "born",
}},
wantErr: true,
},
{
name: "no editor user",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
Service: "management",
ResourceOwner: "org",
Typ: "born",
}},
wantErr: true,
},
{
name: "no resource owner",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
Service: "management",
User: "hodor",
Typ: "born",
}},
wantErr: true,
},
{
name: "no type",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
Service: "management",
User: "hodor",
ResourceOwner: "org",
}},
wantErr: true,
},
{
name: "all fields set",
fields: fields{event: &Event{
AggregateID: "hodor",
AggregateType: "user",
AggregateVersion: "v1.0.0",
Service: "management",
User: "hodor",
ResourceOwner: "org",
Typ: "born",
}},
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.fields.event.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Event.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,40 @@
package models
import (
"time"
"github.com/zitadel/zitadel/internal/eventstore"
)
type ObjectRoot struct {
AggregateID string `json:"-"`
Sequence uint64 `json:"-"`
ResourceOwner string `json:"-"`
InstanceID string `json:"-"`
CreationDate time.Time `json:"-"`
ChangeDate time.Time `json:"-"`
}
func (o *ObjectRoot) AppendEvent(event eventstore.Event) {
if o.AggregateID == "" {
o.AggregateID = event.Aggregate().ID
} else if o.AggregateID != event.Aggregate().ID {
return
}
if o.ResourceOwner == "" {
o.ResourceOwner = event.Aggregate().ResourceOwner
}
if o.InstanceID == "" {
o.InstanceID = event.Aggregate().InstanceID
}
o.ChangeDate = event.CreatedAt()
if o.CreationDate.IsZero() {
o.CreationDate = o.ChangeDate
}
o.Sequence = event.Sequence()
}
func (o *ObjectRoot) IsZero() bool {
return o.AggregateID == ""
}

View File

@@ -0,0 +1,81 @@
package models
import (
"testing"
"time"
)
func TestObjectRoot_AppendEvent(t *testing.T) {
type fields struct {
ID string
Sequence uint64
CreationDate time.Time
ChangeDate time.Time
}
type args struct {
event *Event
isNewRoot bool
}
tests := []struct {
name string
fields fields
args args
}{
{
"new root",
fields{},
args{
&Event{
AggregateID: "aggID",
Seq: 34555,
CreationDate: time.Now(),
},
true,
},
},
{
"existing root",
fields{
"agg",
234,
time.Now().Add(-24 * time.Hour),
time.Now().Add(-12 * time.Hour),
},
args{
&Event{
AggregateID: "agg",
Seq: 34555425,
CreationDate: time.Now(),
PreviousSequence: 22,
},
false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
o := &ObjectRoot{
AggregateID: tt.fields.ID,
Sequence: tt.fields.Sequence,
CreationDate: tt.fields.CreationDate,
ChangeDate: tt.fields.ChangeDate,
}
o.AppendEvent(tt.args.event)
if tt.args.isNewRoot {
if !o.CreationDate.Equal(tt.args.event.CreationDate) {
t.Error("creationDate should be equal to event on new root")
}
} else {
if o.CreationDate.Equal(o.ChangeDate) {
t.Error("creationDate and changedate should differ")
}
}
if o.Sequence != tt.args.event.Seq {
t.Errorf("sequence not equal to event: event: %d root: %d", tt.args.event.Seq, o.Sequence)
}
if !o.ChangeDate.Equal(tt.args.event.CreationDate) {
t.Errorf("changedate should be equal to event creation date: event: %v root: %v", tt.args.event.CreationDate, o.ChangeDate)
}
})
}
}

View File

@@ -0,0 +1,176 @@
package eventstore
import (
"context"
"encoding/json"
"strconv"
"time"
"github.com/shopspring/decimal"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
var (
_ eventstore.Event = (*event)(nil)
)
type command struct {
InstanceID string
AggregateType string
AggregateID string
CommandType string
Revision uint16
Payload Payload
Creator string
Owner string
}
func (c *command) Aggregate() *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: c.AggregateID,
Type: eventstore.AggregateType(c.AggregateType),
ResourceOwner: c.Owner,
InstanceID: c.InstanceID,
Version: eventstore.Version("v" + strconv.Itoa(int(c.Revision))),
}
}
type event struct {
command *command
createdAt time.Time
sequence uint64
position decimal.Decimal
}
// TODO: remove on v3
func commandToEventOld(sequence *latestSequence, cmd eventstore.Command) (_ *event, err error) {
var payload Payload
if cmd.Payload() != nil {
payload, err = json.Marshal(cmd.Payload())
if err != nil {
logging.WithError(err).Warn("marshal payload failed")
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
}
}
return &event{
command: &command{
InstanceID: sequence.aggregate.InstanceID,
AggregateType: string(sequence.aggregate.Type),
AggregateID: sequence.aggregate.ID,
CommandType: string(cmd.Type()),
Revision: cmd.Revision(),
Payload: payload,
Creator: cmd.Creator(),
Owner: sequence.aggregate.ResourceOwner,
},
sequence: sequence.sequence,
}, nil
}
func commandsToEvents(ctx context.Context, cmds []eventstore.Command) (_ []eventstore.Event, _ []*command, err error) {
events := make([]eventstore.Event, len(cmds))
commands := make([]*command, len(cmds))
for i, cmd := range cmds {
if cmd.Aggregate().InstanceID == "" {
cmd.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
}
events[i], err = commandToEvent(cmd)
if err != nil {
return nil, nil, err
}
commands[i] = events[i].(*event).command
}
return events, commands, nil
}
func commandToEvent(cmd eventstore.Command) (_ eventstore.Event, err error) {
var payload Payload
if cmd.Payload() != nil {
payload, err = json.Marshal(cmd.Payload())
if err != nil {
logging.WithError(err).Warn("marshal payload failed")
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
}
}
command := &command{
InstanceID: cmd.Aggregate().InstanceID,
AggregateType: string(cmd.Aggregate().Type),
AggregateID: cmd.Aggregate().ID,
CommandType: string(cmd.Type()),
Revision: cmd.Revision(),
Payload: payload,
Creator: cmd.Creator(),
Owner: cmd.Aggregate().ResourceOwner,
}
return &event{
command: command,
}, nil
}
// CreationDate implements [eventstore.Event]
func (e *event) CreationDate() time.Time {
return e.CreatedAt()
}
// EditorUser implements [eventstore.Event]
func (e *event) EditorUser() string {
return e.Creator()
}
// Aggregate implements [eventstore.Event]
func (e *event) Aggregate() *eventstore.Aggregate {
return e.command.Aggregate()
}
// Creator implements [eventstore.Event]
func (e *event) Creator() string {
return e.command.Creator
}
// Revision implements [eventstore.Event]
func (e *event) Revision() uint16 {
return e.command.Revision
}
// Type implements [eventstore.Event]
func (e *event) Type() eventstore.EventType {
return eventstore.EventType(e.command.CommandType)
}
// CreatedAt implements [eventstore.Event]
func (e *event) CreatedAt() time.Time {
return e.createdAt
}
// Sequence implements [eventstore.Event]
func (e *event) Sequence() uint64 {
return e.sequence
}
// Position implements [eventstore.Event]
func (e *event) Position() decimal.Decimal {
return e.position
}
// Unmarshal implements [eventstore.Event]
func (e *event) Unmarshal(ptr any) error {
if len(e.command.Payload) == 0 {
return nil
}
if err := json.Unmarshal(e.command.Payload, ptr); err != nil {
return zerrors.ThrowInternal(err, "V3-u8qVo", "Errors.Internal")
}
return nil
}
// DataAsBytes implements [eventstore.Event]
func (e *event) DataAsBytes() []byte {
return e.command.Payload
}

View File

@@ -0,0 +1,482 @@
package eventstore
import (
"context"
"encoding/json"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_commandToEvent(t *testing.T) {
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
command eventstore.Command
}
type want struct {
event *event
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
).(*event),
},
},
{
name: "struct payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "pointer payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: &payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "invalid payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
want: want{
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
if tt.want.err == nil {
tt.want.err = func(t *testing.T, err error) {
require.NoError(t, err)
}
}
t.Run(tt.name, func(t *testing.T) {
got, err := commandToEvent(tt.args.command)
tt.want.err(t, err)
if tt.want.event == nil {
assert.Nil(t, got)
return
}
assert.Equal(t, tt.want.event, got)
})
}
}
func Test_commandToEventOld(t *testing.T) {
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
sequence *latestSequence
command eventstore.Command
}
type want struct {
event *event
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
).(*event),
},
},
{
name: "struct payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "pointer payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: &payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "invalid payload",
args: args{
sequence: &latestSequence{
aggregate: mockAggregate("V3-Red9I"),
sequence: 0,
},
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
want: want{
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
if tt.want.err == nil {
tt.want.err = func(t *testing.T, err error) {
require.NoError(t, err)
}
}
t.Run(tt.name, func(t *testing.T) {
got, err := commandToEventOld(tt.args.sequence, tt.args.command)
tt.want.err(t, err)
assert.Equal(t, tt.want.event, got)
})
}
}
func Test_commandsToEvents(t *testing.T) {
ctx := context.Background()
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
ctx context.Context
cmds []eventstore.Command
}
type want struct {
events []eventstore.Event
commands []*command
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no commands",
args: args{
ctx: ctx,
cmds: nil,
},
want: want{
events: []eventstore.Event{},
commands: []*command{},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command no payload",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command no instance id",
args: args{
ctx: authz.WithInstanceID(ctx, "instance from ctx"),
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregateWithInstance("V3-Red9I", ""),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregateWithInstance("V3-Red9I", "instance from ctx"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance from ctx",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command with payload",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: payloadMarshalled,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "multiple commands",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
),
mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
CommandType: "event.type",
Revision: 1,
Payload: payloadMarshalled,
Creator: "creator",
Owner: "ro",
},
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
Owner: "ro",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "invalid command",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
},
want: want{
events: nil,
commands: nil,
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotEvents, gotCommands, err := commandsToEvents(tt.args.ctx, tt.args.cmds)
tt.want.err(t, err)
assert.Equal(t, tt.want.events, gotEvents)
require.Len(t, gotCommands, len(tt.want.commands))
for i, wantCommand := range tt.want.commands {
assertCommand(t, wantCommand, gotCommands[i])
}
})
}
}
func assertCommand(t *testing.T, want, got *command) {
t.Helper()
assert.Equal(t, want.CommandType, got.CommandType)
assert.Equal(t, want.Payload, got.Payload)
assert.Equal(t, want.Creator, got.Creator)
assert.Equal(t, want.Owner, got.Owner)
assert.Equal(t, want.AggregateID, got.AggregateID)
assert.Equal(t, want.AggregateType, got.AggregateType)
assert.Equal(t, want.InstanceID, got.InstanceID)
assert.Equal(t, want.Revision, got.Revision)
}

View File

@@ -0,0 +1,202 @@
package eventstore
import (
"context"
"database/sql"
"encoding/json"
"errors"
"sync"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/stdlib"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/eventstore"
)
func init() {
dialect.RegisterAfterConnect(RegisterEventstoreTypes)
}
var (
// pushPlaceholderFmt defines how data are inserted into the events table
pushPlaceholderFmt = "($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $%d)"
// uniqueConstraintPlaceholderFmt defines the format of the unique constraint error returned from the database
uniqueConstraintPlaceholderFmt = "(%s, %s, %s)"
_ eventstore.Pusher = (*Eventstore)(nil)
)
type Eventstore struct {
client *database.DB
}
var (
textType = &pgtype.Type{
Name: "text",
OID: pgtype.TextOID,
Codec: pgtype.TextCodec{},
}
commandType = &pgtype.Type{
Codec: &pgtype.CompositeCodec{
Fields: []pgtype.CompositeCodecField{
{
Name: "instance_id",
Type: textType,
},
{
Name: "aggregate_type",
Type: textType,
},
{
Name: "aggregate_id",
Type: textType,
},
{
Name: "command_type",
Type: textType,
},
{
Name: "revision",
Type: &pgtype.Type{
Name: "int2",
OID: pgtype.Int2OID,
Codec: pgtype.Int2Codec{},
},
},
{
Name: "payload",
Type: &pgtype.Type{
Name: "jsonb",
OID: pgtype.JSONBOID,
Codec: &pgtype.JSONBCodec{
Marshal: json.Marshal,
Unmarshal: json.Unmarshal,
},
},
},
{
Name: "creator",
Type: textType,
},
{
Name: "owner",
Type: textType,
},
},
},
}
commandArrayCodec = &pgtype.Type{
Codec: &pgtype.ArrayCodec{
ElementType: commandType,
},
}
)
var typeMu sync.Mutex
func RegisterEventstoreTypes(ctx context.Context, conn *pgx.Conn) error {
// conn.TypeMap is not thread safe
typeMu.Lock()
defer typeMu.Unlock()
m := conn.TypeMap()
var cmd *command
if _, ok := m.TypeForValue(cmd); ok {
return nil
}
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
err := conn.QueryRow(ctx, "select oid, typarray from pg_type where typname = $1 and typnamespace = (select oid from pg_namespace where nspname = $2)", "command", "eventstore").
Scan(&commandType.OID, &commandArrayCodec.OID)
if err != nil {
logging.WithError(err).Debug("failed to get oid for command type")
return nil
}
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
logging.Debug("oid for command type not found")
return nil
}
}
m.RegisterTypes([]*pgtype.Type{
{
Name: "eventstore.command",
Codec: commandType.Codec,
OID: commandType.OID,
},
{
Name: "command",
Codec: commandType.Codec,
OID: commandType.OID,
},
{
Name: "eventstore._command",
Codec: commandArrayCodec.Codec,
OID: commandArrayCodec.OID,
},
{
Name: "_command",
Codec: commandArrayCodec.Codec,
OID: commandArrayCodec.OID,
},
})
dialect.RegisterDefaultPgTypeVariants[command](m, "eventstore.command", "eventstore._command")
dialect.RegisterDefaultPgTypeVariants[command](m, "command", "_command")
return nil
}
// Client implements the [eventstore.Pusher]
func (es *Eventstore) Client() *database.DB {
return es.client
}
func NewEventstore(client *database.DB) *Eventstore {
return &Eventstore{client: client}
}
func (es *Eventstore) Health(ctx context.Context) error {
return es.client.PingContext(ctx)
}
var errTypesNotFound = errors.New("types not found")
func CheckExecutionPlan(ctx context.Context, conn *sql.Conn) error {
return conn.Raw(func(driverConn any) error {
if _, ok := driverConn.(sqlmock.SqlmockCommon); ok {
return nil
}
conn, ok := driverConn.(*stdlib.Conn)
if !ok {
return errTypesNotFound
}
return RegisterEventstoreTypes(ctx, conn.Conn())
})
}
func (es *Eventstore) pushTx(ctx context.Context, client database.ContextQueryExecuter) (tx database.Tx, deferrable func(err error) error, err error) {
tx, ok := client.(database.Tx)
if ok {
return tx, nil, nil
}
beginner, ok := client.(database.Beginner)
if !ok {
beginner = es.client
}
tx, err = beginner.BeginTx(ctx, &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
ReadOnly: false,
})
if err != nil {
return nil, nil, err
}
return tx, func(err error) error { return database.CloseTransaction(tx, err) }, nil
}

View File

@@ -0,0 +1,369 @@
package eventstore
import (
"context"
"database/sql"
_ "embed"
"encoding/json"
"reflect"
"slices"
"strconv"
"strings"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
type fieldValue struct {
value []byte
}
func (value *fieldValue) Unmarshal(ptr any) error {
return json.Unmarshal(value.value, ptr)
}
func (es *Eventstore) FillFields(ctx context.Context, events ...eventstore.FillFieldsEvent) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer span.End()
tx, err := es.client.BeginTx(ctx, &sql.TxOptions{Isolation: sql.LevelReadCommitted})
if err != nil {
return err
}
defer func() {
if err != nil {
_ = tx.Rollback()
return
}
err = tx.Commit()
}()
return handleFieldFillEvents(ctx, tx, events)
}
// Search implements the [eventstore.Search] method
func (es *Eventstore) Search(ctx context.Context, conditions ...map[eventstore.FieldType]any) (result []*eventstore.SearchResult, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
var builder strings.Builder
args := buildSearchStatement(ctx, &builder, conditions...)
err = es.client.QueryContext(
ctx,
func(rows *sql.Rows) error {
for rows.Next() {
var (
res eventstore.SearchResult
value fieldValue
)
err = rows.Scan(
&res.Aggregate.InstanceID,
&res.Aggregate.ResourceOwner,
&res.Aggregate.Type,
&res.Aggregate.ID,
&res.Object.Type,
&res.Object.ID,
&res.Object.Revision,
&res.FieldName,
&value.value,
)
if err != nil {
return err
}
res.Value = &value
result = append(result, &res)
}
return nil
},
builder.String(),
args...,
)
if err != nil {
return nil, err
}
return result, nil
}
const searchQueryPrefix = `SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1`
func buildSearchStatement(ctx context.Context, builder *strings.Builder, conditions ...map[eventstore.FieldType]any) []any {
args := make([]any, 0, len(conditions)*4+1)
args = append(args, authz.GetInstance(ctx).InstanceID())
builder.WriteString(searchQueryPrefix)
builder.WriteString(" AND ")
if len(conditions) > 1 {
builder.WriteRune('(')
}
for i, condition := range conditions {
if i > 0 {
builder.WriteString(" OR ")
}
if len(condition) > 1 {
builder.WriteRune('(')
}
args = append(args, buildSearchCondition(builder, len(args)+1, condition)...)
if len(condition) > 1 {
builder.WriteRune(')')
}
}
if len(conditions) > 1 {
builder.WriteRune(')')
}
return args
}
func buildSearchCondition(builder *strings.Builder, index int, conditions map[eventstore.FieldType]any) []any {
args := make([]any, 0, len(conditions))
orderedCondition := make([]eventstore.FieldType, 0, len(conditions))
for field := range conditions {
orderedCondition = append(orderedCondition, field)
}
slices.Sort(orderedCondition)
for _, field := range orderedCondition {
if len(args) > 0 {
builder.WriteString(" AND ")
}
builder.WriteString(fieldNameByType(field, conditions[field]))
builder.WriteString(" = $")
builder.WriteString(strconv.Itoa(index + len(args)))
args = append(args, conditions[field])
}
return args
}
func (es *Eventstore) handleFieldCommands(ctx context.Context, tx database.Tx, commands []eventstore.Command) error {
for _, command := range commands {
if len(command.Fields()) > 0 {
if err := handleFieldOperations(ctx, tx, command.Fields()); err != nil {
return err
}
}
}
return nil
}
func handleFieldFillEvents(ctx context.Context, tx database.Tx, events []eventstore.FillFieldsEvent) error {
for _, event := range events {
if len(event.Fields()) == 0 {
continue
}
if err := handleFieldOperations(ctx, tx, event.Fields()); err != nil {
return err
}
}
return nil
}
func handleFieldOperations(ctx context.Context, tx database.Tx, operations []*eventstore.FieldOperation) error {
for _, operation := range operations {
if operation.Set != nil {
if err := handleFieldSet(ctx, tx, operation.Set); err != nil {
return err
}
continue
}
if operation.Remove != nil {
if err := handleSearchDelete(ctx, tx, operation.Remove); err != nil {
return err
}
}
}
return nil
}
func handleFieldSet(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
if len(field.UpsertConflictFields) == 0 {
return handleSearchInsert(ctx, tx, field)
}
return handleSearchUpsert(ctx, tx, field)
}
const (
insertField = `INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11)`
)
func handleSearchInsert(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
value, err := json.Marshal(field.Value.Value)
if err != nil {
return zerrors.ThrowInvalidArgument(err, "V3-fcrW1", "unable to marshal field value")
}
_, err = tx.ExecContext(
ctx,
insertField,
field.Aggregate.InstanceID,
field.Aggregate.ResourceOwner,
field.Aggregate.Type,
field.Aggregate.ID,
field.Object.Type,
field.Object.ID,
field.Object.Revision,
field.FieldName,
value,
field.Value.MustBeUnique,
field.Value.ShouldIndex,
)
return err
}
const (
fieldsUpsertPrefix = `WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE `
fieldsUpsertSuffix = ` RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)`
)
func handleSearchUpsert(ctx context.Context, tx database.Tx, field *eventstore.Field) error {
value, err := json.Marshal(field.Value.Value)
if err != nil {
return zerrors.ThrowInvalidArgument(err, "V3-fcrW1", "unable to marshal field value")
}
_, err = tx.ExecContext(
ctx,
writeUpsertField(field.UpsertConflictFields),
field.Aggregate.InstanceID,
field.Aggregate.ResourceOwner,
field.Aggregate.Type,
field.Aggregate.ID,
field.Object.Type,
field.Object.ID,
field.Object.Revision,
field.FieldName,
value,
field.Value.MustBeUnique,
field.Value.ShouldIndex,
)
return err
}
func writeUpsertField(fields []eventstore.FieldType) string {
var builder strings.Builder
builder.WriteString(fieldsUpsertPrefix)
for i, fieldName := range fields {
if i > 0 {
builder.WriteString(" AND ")
}
name, index := searchFieldNameAndIndexByTypeForPush(fieldName)
builder.WriteString(name)
builder.WriteString(" = ")
builder.WriteString(index)
}
builder.WriteString(fieldsUpsertSuffix)
return builder.String()
}
const removeSearch = `DELETE FROM eventstore.fields WHERE `
func handleSearchDelete(ctx context.Context, tx database.Tx, clauses map[eventstore.FieldType]any) error {
if len(clauses) == 0 {
return zerrors.ThrowInvalidArgument(nil, "V3-oqlBZ", "no conditions")
}
stmt, args := writeDeleteField(clauses)
_, err := tx.ExecContext(ctx, stmt, args...)
return err
}
func writeDeleteField(clauses map[eventstore.FieldType]any) (string, []any) {
var (
builder strings.Builder
args = make([]any, 0, len(clauses))
)
builder.WriteString(removeSearch)
orderedCondition := make([]eventstore.FieldType, 0, len(clauses))
for field := range clauses {
orderedCondition = append(orderedCondition, field)
}
slices.Sort(orderedCondition)
for _, fieldName := range orderedCondition {
if len(args) > 0 {
builder.WriteString(" AND ")
}
builder.WriteString(fieldNameByType(fieldName, clauses[fieldName]))
builder.WriteString(" = $")
builder.WriteString(strconv.Itoa(len(args) + 1))
args = append(args, clauses[fieldName])
}
return builder.String(), args
}
func fieldNameByType(typ eventstore.FieldType, value any) string {
switch typ {
case eventstore.FieldTypeAggregateID:
return "aggregate_id"
case eventstore.FieldTypeAggregateType:
return "aggregate_type"
case eventstore.FieldTypeInstanceID:
return "instance_id"
case eventstore.FieldTypeResourceOwner:
return "resource_owner"
case eventstore.FieldTypeFieldName:
return "field_name"
case eventstore.FieldTypeObjectType:
return "object_type"
case eventstore.FieldTypeObjectID:
return "object_id"
case eventstore.FieldTypeObjectRevision:
return "object_revision"
case eventstore.FieldTypeValue:
return valueColumn(value)
}
return ""
}
func searchFieldNameAndIndexByTypeForPush(typ eventstore.FieldType) (string, string) {
switch typ {
case eventstore.FieldTypeInstanceID:
return "instance_id", "$1"
case eventstore.FieldTypeResourceOwner:
return "resource_owner", "$2"
case eventstore.FieldTypeAggregateType:
return "aggregate_type", "$3"
case eventstore.FieldTypeAggregateID:
return "aggregate_id", "$4"
case eventstore.FieldTypeObjectType:
return "object_type", "$5"
case eventstore.FieldTypeObjectID:
return "object_id", "$6"
case eventstore.FieldTypeObjectRevision:
return "object_revision", "$7"
case eventstore.FieldTypeFieldName:
return "field_name", "$8"
case eventstore.FieldTypeValue:
return "value", "$9"
}
return "", ""
}
func valueColumn(value any) string {
//nolint: exhaustive
switch reflect.TypeOf(value).Kind() {
case reflect.Bool:
return "bool_value"
case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64:
return "number_value"
case reflect.String:
return "text_value"
}
return ""
}

View File

@@ -0,0 +1,260 @@
package eventstore
import (
"context"
_ "embed"
"reflect"
"strings"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_handleSearchDelete(t *testing.T) {
type args struct {
clauses map[eventstore.FieldType]any
}
type want struct {
stmt string
args []any
}
tests := []struct {
name string
args args
want want
}{
{
name: "1 condition",
args: args{
clauses: map[eventstore.FieldType]any{
eventstore.FieldTypeInstanceID: "i_id",
},
},
want: want{
stmt: "DELETE FROM eventstore.fields WHERE instance_id = $1",
args: []any{"i_id"},
},
},
{
name: "2 conditions",
args: args{
clauses: map[eventstore.FieldType]any{
eventstore.FieldTypeInstanceID: "i_id",
eventstore.FieldTypeAggregateID: "a_id",
},
},
want: want{
stmt: "DELETE FROM eventstore.fields WHERE aggregate_id = $1 AND instance_id = $2",
args: []any{"a_id", "i_id"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
stmt, args := writeDeleteField(tt.args.clauses)
if stmt != tt.want.stmt {
t.Errorf("handleSearchDelete() stmt = %q, want %q", stmt, tt.want.stmt)
}
assert.Equal(t, tt.want.args, args)
})
}
}
func Test_writeUpsertField(t *testing.T) {
type args struct {
fields []eventstore.FieldType
}
tests := []struct {
name string
args args
want string
}{
{
name: "1 field",
args: args{
fields: []eventstore.FieldType{
eventstore.FieldTypeInstanceID,
},
},
want: "WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE instance_id = $1 RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)",
},
{
name: "2 fields",
args: args{
fields: []eventstore.FieldType{
eventstore.FieldTypeInstanceID,
eventstore.FieldTypeAggregateType,
},
},
want: "WITH upsert AS (UPDATE eventstore.fields SET (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) = ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11) WHERE instance_id = $1 AND aggregate_type = $3 RETURNING * ) INSERT INTO eventstore.fields (instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value, value_must_be_unique, should_index) SELECT $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11 WHERE NOT EXISTS (SELECT 1 FROM upsert)",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := writeUpsertField(tt.args.fields); got != tt.want {
t.Errorf("writeUpsertField() = %q, want %q", got, tt.want)
}
})
}
}
func Test_buildSearchCondition(t *testing.T) {
type args struct {
index int
conditions map[eventstore.FieldType]any
}
type want struct {
stmt string
args []any
}
tests := []struct {
name string
args args
want want
}{
{
name: "1 condition",
args: args{
index: 1,
conditions: map[eventstore.FieldType]any{
eventstore.FieldTypeAggregateID: "a_id",
},
},
want: want{
stmt: "aggregate_id = $1",
args: []any{"a_id"},
},
},
{
name: "3 condition",
args: args{
index: 1,
conditions: map[eventstore.FieldType]any{
eventstore.FieldTypeAggregateID: "a_id",
eventstore.FieldTypeInstanceID: "i_id",
eventstore.FieldTypeAggregateType: "a_type",
},
},
want: want{
stmt: "aggregate_type = $1 AND aggregate_id = $2 AND instance_id = $3",
args: []any{"a_type", "a_id", "i_id"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var builder strings.Builder
if got := buildSearchCondition(&builder, tt.args.index, tt.args.conditions); !reflect.DeepEqual(got, tt.want.args) {
t.Errorf("buildSearchCondition() = %v, want %v", got, tt.want)
}
if tt.want.stmt != builder.String() {
t.Errorf("buildSearchCondition() stmt = %q, want %q", builder.String(), tt.want.stmt)
}
})
}
}
func Test_buildSearchStatement(t *testing.T) {
type args struct {
index int
conditions []map[eventstore.FieldType]any
}
type want struct {
stmt string
args []any
}
tests := []struct {
name string
args args
want want
}{
{
name: "1 condition with 1 field",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND aggregate_id = $2",
args: []any{"a_id"},
},
},
{
name: "1 condition with 3 fields",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id",
eventstore.FieldTypeInstanceID: "i_id",
eventstore.FieldTypeAggregateType: "a_type",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND (aggregate_type = $2 AND aggregate_id = $3 AND instance_id = $4)",
args: []any{"a_type", "a_id", "i_id"},
},
},
{
name: "2 condition with 1 field",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id",
},
{
eventstore.FieldTypeAggregateType: "a_type",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND (aggregate_id = $2 OR aggregate_type = $3)",
args: []any{"a_id", "a_type"},
},
},
{
name: "2 condition with 2 fields",
args: args{
index: 1,
conditions: []map[eventstore.FieldType]any{
{
eventstore.FieldTypeAggregateID: "a_id1",
eventstore.FieldTypeAggregateType: "a_type1",
},
{
eventstore.FieldTypeAggregateID: "a_id2",
eventstore.FieldTypeAggregateType: "a_type2",
},
},
},
want: want{
stmt: "SELECT instance_id, resource_owner, aggregate_type, aggregate_id, object_type, object_id, object_revision, field_name, value FROM eventstore.fields WHERE instance_id = $1 AND ((aggregate_type = $2 AND aggregate_id = $3) OR (aggregate_type = $4 AND aggregate_id = $5))",
args: []any{"a_type1", "a_id1", "a_type2", "a_id2"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var builder strings.Builder
tt.want.args = append([]any{"i_id"}, tt.want.args...)
ctx := authz.WithInstanceID(context.Background(), "i_id")
if got := buildSearchStatement(ctx, &builder, tt.args.conditions...); !reflect.DeepEqual(got, tt.want.args) {
t.Errorf("buildSearchStatement() = %v, want %v", got, tt.want)
}
if tt.want.stmt != builder.String() {
t.Errorf("buildSearchStatement() stmt = %q, want %q", builder.String(), tt.want.stmt)
}
})
}
}

View File

@@ -0,0 +1,83 @@
package eventstore
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
var _ eventstore.Command = (*mockCommand)(nil)
type mockCommand struct {
aggregate *eventstore.Aggregate
payload any
constraints []*eventstore.UniqueConstraint
}
// Aggregate implements [eventstore.Command]
func (m *mockCommand) Aggregate() *eventstore.Aggregate {
return m.aggregate
}
// Creator implements [eventstore.Command]
func (m *mockCommand) Creator() string {
return "creator"
}
// Revision implements [eventstore.Command]
func (m *mockCommand) Revision() uint16 {
return 1
}
// Type implements [eventstore.Command]
func (m *mockCommand) Type() eventstore.EventType {
return "event.type"
}
// Payload implements [eventstore.Command]
func (m *mockCommand) Payload() any {
return m.payload
}
// UniqueConstraints implements [eventstore.Command]
func (m *mockCommand) UniqueConstraints() []*eventstore.UniqueConstraint {
return m.constraints
}
func (e *mockCommand) Fields() []*eventstore.FieldOperation {
return nil
}
func mockEvent(aggregate *eventstore.Aggregate, sequence uint64, payload Payload) eventstore.Event {
return &event{
command: &command{
InstanceID: aggregate.InstanceID,
AggregateType: string(aggregate.Type),
AggregateID: aggregate.ID,
Owner: aggregate.ResourceOwner,
Creator: "creator",
Revision: 1,
CommandType: "event.type",
Payload: payload,
},
sequence: sequence,
}
}
func mockAggregate(id string) *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: id,
Type: "type",
ResourceOwner: "ro",
InstanceID: "instance",
Version: "v1",
}
}
func mockAggregateWithInstance(id, instance string) *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: id,
InstanceID: instance,
Type: "type",
ResourceOwner: "ro",
Version: "v1",
}
}

View File

@@ -0,0 +1,108 @@
package eventstore
import (
"context"
"database/sql"
_ "embed"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
var pushTxOpts = &sql.TxOptions{
Isolation: sql.LevelReadCommitted,
ReadOnly: false,
}
func (es *Eventstore) Push(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
events, err = es.writeCommands(ctx, client, commands)
if isSetupNotExecutedError(err) {
return es.pushWithoutFunc(ctx, client, commands...)
}
return events, err
}
func (es *Eventstore) writeCommands(ctx context.Context, client database.ContextQueryExecuter, commands []eventstore.Command) (_ []eventstore.Event, err error) {
var conn *sql.Conn
switch c := client.(type) {
case database.Client:
conn, err = c.Conn(ctx)
case nil:
conn, err = es.client.Conn(ctx)
client = conn
}
if err != nil {
return nil, err
}
if conn != nil {
defer conn.Close()
}
tx, close, err := es.pushTx(ctx, client)
if err != nil {
return nil, err
}
if close != nil {
defer func() {
err = close(err)
}()
}
_, err = tx.ExecContext(ctx, fmt.Sprintf("SET LOCAL application_name = '%s'", fmt.Sprintf("zitadel_es_pusher_%s", authz.GetInstance(ctx).InstanceID())))
if err != nil {
return nil, err
}
events, err := writeEvents(ctx, tx, commands)
if err != nil {
return nil, err
}
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
return nil, err
}
err = es.handleFieldCommands(ctx, tx, commands)
if err != nil {
return nil, err
}
return events, nil
}
func writeEvents(ctx context.Context, tx database.Tx, commands []eventstore.Command) (_ []eventstore.Event, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
events, cmds, err := commandsToEvents(ctx, commands)
if err != nil {
return nil, err
}
rows, err := tx.QueryContext(ctx, `select owner, created_at, "sequence", position from eventstore.push($1::eventstore.command[])`, cmds)
if err != nil {
return nil, err
}
defer rows.Close()
for i := 0; rows.Next(); i++ {
err = rows.Scan(&events[i].(*event).command.Owner, &events[i].(*event).createdAt, &events[i].(*event).sequence, &events[i].(*event).position)
if err != nil {
logging.WithError(err).Warn("failed to scan events")
return nil, err
}
}
if err = rows.Err(); err != nil {
return nil, err
}
return events, nil
}

View File

@@ -0,0 +1,18 @@
INSERT INTO eventstore.events2 (
instance_id
, "owner"
, aggregate_type
, aggregate_id
, revision
, creator
, event_type
, payload
, "sequence"
, created_at
, "position"
, in_tx_order
) VALUES
%s
RETURNING created_at, "position";

View File

@@ -0,0 +1,253 @@
package eventstore
import (
_ "embed"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/postgres"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_mapCommands(t *testing.T) {
type args struct {
commands []eventstore.Command
sequences []*latestSequence
}
type want struct {
events []eventstore.Event
placeHolders []string
args []any
err func(t *testing.T, err error)
shouldPanic bool
}
tests := []struct {
name string
args args
want want
}{
{
name: "no commands",
args: args{
commands: []eventstore.Command{},
sequences: []*latestSequence{},
},
want: want{
events: []eventstore.Event{},
placeHolders: []string{},
args: []any{},
},
},
{
name: "one command",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
},
sequences: []*latestSequence{
{
aggregate: mockAggregate("V3-VEIvq"),
sequence: 0,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-VEIvq"),
1,
nil,
),
},
placeHolders: []string{
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
},
args: []any{
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(1),
0,
},
err: func(t *testing.T, err error) {},
},
},
{
name: "multiple commands same aggregate",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
},
sequences: []*latestSequence{
{
aggregate: mockAggregate("V3-VEIvq"),
sequence: 5,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-VEIvq"),
6,
nil,
),
mockEvent(
mockAggregate("V3-VEIvq"),
7,
nil,
),
},
placeHolders: []string{
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
"($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)",
},
args: []any{
// first event
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(6),
0,
// second event
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(7),
1,
},
err: func(t *testing.T, err error) {},
},
},
{
name: "one command per aggregate",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
&mockCommand{
aggregate: mockAggregate("V3-IT6VN"),
},
},
sequences: []*latestSequence{
{
aggregate: mockAggregate("V3-VEIvq"),
sequence: 5,
},
{
aggregate: mockAggregate("V3-IT6VN"),
sequence: 0,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-VEIvq"),
6,
nil,
),
mockEvent(
mockAggregate("V3-IT6VN"),
1,
nil,
),
},
placeHolders: []string{
"($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)",
"($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)",
},
args: []any{
// first event
"instance",
"ro",
"type",
"V3-VEIvq",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(6),
0,
// second event
"instance",
"ro",
"type",
"V3-IT6VN",
uint16(1),
"creator",
"event.type",
Payload(nil),
uint64(1),
1,
},
err: func(t *testing.T, err error) {},
},
},
{
name: "missing sequence",
args: args{
commands: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-VEIvq"),
},
},
sequences: []*latestSequence{},
},
want: want{
events: []eventstore.Event{},
placeHolders: []string{},
args: []any{},
err: func(t *testing.T, err error) {},
shouldPanic: true,
},
},
}
for _, tt := range tests {
if tt.want.err == nil {
tt.want.err = func(t *testing.T, err error) {
require.NoError(t, err)
}
}
// is used to set the the [pushPlaceholderFmt]
NewEventstore(&database.DB{Database: new(postgres.Config)})
t.Run(tt.name, func(t *testing.T) {
defer func() {
cause := recover()
assert.Equal(t, tt.want.shouldPanic, cause != nil)
}()
gotEvents, gotPlaceHolders, gotArgs, err := mapCommands(tt.args.commands, tt.args.sequences)
tt.want.err(t, err)
assert.ElementsMatch(t, tt.want.events, gotEvents)
assert.ElementsMatch(t, tt.want.placeHolders, gotPlaceHolders)
assert.ElementsMatch(t, tt.want.args, gotArgs)
})
}
}

View File

@@ -0,0 +1,162 @@
package eventstore
import (
"context"
_ "embed"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
// checks whether the error is caused because setup step 39 was not executed
func isSetupNotExecutedError(err error) bool {
if err == nil {
return false
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return (pgErr.Code == "42704" && strings.Contains(pgErr.Message, "eventstore.command")) ||
(pgErr.Code == "42883" && strings.Contains(pgErr.Message, "eventstore.push"))
}
return errors.Is(err, errTypesNotFound)
}
var (
//go:embed push.sql
pushStmt string
)
// pushWithoutFunc implements pushing events before setup step 39 was introduced.
// TODO: remove with v3
func (es *Eventstore) pushWithoutFunc(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
tx, closeTx, err := es.pushTx(ctx, client)
if err != nil {
return nil, err
}
defer func() {
err = closeTx(err)
}()
var (
sequences []*latestSequence
)
sequences, err = latestSequences(ctx, tx, commands)
if err != nil {
return nil, err
}
events, err = es.writeEventsOld(ctx, tx, sequences, commands)
if err != nil {
return nil, err
}
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
return nil, err
}
err = es.handleFieldCommands(ctx, tx, commands)
if err != nil {
return nil, err
}
return events, nil
}
func (es *Eventstore) writeEventsOld(ctx context.Context, tx database.Tx, sequences []*latestSequence, commands []eventstore.Command) ([]eventstore.Event, error) {
events, placeholders, args, err := mapCommands(commands, sequences)
if err != nil {
return nil, err
}
rows, err := tx.QueryContext(ctx, fmt.Sprintf(pushStmt, strings.Join(placeholders, ", ")), args...)
if err != nil {
return nil, err
}
defer rows.Close()
for i := 0; rows.Next(); i++ {
err = rows.Scan(&events[i].(*event).createdAt, &events[i].(*event).position)
if err != nil {
logging.WithError(err).Warn("failed to scan events")
return nil, err
}
}
if err := rows.Err(); err != nil {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// Check if push tries to write an event just written
// by another transaction
if pgErr.Code == "40001" {
// TODO: @livio-a should we return the parent or not?
return nil, zerrors.ThrowInvalidArgument(err, "V3-p5xAn", "Errors.AlreadyExists")
}
}
logging.WithError(rows.Err()).Warn("failed to push events")
return nil, zerrors.ThrowInternal(err, "V3-VGnZY", "Errors.Internal")
}
return events, nil
}
const argsPerCommand = 10
func mapCommands(commands []eventstore.Command, sequences []*latestSequence) (events []eventstore.Event, placeholders []string, args []any, err error) {
events = make([]eventstore.Event, len(commands))
args = make([]any, 0, len(commands)*argsPerCommand)
placeholders = make([]string, len(commands))
for i, command := range commands {
sequence := searchSequenceByCommand(sequences, command)
if sequence == nil {
logging.WithFields(
"aggType", command.Aggregate().Type,
"aggID", command.Aggregate().ID,
"instance", command.Aggregate().InstanceID,
).Panic("no sequence found")
// added return for linting
return nil, nil, nil, nil
}
sequence.sequence++
events[i], err = commandToEventOld(sequence, command)
if err != nil {
return nil, nil, nil, err
}
placeholders[i] = fmt.Sprintf(pushPlaceholderFmt,
i*argsPerCommand+1,
i*argsPerCommand+2,
i*argsPerCommand+3,
i*argsPerCommand+4,
i*argsPerCommand+5,
i*argsPerCommand+6,
i*argsPerCommand+7,
i*argsPerCommand+8,
i*argsPerCommand+9,
i*argsPerCommand+10,
)
args = append(args,
events[i].(*event).command.InstanceID,
events[i].(*event).command.Owner,
events[i].(*event).command.AggregateType,
events[i].(*event).command.AggregateID,
events[i].(*event).command.Revision,
events[i].(*event).command.Creator,
events[i].(*event).command.CommandType,
events[i].(*event).command.Payload,
events[i].(*event).sequence,
i,
)
}
return events, placeholders, args, nil
}

View File

@@ -0,0 +1,144 @@
package eventstore
import (
"context"
"database/sql"
_ "embed"
"fmt"
"strings"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
type latestSequence struct {
aggregate *eventstore.Aggregate
sequence uint64
}
//go:embed sequences_query.sql
var latestSequencesStmt string
func latestSequences(ctx context.Context, tx database.Tx, commands []eventstore.Command) ([]*latestSequence, error) {
sequences := commandsToSequences(ctx, commands)
conditions, args := sequencesToSql(sequences)
rows, err := tx.QueryContext(ctx, fmt.Sprintf(latestSequencesStmt, strings.Join(conditions, " UNION ALL ")), args...)
if err != nil {
return nil, zerrors.ThrowInternal(err, "V3-5jU5z", "Errors.Internal")
}
defer rows.Close()
for rows.Next() {
if err := scanToSequence(rows, sequences); err != nil {
return nil, zerrors.ThrowInternal(err, "V3-Ydiwv", "Errors.Internal")
}
}
if rows.Err() != nil {
return nil, zerrors.ThrowInternal(rows.Err(), "V3-XApDk", "Errors.Internal")
}
return sequences, nil
}
func searchSequenceByCommand(sequences []*latestSequence, command eventstore.Command) *latestSequence {
for _, sequence := range sequences {
if sequence.aggregate.Type == command.Aggregate().Type &&
sequence.aggregate.ID == command.Aggregate().ID &&
sequence.aggregate.InstanceID == command.Aggregate().InstanceID {
return sequence
}
}
return nil
}
func searchSequence(sequences []*latestSequence, aggregateType eventstore.AggregateType, aggregateID, instanceID string) *latestSequence {
for _, sequence := range sequences {
if sequence.aggregate.Type == aggregateType &&
sequence.aggregate.ID == aggregateID &&
sequence.aggregate.InstanceID == instanceID {
return sequence
}
}
return nil
}
func commandsToSequences(ctx context.Context, commands []eventstore.Command) []*latestSequence {
sequences := make([]*latestSequence, 0, len(commands))
for _, command := range commands {
if searchSequenceByCommand(sequences, command) != nil {
continue
}
if command.Aggregate().InstanceID == "" {
command.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
}
sequences = append(sequences, &latestSequence{
aggregate: command.Aggregate(),
})
}
return sequences
}
const argsPerCondition = 3
func sequencesToSql(sequences []*latestSequence) (conditions []string, args []any) {
args = make([]interface{}, 0, len(sequences)*argsPerCondition)
conditions = make([]string, len(sequences))
for i, sequence := range sequences {
conditions[i] = fmt.Sprintf(`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $%d AND aggregate_type = $%d AND aggregate_id = $%d ORDER BY "sequence" DESC LIMIT 1)`,
i*argsPerCondition+1,
i*argsPerCondition+2,
i*argsPerCondition+3,
)
args = append(args, sequence.aggregate.InstanceID, sequence.aggregate.Type, sequence.aggregate.ID)
}
return conditions, args
}
func scanToSequence(rows *sql.Rows, sequences []*latestSequence) error {
var aggregateType eventstore.AggregateType
var aggregateID, instanceID string
var currentSequence uint64
var resourceOwner string
if err := rows.Scan(&instanceID, &resourceOwner, &aggregateType, &aggregateID, &currentSequence); err != nil {
return zerrors.ThrowInternal(err, "V3-OIWqj", "Errors.Internal")
}
sequence := searchSequence(sequences, aggregateType, aggregateID, instanceID)
if sequence == nil {
logging.WithFields(
"aggType", aggregateType,
"aggID", aggregateID,
"instance", instanceID,
).Panic("no sequence found")
// added return for linting
return nil
}
sequence.sequence = currentSequence
if resourceOwner != "" && sequence.aggregate.ResourceOwner != "" && sequence.aggregate.ResourceOwner != resourceOwner {
logging.WithFields(
"current_sequence", sequence.sequence,
"instance_id", sequence.aggregate.InstanceID,
"agg_type", sequence.aggregate.Type,
"agg_id", sequence.aggregate.ID,
"current_owner", resourceOwner,
"provided_owner", sequence.aggregate.ResourceOwner,
).Info("would have set wrong resource owner")
}
// set resource owner from previous events
if resourceOwner != "" {
sequence.aggregate.ResourceOwner = resourceOwner
}
return nil
}

View File

@@ -0,0 +1,293 @@
package eventstore
import (
"context"
_ "embed"
"reflect"
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore"
)
func Test_searchSequence(t *testing.T) {
sequence := &latestSequence{
aggregate: mockAggregate("V3-p1BWC"),
sequence: 1,
}
type args struct {
sequences []*latestSequence
aggregateType eventstore.AggregateType
aggregateID string
instanceID string
}
tests := []struct {
name string
args args
want *latestSequence
}{
{
name: "type missmatch",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "wrong",
aggregateID: "V3-p1BWC",
instanceID: "instance",
},
want: nil,
},
{
name: "id missmatch",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "type",
aggregateID: "wrong",
instanceID: "instance",
},
want: nil,
},
{
name: "instance missmatch",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "type",
aggregateID: "V3-p1BWC",
instanceID: "wrong",
},
want: nil,
},
{
name: "match",
args: args{
sequences: []*latestSequence{
sequence,
},
aggregateType: "type",
aggregateID: "V3-p1BWC",
instanceID: "instance",
},
want: sequence,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := searchSequence(tt.args.sequences, tt.args.aggregateType, tt.args.aggregateID, tt.args.instanceID); !reflect.DeepEqual(got, tt.want) {
t.Errorf("searchSequence() = %v, want %v", got, tt.want)
}
})
}
}
func Test_commandsToSequences(t *testing.T) {
aggregate := mockAggregate("V3-MKHTF")
type args struct {
ctx context.Context
commands []eventstore.Command
}
tests := []struct {
name string
args args
want []*latestSequence
}{
{
name: "no command",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{},
},
want: []*latestSequence{},
},
{
name: "one command",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
&mockCommand{
aggregate: aggregate,
},
},
},
want: []*latestSequence{
{
aggregate: aggregate,
},
},
},
{
name: "two commands same aggregate",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
&mockCommand{
aggregate: aggregate,
},
&mockCommand{
aggregate: aggregate,
},
},
},
want: []*latestSequence{
{
aggregate: aggregate,
},
},
},
{
name: "two commands different aggregates",
args: args{
ctx: context.Background(),
commands: []eventstore.Command{
&mockCommand{
aggregate: aggregate,
},
&mockCommand{
aggregate: mockAggregate("V3-cZkCy"),
},
},
},
want: []*latestSequence{
{
aggregate: aggregate,
},
{
aggregate: mockAggregate("V3-cZkCy"),
},
},
},
{
name: "instance set in command",
args: args{
ctx: authz.WithInstanceID(context.Background(), "V3-ANV4p"),
commands: []eventstore.Command{
&mockCommand{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
InstanceID: "instance",
Version: "v1",
},
},
},
},
want: []*latestSequence{
{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
InstanceID: "instance",
Version: "v1",
},
},
},
},
{
name: "instance from context",
args: args{
ctx: authz.WithInstanceID(context.Background(), "V3-ANV4p"),
commands: []eventstore.Command{
&mockCommand{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
Version: "v1",
},
},
},
},
want: []*latestSequence{
{
aggregate: &eventstore.Aggregate{
ID: "V3-bF0Sa",
Type: "type",
ResourceOwner: "to",
InstanceID: "V3-ANV4p",
Version: "v1",
},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := commandsToSequences(tt.args.ctx, tt.args.commands)
assert.ElementsMatch(t, tt.want, got)
})
}
}
func Test_sequencesToSql(t *testing.T) {
tests := []struct {
name string
arg []*latestSequence
wantConditions []string
wantArgs []any
}{
{
name: "no sequence",
arg: []*latestSequence{},
wantConditions: []string{},
wantArgs: []any{},
},
{
name: "one",
arg: []*latestSequence{
{
aggregate: mockAggregate("V3-SbpGB"),
},
},
wantConditions: []string{
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 ORDER BY "sequence" DESC LIMIT 1)`,
},
wantArgs: []any{
"instance",
eventstore.AggregateType("type"),
"V3-SbpGB",
},
},
{
name: "multiple",
arg: []*latestSequence{
{
aggregate: mockAggregate("V3-SbpGB"),
},
{
aggregate: mockAggregate("V3-0X3yt"),
},
},
wantConditions: []string{
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND aggregate_id = $3 ORDER BY "sequence" DESC LIMIT 1)`,
`(SELECT instance_id, aggregate_type, aggregate_id, "sequence" FROM eventstore.events2 WHERE instance_id = $4 AND aggregate_type = $5 AND aggregate_id = $6 ORDER BY "sequence" DESC LIMIT 1)`,
},
wantArgs: []any{
"instance",
eventstore.AggregateType("type"),
"V3-SbpGB",
"instance",
eventstore.AggregateType("type"),
"V3-0X3yt",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotConditions, gotArgs := sequencesToSql(tt.arg)
if !reflect.DeepEqual(gotConditions, tt.wantConditions) {
t.Errorf("sequencesToSql() gotConditions = %v, want %v", gotConditions, tt.wantConditions)
}
if !reflect.DeepEqual(gotArgs, tt.wantArgs) {
t.Errorf("sequencesToSql() gotArgs = %v, want %v", gotArgs, tt.wantArgs)
}
})
}
}

View File

@@ -0,0 +1,18 @@
WITH existing AS (
%s
) SELECT
e.instance_id
, e.owner
, e.aggregate_type
, e.aggregate_id
, e.sequence
FROM
eventstore.events2 e
JOIN
existing
ON
e.instance_id = existing.instance_id
AND e.aggregate_type = existing.aggregate_type
AND e.aggregate_id = existing.aggregate_id
AND e.sequence = existing.sequence
FOR UPDATE;

View File

@@ -0,0 +1,25 @@
package eventstore
import "database/sql/driver"
// Payload represents a byte array that may be null.
// Payload implements the sql.Scanner interface
type Payload []byte
// Scan implements the Scanner interface.
func (data *Payload) Scan(value interface{}) error {
if value == nil {
*data = nil
return nil
}
*data = Payload(value.([]byte))
return nil
}
// Value implements the driver Valuer interface.
func (data Payload) Value() (driver.Value, error) {
if len(data) == 0 {
return nil, nil
}
return []byte(data), nil
}

View File

@@ -0,0 +1,100 @@
package eventstore
import (
"context"
_ "embed"
"errors"
"fmt"
"strings"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
var (
//go:embed unique_constraints_delete.sql
deleteConstraintStmt string
//go:embed unique_constraints_delete_placeholders.sql
deleteConstraintPlaceholdersStmt string
//go:embed unique_constraints_add.sql
addConstraintStmt string
)
func handleUniqueConstraints(ctx context.Context, tx database.Tx, commands []eventstore.Command) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
deletePlaceholders := make([]string, 0)
deleteArgs := make([]any, 0)
addPlaceholders := make([]string, 0)
addArgs := make([]any, 0)
addConstraints := map[string]*eventstore.UniqueConstraint{}
deleteConstraints := map[string]*eventstore.UniqueConstraint{}
for _, command := range commands {
for _, constraint := range command.UniqueConstraints() {
instanceID := command.Aggregate().InstanceID
if constraint.IsGlobal {
instanceID = ""
}
switch constraint.Action {
case eventstore.UniqueConstraintAdd:
constraint.UniqueField = strings.ToLower(constraint.UniqueField)
addPlaceholders = append(addPlaceholders, fmt.Sprintf("($%d, $%d, $%d)", len(addArgs)+1, len(addArgs)+2, len(addArgs)+3))
addArgs = append(addArgs, instanceID, constraint.UniqueType, constraint.UniqueField)
addConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
case eventstore.UniqueConstraintRemove:
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf(deleteConstraintPlaceholdersStmt, len(deleteArgs)+1, len(deleteArgs)+2, len(deleteArgs)+3))
deleteArgs = append(deleteArgs, instanceID, constraint.UniqueType, constraint.UniqueField)
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
case eventstore.UniqueConstraintInstanceRemove:
deletePlaceholders = append(deletePlaceholders, fmt.Sprintf("(instance_id = $%d)", len(deleteArgs)+1))
deleteArgs = append(deleteArgs, instanceID)
deleteConstraints[fmt.Sprintf(uniqueConstraintPlaceholderFmt, instanceID, constraint.UniqueType, constraint.UniqueField)] = constraint
}
}
}
if len(deletePlaceholders) > 0 {
_, err := tx.ExecContext(ctx, fmt.Sprintf(deleteConstraintStmt, strings.Join(deletePlaceholders, " OR ")), deleteArgs...)
if err != nil {
logging.WithError(err).Warn("delete unique constraint failed")
errMessage := "Errors.Internal"
if constraint := constraintFromErr(err, deleteConstraints); constraint != nil {
errMessage = constraint.ErrorMessage
}
return zerrors.ThrowInternal(err, "V3-C8l3V", errMessage)
}
}
if len(addPlaceholders) > 0 {
_, err := tx.ExecContext(ctx, fmt.Sprintf(addConstraintStmt, strings.Join(addPlaceholders, ", ")), addArgs...)
if err != nil {
logging.WithError(err).Warn("add unique constraint failed")
errMessage := "Errors.Internal"
if constraint := constraintFromErr(err, addConstraints); constraint != nil {
errMessage = constraint.ErrorMessage
}
return zerrors.ThrowAlreadyExists(err, "V3-DKcYh", errMessage)
}
}
return nil
}
func constraintFromErr(err error, constraints map[string]*eventstore.UniqueConstraint) *eventstore.UniqueConstraint {
pgErr := new(pgconn.PgError)
if !errors.As(err, &pgErr) {
return nil
}
for key, constraint := range constraints {
if strings.Contains(pgErr.Detail, key) {
return constraint
}
}
return nil
}

View File

@@ -0,0 +1,6 @@
INSERT INTO eventstore.unique_constraints (
instance_id
, unique_type
, unique_field
) VALUES
%s

View File

@@ -0,0 +1 @@
DELETE FROM eventstore.unique_constraints WHERE %s

View File

@@ -0,0 +1,13 @@
-- the query is so complex because we accidentally stored unique constraint case sensitive
-- the query checks first if there is a case sensitive match and afterwards if there is a case insensitive match
(instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = (
SELECT unique_field from (
SELECT instance_id, unique_type, unique_field
FROM eventstore.unique_constraints
WHERE instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = $%[3]d
UNION ALL
SELECT instance_id, unique_type, unique_field
FROM eventstore.unique_constraints
WHERE instance_id = $%[1]d AND unique_type = $%[2]d AND unique_field = LOWER($%[3]d)
) AS case_insensitive_constraints LIMIT 1)
)

View File

@@ -0,0 +1,18 @@
package eventstore
import (
"regexp"
"github.com/zitadel/zitadel/internal/zerrors"
)
type Version string
var versionRegexp = regexp.MustCompile(`^v[0-9]+(\.[0-9]+){0,2}$`)
func (v Version) Validate() error {
if !versionRegexp.MatchString(string(v)) {
return zerrors.ThrowPreconditionFailed(nil, "MODEL-luDuS", "version is not semver")
}
return nil
}

View File

@@ -0,0 +1,39 @@
package eventstore
import "testing"
func TestVersion_Validate(t *testing.T) {
tests := []struct {
name string
v Version
wantErr bool
}{
{
"correct version",
"v1.23.23",
false,
},
{
"no v prefix",
"1.2.2",
true,
},
{
"letters in version",
"v1.as.3",
true,
},
{
"no version",
"",
true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.v.Validate(); (err != nil) != tt.wantErr {
t.Errorf("Version.Validate() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

View File

@@ -0,0 +1,53 @@
package eventstore
import (
"time"
)
// WriteModel is the minimum representation of a command side write model.
// It implements a basic reducer
// it's purpose is to reduce events to create new ones
type WriteModel struct {
AggregateID string `json:"-"`
ProcessedSequence uint64 `json:"-"`
Events []Event `json:"-"`
ResourceOwner string `json:"-"`
InstanceID string `json:"-"`
ChangeDate time.Time `json:"-"`
}
// AppendEvents adds all the events to the read model.
// The function doesn't compute the new state of the read model
func (rm *WriteModel) AppendEvents(events ...Event) {
rm.Events = append(rm.Events, events...)
}
// Reduce is the basic implementation of reducer
// If this function is extended the extending function should be the last step
func (wm *WriteModel) Reduce() error {
if len(wm.Events) == 0 {
return nil
}
latestEvent := wm.Events[len(wm.Events)-1]
if wm.AggregateID == "" {
wm.AggregateID = latestEvent.Aggregate().ID
}
if wm.ResourceOwner == "" {
wm.ResourceOwner = latestEvent.Aggregate().ResourceOwner
}
if wm.InstanceID == "" {
wm.InstanceID = latestEvent.Aggregate().InstanceID
}
wm.ProcessedSequence = latestEvent.Sequence()
wm.ChangeDate = latestEvent.CreatedAt()
// all events processed and not needed anymore
wm.Events = nil
wm.Events = []Event{}
return nil
}