mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-03 10:02:16 +00:00
feat(eventstore): increase parallel write capabilities (#5940)
This implementation increases parallel write capabilities of the eventstore. Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and [06](https://zitadel.com/docs/support/advisory/a10006). The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`. If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
52
internal/eventstore/handler/v2/event.go
Normal file
52
internal/eventstore/handler/v2/event.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
schedulerSucceeded = eventstore.EventType("system.projections.scheduler.succeeded")
|
||||
aggregateType = eventstore.AggregateType("system")
|
||||
aggregateID = "SYSTEM"
|
||||
)
|
||||
|
||||
func (h *Handler) didProjectionInitialize(ctx context.Context) bool {
|
||||
events, err := h.es.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
InstanceID("").
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
AggregateIDs(aggregateID).
|
||||
EventTypes(schedulerSucceeded).
|
||||
EventData(map[string]interface{}{
|
||||
"name": h.projection.Name(),
|
||||
}).
|
||||
Builder(),
|
||||
)
|
||||
return len(events) > 0 && err == nil
|
||||
}
|
||||
|
||||
func (h *Handler) setSucceededOnce(ctx context.Context) error {
|
||||
_, err := h.es.Push(ctx, &ProjectionSucceededEvent{
|
||||
BaseEvent: *eventstore.NewBaseEventForPush(ctx,
|
||||
eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"),
|
||||
schedulerSucceeded,
|
||||
),
|
||||
Name: h.projection.Name(),
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
type ProjectionSucceededEvent struct {
|
||||
eventstore.BaseEvent `json:"-"`
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
func (p *ProjectionSucceededEvent) Payload() interface{} {
|
||||
return p
|
||||
}
|
||||
|
||||
func (p *ProjectionSucceededEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
|
||||
return nil
|
||||
}
|
||||
95
internal/eventstore/handler/v2/failed_event.go
Normal file
95
internal/eventstore/handler/v2/failed_event.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed failed_event_set.sql
|
||||
setFailedEventStmt string
|
||||
//go:embed failed_event_get_count.sql
|
||||
failureCountStmt string
|
||||
)
|
||||
|
||||
type failure struct {
|
||||
sequence uint64
|
||||
instance string
|
||||
aggregateID string
|
||||
aggregateType eventstore.AggregateType
|
||||
eventDate time.Time
|
||||
err error
|
||||
}
|
||||
|
||||
func failureFromEvent(event eventstore.Event, err error) *failure {
|
||||
return &failure{
|
||||
sequence: event.Sequence(),
|
||||
instance: event.Aggregate().InstanceID,
|
||||
aggregateID: event.Aggregate().ID,
|
||||
aggregateType: event.Aggregate().Type,
|
||||
eventDate: event.CreatedAt(),
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func failureFromStatement(statement *Statement, err error) *failure {
|
||||
return &failure{
|
||||
sequence: statement.Sequence,
|
||||
instance: statement.InstanceID,
|
||||
aggregateID: statement.AggregateID,
|
||||
aggregateType: statement.AggregateType,
|
||||
eventDate: statement.CreationDate,
|
||||
err: err,
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) handleFailedStmt(tx *sql.Tx, currentState *state, f *failure) (shouldContinue bool) {
|
||||
failureCount, err := h.failureCount(tx, f)
|
||||
if err != nil {
|
||||
h.logFailure(f).WithError(err).Warn("unable to get failure count")
|
||||
return false
|
||||
}
|
||||
failureCount += 1
|
||||
err = h.setFailureCount(tx, failureCount, f)
|
||||
h.logFailure(f).OnError(err).Warn("unable to update failure count")
|
||||
|
||||
return failureCount >= h.maxFailureCount
|
||||
}
|
||||
|
||||
func (h *Handler) failureCount(tx *sql.Tx, f *failure) (count uint8, err error) {
|
||||
row := tx.QueryRow(failureCountStmt,
|
||||
h.projection.Name(),
|
||||
f.instance,
|
||||
f.aggregateType,
|
||||
f.aggregateID,
|
||||
f.sequence,
|
||||
)
|
||||
if err = row.Err(); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
|
||||
}
|
||||
if err = row.Scan(&count); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (h *Handler) setFailureCount(tx *sql.Tx, count uint8, f *failure) error {
|
||||
_, err := tx.Exec(setFailedEventStmt,
|
||||
h.projection.Name(),
|
||||
f.instance,
|
||||
f.aggregateType,
|
||||
f.aggregateID,
|
||||
f.eventDate,
|
||||
f.sequence,
|
||||
count,
|
||||
f.err.Error(),
|
||||
)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-4Ht4x", "set failure count failed")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
12
internal/eventstore/handler/v2/failed_event_get_count.sql
Normal file
12
internal/eventstore/handler/v2/failed_event_get_count.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
WITH failures AS (
|
||||
SELECT
|
||||
failure_count
|
||||
FROM
|
||||
projections.failed_events2
|
||||
WHERE
|
||||
projection_name = $1
|
||||
AND instance_id = $2
|
||||
AND aggregate_type = $3
|
||||
AND aggregate_id = $4
|
||||
AND failed_sequence = $5
|
||||
) SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count
|
||||
31
internal/eventstore/handler/v2/failed_event_set.sql
Normal file
31
internal/eventstore/handler/v2/failed_event_set.sql
Normal file
@@ -0,0 +1,31 @@
|
||||
INSERT INTO projections.failed_events2 (
|
||||
projection_name
|
||||
, instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, event_creation_date
|
||||
, failed_sequence
|
||||
, failure_count
|
||||
, error
|
||||
, last_failed
|
||||
) VALUES (
|
||||
$1
|
||||
, $2
|
||||
, $3
|
||||
, $4
|
||||
, $5
|
||||
, $6
|
||||
, $7
|
||||
, $8
|
||||
, now()
|
||||
) ON CONFLICT (
|
||||
projection_name
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, failed_sequence
|
||||
, instance_id
|
||||
) DO UPDATE SET
|
||||
failure_count = EXCLUDED.failure_count
|
||||
, error = EXCLUDED.error
|
||||
, last_failed = EXCLUDED.last_failed
|
||||
;
|
||||
465
internal/eventstore/handler/v2/handler.go
Normal file
465
internal/eventstore/handler/v2/handler.go
Normal file
@@ -0,0 +1,465 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"math"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/repository/pseudo"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type EventStore interface {
|
||||
InstanceIDs(ctx context.Context, maxAge time.Duration, forceLoad bool, query *eventstore.SearchQueryBuilder) ([]string, error)
|
||||
Filter(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error)
|
||||
Push(ctx context.Context, cmds ...eventstore.Command) ([]eventstore.Event, error)
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Client *database.DB
|
||||
Eventstore EventStore
|
||||
|
||||
BulkLimit uint16
|
||||
RequeueEvery time.Duration
|
||||
RetryFailedAfter time.Duration
|
||||
HandleActiveInstances time.Duration
|
||||
TransactionDuration time.Duration
|
||||
MaxFailureCount uint8
|
||||
|
||||
TriggerWithoutEvents Reduce
|
||||
}
|
||||
|
||||
type Handler struct {
|
||||
client *database.DB
|
||||
projection Projection
|
||||
|
||||
es EventStore
|
||||
bulkLimit uint16
|
||||
eventTypes map[eventstore.AggregateType][]eventstore.EventType
|
||||
|
||||
maxFailureCount uint8
|
||||
retryFailedAfter time.Duration
|
||||
requeueEvery time.Duration
|
||||
handleActiveInstances time.Duration
|
||||
txDuration time.Duration
|
||||
now nowFunc
|
||||
|
||||
triggeredInstancesSync sync.Map
|
||||
|
||||
triggerWithoutEvents Reduce
|
||||
}
|
||||
|
||||
// nowFunc makes [time.Now] mockable
|
||||
type nowFunc func() time.Time
|
||||
|
||||
type Projection interface {
|
||||
Name() string
|
||||
Reducers() []AggregateReducer
|
||||
}
|
||||
|
||||
func NewHandler(
|
||||
ctx context.Context,
|
||||
config *Config,
|
||||
projection Projection,
|
||||
) *Handler {
|
||||
aggregates := make(map[eventstore.AggregateType][]eventstore.EventType, len(projection.Reducers()))
|
||||
for _, reducer := range projection.Reducers() {
|
||||
eventTypes := make([]eventstore.EventType, len(reducer.EventReducers))
|
||||
for i, eventReducer := range reducer.EventReducers {
|
||||
eventTypes[i] = eventReducer.Event
|
||||
}
|
||||
if _, ok := aggregates[reducer.Aggregate]; ok {
|
||||
aggregates[reducer.Aggregate] = append(aggregates[reducer.Aggregate], eventTypes...)
|
||||
continue
|
||||
}
|
||||
aggregates[reducer.Aggregate] = eventTypes
|
||||
}
|
||||
|
||||
handler := &Handler{
|
||||
projection: projection,
|
||||
client: config.Client,
|
||||
es: config.Eventstore,
|
||||
bulkLimit: config.BulkLimit,
|
||||
eventTypes: aggregates,
|
||||
requeueEvery: config.RequeueEvery,
|
||||
handleActiveInstances: config.HandleActiveInstances,
|
||||
now: time.Now,
|
||||
maxFailureCount: config.MaxFailureCount,
|
||||
retryFailedAfter: config.RetryFailedAfter,
|
||||
triggeredInstancesSync: sync.Map{},
|
||||
triggerWithoutEvents: config.TriggerWithoutEvents,
|
||||
txDuration: config.TransactionDuration,
|
||||
}
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func (h *Handler) Start(ctx context.Context) {
|
||||
go h.schedule(ctx)
|
||||
if h.triggerWithoutEvents != nil {
|
||||
return
|
||||
}
|
||||
go h.subscribe(ctx)
|
||||
}
|
||||
|
||||
func (h *Handler) schedule(ctx context.Context) {
|
||||
// if there was no run before trigger instantly
|
||||
t := time.NewTimer(0)
|
||||
didInitialize := h.didProjectionInitialize(ctx)
|
||||
if didInitialize {
|
||||
t.Reset(h.requeueEvery)
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
t.Stop()
|
||||
return
|
||||
case <-t.C:
|
||||
instances, err := h.queryInstances(ctx, didInitialize)
|
||||
h.log().OnError(err).Debug("unable to query instances")
|
||||
|
||||
var instanceFailed bool
|
||||
scheduledCtx := call.WithTimestamp(ctx)
|
||||
for _, instance := range instances {
|
||||
instanceCtx := authz.WithInstanceID(scheduledCtx, instance)
|
||||
|
||||
// simple implementation of do while
|
||||
_, err = h.Trigger(instanceCtx)
|
||||
instanceFailed = instanceFailed || err != nil
|
||||
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
|
||||
// retry if trigger failed
|
||||
for ; err != nil; _, err = h.Trigger(instanceCtx) {
|
||||
time.Sleep(h.retryFailedAfter)
|
||||
instanceFailed = instanceFailed || err != nil
|
||||
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
|
||||
if err == nil {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if !didInitialize && !instanceFailed {
|
||||
err = h.setSucceededOnce(ctx)
|
||||
h.log().OnError(err).Debug("unable to set succeeded once")
|
||||
didInitialize = err == nil
|
||||
}
|
||||
t.Reset(h.requeueEvery)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) subscribe(ctx context.Context) {
|
||||
queue := make(chan eventstore.Event, 100)
|
||||
subscription := eventstore.SubscribeEventTypes(queue, h.eventTypes)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
subscription.Unsubscribe()
|
||||
h.log().Debug("shutdown")
|
||||
return
|
||||
case event := <-queue:
|
||||
events := checkAdditionalEvents(queue, event)
|
||||
solvedInstances := make([]string, 0, len(events))
|
||||
queueCtx := call.WithTimestamp(ctx)
|
||||
for _, e := range events {
|
||||
if instanceSolved(solvedInstances, e.Aggregate().InstanceID) {
|
||||
continue
|
||||
}
|
||||
queueCtx = authz.WithInstanceID(queueCtx, e.Aggregate().InstanceID)
|
||||
_, err := h.Trigger(queueCtx)
|
||||
h.log().OnError(err).Debug("trigger of queued event failed")
|
||||
if err == nil {
|
||||
solvedInstances = append(solvedInstances, e.Aggregate().InstanceID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func instanceSolved(solvedInstances []string, instanceID string) bool {
|
||||
for _, solvedInstance := range solvedInstances {
|
||||
if solvedInstance == instanceID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
|
||||
events := make([]eventstore.Event, 1)
|
||||
events[0] = event
|
||||
for {
|
||||
wait := time.NewTimer(1 * time.Millisecond)
|
||||
select {
|
||||
case event := <-eventQueue:
|
||||
events = append(events, event)
|
||||
case <-wait.C:
|
||||
return events
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) queryInstances(ctx context.Context, didInitialize bool) ([]string, error) {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).
|
||||
AwaitOpenTransactions().
|
||||
AllowTimeTravel().
|
||||
ExcludedInstanceID("")
|
||||
if didInitialize {
|
||||
query = query.
|
||||
CreationDateAfter(h.now().Add(-1 * h.handleActiveInstances))
|
||||
}
|
||||
return h.es.InstanceIDs(ctx, h.requeueEvery, !didInitialize, query)
|
||||
}
|
||||
|
||||
type triggerConfig struct {
|
||||
awaitRunning bool
|
||||
}
|
||||
|
||||
type triggerOpt func(conf *triggerConfig)
|
||||
|
||||
func WithAwaitRunning() triggerOpt {
|
||||
return func(conf *triggerConfig) {
|
||||
conf.awaitRunning = true
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) Trigger(ctx context.Context, opts ...triggerOpt) (_ context.Context, err error) {
|
||||
if authz.GetInstance(ctx).InstanceID() != "" {
|
||||
var span *tracing.Span
|
||||
ctx, span = tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
}
|
||||
|
||||
config := new(triggerConfig)
|
||||
for _, opt := range opts {
|
||||
opt(config)
|
||||
}
|
||||
|
||||
cancel := h.lockInstance(ctx, config)
|
||||
if cancel == nil {
|
||||
return call.ResetTimestamp(ctx), nil
|
||||
}
|
||||
defer cancel()
|
||||
|
||||
for i := 0; ; i++ {
|
||||
additionalIteration, err := h.processEvents(ctx, config)
|
||||
h.log().OnError(err).Warn("process events failed")
|
||||
h.log().WithField("iteration", i).Debug("trigger iteration")
|
||||
if !additionalIteration || err != nil {
|
||||
return call.ResetTimestamp(ctx), err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// lockInstances tries to lock the instance.
|
||||
// If the instance is already locked from another process no cancel function is returned
|
||||
// the instance can be skipped then
|
||||
// If the instance is locked, an unlock deferable function is returned
|
||||
func (h *Handler) lockInstance(ctx context.Context, config *triggerConfig) func() {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
|
||||
// Check that the instance has a mutex to lock
|
||||
instanceMu, _ := h.triggeredInstancesSync.LoadOrStore(instanceID, new(sync.Mutex))
|
||||
unlock := func() {
|
||||
instanceMu.(*sync.Mutex).Unlock()
|
||||
}
|
||||
if !instanceMu.(*sync.Mutex).TryLock() {
|
||||
instanceMu.(*sync.Mutex).Lock()
|
||||
if config.awaitRunning {
|
||||
return unlock
|
||||
}
|
||||
defer unlock()
|
||||
return nil
|
||||
}
|
||||
return unlock
|
||||
}
|
||||
|
||||
func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (additionalIteration bool, err error) {
|
||||
defer func() {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(err, &pgErr) {
|
||||
// error returned if the row is currently locked by another connection
|
||||
if pgErr.Code == "55P03" {
|
||||
h.log().Debug("state already locked")
|
||||
err = nil
|
||||
additionalIteration = false
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
if h.txDuration > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
|
||||
defer cancel()
|
||||
}
|
||||
|
||||
tx, err := h.client.Begin()
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback()
|
||||
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
|
||||
return
|
||||
}
|
||||
err = tx.Commit()
|
||||
}()
|
||||
|
||||
currentState, err := h.currentState(ctx, tx, config)
|
||||
if err != nil {
|
||||
if errors.Is(err, errJustUpdated) {
|
||||
return false, nil
|
||||
}
|
||||
return additionalIteration, err
|
||||
}
|
||||
|
||||
var statements []*Statement
|
||||
statements, additionalIteration, err = h.generateStatements(ctx, tx, currentState)
|
||||
if err != nil || len(statements) == 0 {
|
||||
return additionalIteration, err
|
||||
}
|
||||
|
||||
lastProcessedIndex, err := h.executeStatements(ctx, tx, currentState, statements)
|
||||
if lastProcessedIndex < 0 {
|
||||
return false, err
|
||||
}
|
||||
|
||||
currentState.position = statements[lastProcessedIndex].Position
|
||||
currentState.aggregateID = statements[lastProcessedIndex].AggregateID
|
||||
currentState.aggregateType = statements[lastProcessedIndex].AggregateType
|
||||
currentState.sequence = statements[lastProcessedIndex].Sequence
|
||||
currentState.eventTimestamp = statements[lastProcessedIndex].CreationDate
|
||||
err = h.setState(tx, currentState)
|
||||
|
||||
return additionalIteration, err
|
||||
}
|
||||
|
||||
func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentState *state) (_ []*Statement, additionalIteration bool, err error) {
|
||||
if h.triggerWithoutEvents != nil {
|
||||
stmt, err := h.triggerWithoutEvents(pseudo.NewScheduledEvent(ctx, time.Now(), currentState.instanceID))
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
return []*Statement{stmt}, false, nil
|
||||
}
|
||||
|
||||
events, err := h.es.Filter(ctx, h.eventQuery(currentState))
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("filter eventstore failed")
|
||||
return nil, false, err
|
||||
}
|
||||
eventAmount := len(events)
|
||||
events = skipPreviouslyReduced(events, currentState)
|
||||
|
||||
if len(events) == 0 {
|
||||
h.updateLastUpdated(ctx, tx, currentState)
|
||||
return nil, false, nil
|
||||
}
|
||||
|
||||
statements, err := h.eventsToStatements(tx, events, currentState)
|
||||
if len(statements) == 0 {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
additionalIteration = eventAmount == int(h.bulkLimit)
|
||||
if len(statements) < len(events) {
|
||||
// retry imediatly if statements failed
|
||||
additionalIteration = true
|
||||
}
|
||||
|
||||
return statements, additionalIteration, nil
|
||||
}
|
||||
|
||||
func skipPreviouslyReduced(events []eventstore.Event, currentState *state) []eventstore.Event {
|
||||
for i, event := range events {
|
||||
if event.Position() == currentState.position &&
|
||||
event.Aggregate().ID == currentState.aggregateID &&
|
||||
event.Aggregate().Type == currentState.aggregateType &&
|
||||
event.Sequence() == currentState.sequence {
|
||||
return events[i+1:]
|
||||
}
|
||||
}
|
||||
return events
|
||||
}
|
||||
|
||||
func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentState *state, statements []*Statement) (lastProcessedIndex int, err error) {
|
||||
lastProcessedIndex = -1
|
||||
|
||||
for i, statement := range statements {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
break
|
||||
default:
|
||||
err := h.executeStatement(ctx, tx, currentState, statement)
|
||||
if err != nil {
|
||||
return lastProcessedIndex, err
|
||||
}
|
||||
lastProcessedIndex = i
|
||||
}
|
||||
}
|
||||
return lastProcessedIndex, nil
|
||||
}
|
||||
|
||||
func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, currentState *state, statement *Statement) (err error) {
|
||||
if statement.Execute == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
_, err = tx.Exec("SAVEPOINT exec")
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("create savepoint failed")
|
||||
return err
|
||||
}
|
||||
var shouldContinue bool
|
||||
defer func() {
|
||||
_, err = tx.Exec("RELEASE SAVEPOINT exec")
|
||||
}()
|
||||
|
||||
if err = statement.Execute(tx, h.projection.Name()); err != nil {
|
||||
h.log().WithError(err).Error("statement execution failed")
|
||||
|
||||
shouldContinue = h.handleFailedStmt(tx, currentState, failureFromStatement(statement, err))
|
||||
if shouldContinue {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder {
|
||||
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
|
||||
AwaitOpenTransactions().
|
||||
Limit(uint64(h.bulkLimit)).
|
||||
AllowTimeTravel().
|
||||
OrderAsc().
|
||||
InstanceID(currentState.instanceID)
|
||||
|
||||
if currentState.position > 0 {
|
||||
builder = builder.PositionAfter(math.Float64frombits(math.Float64bits(currentState.position) - 10))
|
||||
}
|
||||
|
||||
for aggregateType, eventTypes := range h.eventTypes {
|
||||
query := builder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
EventTypes(eventTypes...)
|
||||
|
||||
builder = query.Builder()
|
||||
}
|
||||
|
||||
return builder
|
||||
}
|
||||
413
internal/eventstore/handler/v2/init.go
Normal file
413
internal/eventstore/handler/v2/init.go
Normal file
@@ -0,0 +1,413 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
columns []*InitColumn
|
||||
primaryKey PrimaryKey
|
||||
indices []*Index
|
||||
constraints []*Constraint
|
||||
foreignKeys []*ForeignKey
|
||||
}
|
||||
|
||||
func NewTable(columns []*InitColumn, key PrimaryKey, opts ...TableOption) *Table {
|
||||
t := &Table{
|
||||
columns: columns,
|
||||
primaryKey: key,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type SuffixedTable struct {
|
||||
Table
|
||||
suffix string
|
||||
}
|
||||
|
||||
func NewSuffixedTable(columns []*InitColumn, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
|
||||
return &SuffixedTable{
|
||||
Table: *NewTable(columns, key, opts...),
|
||||
suffix: suffix,
|
||||
}
|
||||
}
|
||||
|
||||
type TableOption func(*Table)
|
||||
|
||||
func WithIndex(index *Index) TableOption {
|
||||
return func(table *Table) {
|
||||
table.indices = append(table.indices, index)
|
||||
}
|
||||
}
|
||||
|
||||
func WithConstraint(constraint *Constraint) TableOption {
|
||||
return func(table *Table) {
|
||||
table.constraints = append(table.constraints, constraint)
|
||||
}
|
||||
}
|
||||
|
||||
func WithForeignKey(key *ForeignKey) TableOption {
|
||||
return func(table *Table) {
|
||||
table.foreignKeys = append(table.foreignKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
type InitColumn struct {
|
||||
Name string
|
||||
Type ColumnType
|
||||
nullable bool
|
||||
defaultValue interface{}
|
||||
deleteCascade string
|
||||
}
|
||||
|
||||
type ColumnOption func(*InitColumn)
|
||||
|
||||
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *InitColumn {
|
||||
column := &InitColumn{
|
||||
Name: name,
|
||||
Type: columnType,
|
||||
nullable: false,
|
||||
defaultValue: nil,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(column)
|
||||
}
|
||||
return column
|
||||
}
|
||||
|
||||
func Nullable() ColumnOption {
|
||||
return func(c *InitColumn) {
|
||||
c.nullable = true
|
||||
}
|
||||
}
|
||||
|
||||
func Default(value interface{}) ColumnOption {
|
||||
return func(c *InitColumn) {
|
||||
c.defaultValue = value
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteCascade(column string) ColumnOption {
|
||||
return func(c *InitColumn) {
|
||||
c.deleteCascade = column
|
||||
}
|
||||
}
|
||||
|
||||
type PrimaryKey []string
|
||||
|
||||
func NewPrimaryKey(columnNames ...string) PrimaryKey {
|
||||
return columnNames
|
||||
}
|
||||
|
||||
type ColumnType int32
|
||||
|
||||
const (
|
||||
ColumnTypeText ColumnType = iota
|
||||
ColumnTypeTextArray
|
||||
ColumnTypeJSONB
|
||||
ColumnTypeBytes
|
||||
ColumnTypeTimestamp
|
||||
ColumnTypeInterval
|
||||
ColumnTypeEnum
|
||||
ColumnTypeEnumArray
|
||||
ColumnTypeInt64
|
||||
ColumnTypeBool
|
||||
)
|
||||
|
||||
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
|
||||
i := &Index{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
Name string
|
||||
Columns []string
|
||||
includes []string
|
||||
}
|
||||
|
||||
type indexOpts func(*Index)
|
||||
|
||||
func WithInclude(columns ...string) indexOpts {
|
||||
return func(i *Index) {
|
||||
i.includes = columns
|
||||
}
|
||||
}
|
||||
|
||||
func NewConstraint(name string, columns []string) *Constraint {
|
||||
i := &Constraint{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
type Constraint struct {
|
||||
Name string
|
||||
Columns []string
|
||||
}
|
||||
|
||||
func NewForeignKey(name string, columns []string, refColumns []string) *ForeignKey {
|
||||
i := &ForeignKey{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
RefColumns: refColumns,
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func NewForeignKeyOfPublicKeys() *ForeignKey {
|
||||
return &ForeignKey{
|
||||
Name: "",
|
||||
}
|
||||
}
|
||||
|
||||
type ForeignKey struct {
|
||||
Name string
|
||||
Columns []string
|
||||
RefColumns []string
|
||||
}
|
||||
|
||||
type initializer interface {
|
||||
Init() *handler.Check
|
||||
}
|
||||
|
||||
func (h *Handler) Init(ctx context.Context) error {
|
||||
check, ok := h.projection.(initializer)
|
||||
if !ok || check.Init().IsNoop() {
|
||||
return nil
|
||||
}
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
|
||||
}
|
||||
for i, execute := range check.Init().Executes {
|
||||
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("executing check")
|
||||
next, err := execute(tx, h.projection.Name())
|
||||
if err != nil {
|
||||
logging.OnError(tx.Rollback()).Debug("unable to rollback")
|
||||
return err
|
||||
}
|
||||
if !next {
|
||||
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("projection set up")
|
||||
break
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func NewTableCheck(table *Table, opts ...execOption) *handler.Check {
|
||||
config := execConfig{}
|
||||
create := func(config execConfig) string {
|
||||
return createTableStatement(table, config.tableName, "")
|
||||
}
|
||||
executes := make([]func(handler.Executer, string) (bool, error), len(table.indices)+1)
|
||||
executes[0] = execNextIfExists(config, create, opts, true)
|
||||
for i, index := range table.indices {
|
||||
executes[i+1] = execNextIfExists(config, createIndexCheck(index), opts, true)
|
||||
}
|
||||
return &handler.Check{
|
||||
Executes: executes,
|
||||
}
|
||||
}
|
||||
|
||||
func NewMultiTableCheck(primaryTable *Table, secondaryTables ...*SuffixedTable) *handler.Check {
|
||||
config := execConfig{}
|
||||
create := func(config execConfig) string {
|
||||
stmt := createTableStatement(primaryTable, config.tableName, "")
|
||||
for _, table := range secondaryTables {
|
||||
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
return &handler.Check{
|
||||
Executes: []func(handler.Executer, string) (bool, error){
|
||||
execNextIfExists(config, create, nil, true),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler.Check {
|
||||
config := execConfig{}
|
||||
create := func(config execConfig) string {
|
||||
var stmt string
|
||||
for _, table := range secondaryTables {
|
||||
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
|
||||
}
|
||||
stmt += createViewStatement(config.tableName, selectStmt)
|
||||
return stmt
|
||||
}
|
||||
|
||||
return &handler.Check{
|
||||
Executes: []func(handler.Executer, string) (bool, error){
|
||||
execNextIfExists(config, create, nil, false),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) {
|
||||
return func(handler handler.Executer, name string) (bool, error) {
|
||||
err := exec(config, q, opts)(handler, name)
|
||||
if isErrAlreadyExists(err) {
|
||||
return executeNext, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func isErrAlreadyExists(err error) bool {
|
||||
caosErr := &errs.CaosError{}
|
||||
if !errors.As(err, &caosErr) {
|
||||
return false
|
||||
}
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(caosErr.Parent, &pgErr) {
|
||||
return pgErr.Code == "42P07"
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func createTableStatement(table *Table, tableName string, suffix string) string {
|
||||
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
|
||||
tableName+suffix,
|
||||
createColumnsStatement(table.columns, tableName),
|
||||
strings.Join(table.primaryKey, ", "),
|
||||
)
|
||||
for _, key := range table.foreignKeys {
|
||||
ref := tableName
|
||||
if len(key.RefColumns) > 0 {
|
||||
ref += fmt.Sprintf("(%s)", strings.Join(key.RefColumns, ","))
|
||||
}
|
||||
if len(key.Columns) == 0 {
|
||||
key.Columns = table.primaryKey
|
||||
}
|
||||
stmt += fmt.Sprintf(", CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE CASCADE", foreignKeyName(key.Name, tableName, suffix), strings.Join(key.Columns, ","), ref)
|
||||
}
|
||||
for _, constraint := range table.constraints {
|
||||
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraintName(constraint.Name, tableName, suffix), strings.Join(constraint.Columns, ","))
|
||||
}
|
||||
|
||||
stmt += ");"
|
||||
|
||||
for _, index := range table.indices {
|
||||
stmt += createIndexStatement(index, tableName+suffix)
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
func createViewStatement(viewName string, selectStmt string) string {
|
||||
return fmt.Sprintf("CREATE VIEW %s AS %s",
|
||||
viewName,
|
||||
selectStmt,
|
||||
)
|
||||
}
|
||||
|
||||
func createIndexCheck(index *Index) func(config execConfig) string {
|
||||
return func(config execConfig) string {
|
||||
return createIndexStatement(index, config.tableName)
|
||||
}
|
||||
}
|
||||
|
||||
func createIndexStatement(index *Index, tableName string) string {
|
||||
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
indexName(index.Name, tableName),
|
||||
tableName,
|
||||
strings.Join(index.Columns, ","),
|
||||
)
|
||||
if len(index.includes) > 0 {
|
||||
stmt += " INCLUDE (" + strings.Join(index.includes, ", ") + ")"
|
||||
}
|
||||
return stmt + ";"
|
||||
}
|
||||
|
||||
func foreignKeyName(name, tableName, suffix string) string {
|
||||
if name == "" {
|
||||
key := "fk" + suffix + "_ref_" + tableNameWithoutSchema(tableName)
|
||||
return key
|
||||
}
|
||||
return "fk_" + tableNameWithoutSchema(tableName+suffix) + "_" + name
|
||||
}
|
||||
func constraintName(name, tableName, suffix string) string {
|
||||
return tableNameWithoutSchema(tableName+suffix) + "_" + name + "_unique"
|
||||
}
|
||||
func indexName(name, tableName string) string {
|
||||
return tableNameWithoutSchema(tableName) + "_" + name + "_idx"
|
||||
}
|
||||
|
||||
func tableNameWithoutSchema(name string) string {
|
||||
return name[strings.LastIndex(name, ".")+1:]
|
||||
}
|
||||
|
||||
func createColumnsStatement(cols []*InitColumn, tableName string) string {
|
||||
columns := make([]string, len(cols))
|
||||
for i, col := range cols {
|
||||
column := col.Name + " " + columnType(col.Type)
|
||||
if !col.nullable {
|
||||
column += " NOT NULL"
|
||||
}
|
||||
if col.defaultValue != nil {
|
||||
column += " DEFAULT " + defaultValue(col.defaultValue)
|
||||
}
|
||||
if len(col.deleteCascade) != 0 {
|
||||
column += fmt.Sprintf(" REFERENCES %s (%s) ON DELETE CASCADE", tableName, col.deleteCascade)
|
||||
}
|
||||
columns[i] = column
|
||||
}
|
||||
return strings.Join(columns, ",")
|
||||
}
|
||||
|
||||
func defaultValue(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return "'" + v + "'"
|
||||
case fmt.Stringer:
|
||||
return fmt.Sprintf("%#v", v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func columnType(columnType ColumnType) string {
|
||||
switch columnType {
|
||||
case ColumnTypeText:
|
||||
return "TEXT"
|
||||
case ColumnTypeTextArray:
|
||||
return "TEXT[]"
|
||||
case ColumnTypeTimestamp:
|
||||
return "TIMESTAMPTZ"
|
||||
case ColumnTypeInterval:
|
||||
return "INTERVAL"
|
||||
case ColumnTypeEnum:
|
||||
return "SMALLINT"
|
||||
case ColumnTypeEnumArray:
|
||||
return "SMALLINT[]"
|
||||
case ColumnTypeInt64:
|
||||
return "BIGINT"
|
||||
case ColumnTypeBool:
|
||||
return "BOOLEAN"
|
||||
case ColumnTypeJSONB:
|
||||
return "JSONB"
|
||||
case ColumnTypeBytes:
|
||||
return "BYTEA"
|
||||
default:
|
||||
panic("unknown column type")
|
||||
}
|
||||
}
|
||||
23
internal/eventstore/handler/v2/log.go
Normal file
23
internal/eventstore/handler/v2/log.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func (h *Handler) log() *logging.Entry {
|
||||
return logging.WithFields("projection", h.projection.Name())
|
||||
}
|
||||
|
||||
func (h *Handler) logFailure(fail *failure) *logging.Entry {
|
||||
return h.log().WithField("sequence", fail.sequence).
|
||||
WithField("instance", fail.instance).
|
||||
WithField("aggregate", fail.aggregateID)
|
||||
}
|
||||
|
||||
func (h *Handler) logEvent(event eventstore.Event) *logging.Entry {
|
||||
return h.log().WithField("sequence", event.Sequence()).
|
||||
WithField("instance", event.Aggregate().InstanceID).
|
||||
WithField("aggregate", event.Aggregate().Type)
|
||||
}
|
||||
18
internal/eventstore/handler/v2/mock_test.go
Normal file
18
internal/eventstore/handler/v2/mock_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package handler
|
||||
|
||||
var _ Projection = (*projection)(nil)
|
||||
|
||||
type projection struct {
|
||||
name string
|
||||
reducers []AggregateReducer
|
||||
}
|
||||
|
||||
// Name implements Projection
|
||||
func (p *projection) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Reducers implements Projection
|
||||
func (p *projection) Reducers() []AggregateReducer {
|
||||
return p.reducers
|
||||
}
|
||||
21
internal/eventstore/handler/v2/reduce.go
Normal file
21
internal/eventstore/handler/v2/reduce.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package handler
|
||||
|
||||
import "github.com/zitadel/zitadel/internal/eventstore"
|
||||
|
||||
// EventReducer represents the required data
|
||||
// to work with events
|
||||
type EventReducer struct {
|
||||
Event eventstore.EventType
|
||||
Reduce Reduce
|
||||
}
|
||||
|
||||
// Reduce reduces the given event to a statement
|
||||
// which is used to update the projection
|
||||
type Reduce func(eventstore.Event) (*Statement, error)
|
||||
|
||||
// EventReducer represents the required data
|
||||
// to work with aggregates
|
||||
type AggregateReducer struct {
|
||||
Aggregate eventstore.AggregateType
|
||||
EventReducers []EventReducer
|
||||
}
|
||||
119
internal/eventstore/handler/v2/state.go
Normal file
119
internal/eventstore/handler/v2/state.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type state struct {
|
||||
instanceID string
|
||||
position float64
|
||||
eventTimestamp time.Time
|
||||
aggregateType eventstore.AggregateType
|
||||
aggregateID string
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed state_get.sql
|
||||
currentStateStmt string
|
||||
//go:embed state_get_await.sql
|
||||
currentStateAwaitStmt string
|
||||
//go:embed state_set.sql
|
||||
updateStateStmt string
|
||||
//go:embed state_lock.sql
|
||||
lockStateStmt string
|
||||
//go:embed state_set_last_run.sql
|
||||
updateStateLastRunStmt string
|
||||
|
||||
errJustUpdated = errors.New("projection was just updated")
|
||||
)
|
||||
|
||||
func (h *Handler) currentState(ctx context.Context, tx *sql.Tx, config *triggerConfig) (currentState *state, err error) {
|
||||
currentState = &state{
|
||||
instanceID: authz.GetInstance(ctx).InstanceID(),
|
||||
}
|
||||
|
||||
var (
|
||||
aggregateID = new(sql.NullString)
|
||||
aggregateType = new(sql.NullString)
|
||||
sequence = new(sql.NullInt64)
|
||||
timestamp = new(sql.NullTime)
|
||||
position = new(sql.NullFloat64)
|
||||
)
|
||||
|
||||
stateQuery := currentStateStmt
|
||||
if config.awaitRunning {
|
||||
stateQuery = currentStateAwaitStmt
|
||||
}
|
||||
|
||||
row := tx.QueryRow(stateQuery, currentState.instanceID, h.projection.Name())
|
||||
err = row.Scan(
|
||||
aggregateID,
|
||||
aggregateType,
|
||||
sequence,
|
||||
timestamp,
|
||||
position,
|
||||
)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = h.lockState(tx, currentState.instanceID)
|
||||
}
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("unable to query current state")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
currentState.aggregateID = aggregateID.String
|
||||
currentState.aggregateType = eventstore.AggregateType(aggregateType.String)
|
||||
currentState.sequence = uint64(sequence.Int64)
|
||||
currentState.eventTimestamp = timestamp.Time
|
||||
currentState.position = position.Float64
|
||||
return currentState, nil
|
||||
}
|
||||
|
||||
func (h *Handler) setState(tx *sql.Tx, updatedState *state) error {
|
||||
res, err := tx.Exec(updateStateStmt,
|
||||
h.projection.Name(),
|
||||
updatedState.instanceID,
|
||||
updatedState.aggregateID,
|
||||
updatedState.aggregateType,
|
||||
updatedState.sequence,
|
||||
updatedState.eventTimestamp,
|
||||
updatedState.position,
|
||||
)
|
||||
if err != nil {
|
||||
h.log().WithError(err).Debug("unable to update state")
|
||||
return err
|
||||
}
|
||||
if affected, err := res.RowsAffected(); affected == 0 {
|
||||
h.log().OnError(err).Error("unable to check if states are updated")
|
||||
return errs.ThrowInternal(err, "V2-FGEKi", "unable to update state")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *Handler) updateLastUpdated(ctx context.Context, tx *sql.Tx, updatedState *state) {
|
||||
_, err := tx.ExecContext(ctx, updateStateLastRunStmt, h.projection.Name(), updatedState.instanceID)
|
||||
h.log().OnError(err).Debug("unable to update last updated")
|
||||
}
|
||||
|
||||
func (h *Handler) lockState(tx *sql.Tx, instanceID string) error {
|
||||
res, err := tx.Exec(lockStateStmt,
|
||||
h.projection.Name(),
|
||||
instanceID,
|
||||
)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if affected, err := res.RowsAffected(); affected == 0 || err != nil {
|
||||
return errs.ThrowInternal(err, "V2-lpiK0", "projection already locked")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
12
internal/eventstore/handler/v2/state_get.sql
Normal file
12
internal/eventstore/handler/v2/state_get.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
SELECT
|
||||
aggregate_id
|
||||
, aggregate_type
|
||||
, "sequence"
|
||||
, event_date
|
||||
, "position"
|
||||
FROM
|
||||
projections.current_states
|
||||
WHERE
|
||||
instance_id = $1
|
||||
AND projection_name = $2
|
||||
FOR UPDATE NOWAIT;
|
||||
12
internal/eventstore/handler/v2/state_get_await.sql
Normal file
12
internal/eventstore/handler/v2/state_get_await.sql
Normal file
@@ -0,0 +1,12 @@
|
||||
SELECT
|
||||
aggregate_id
|
||||
, aggregate_type
|
||||
, "sequence"
|
||||
, event_date
|
||||
, "position"
|
||||
FROM
|
||||
projections.current_states
|
||||
WHERE
|
||||
instance_id = $1
|
||||
AND projection_name = $2
|
||||
FOR UPDATE;
|
||||
9
internal/eventstore/handler/v2/state_lock.sql
Normal file
9
internal/eventstore/handler/v2/state_lock.sql
Normal file
@@ -0,0 +1,9 @@
|
||||
INSERT INTO projections.current_states (
|
||||
projection_name
|
||||
, instance_id
|
||||
, last_updated
|
||||
) VALUES (
|
||||
$1
|
||||
, $2
|
||||
, now()
|
||||
) ON CONFLICT DO NOTHING;
|
||||
29
internal/eventstore/handler/v2/state_set.sql
Normal file
29
internal/eventstore/handler/v2/state_set.sql
Normal file
@@ -0,0 +1,29 @@
|
||||
INSERT INTO projections.current_states (
|
||||
projection_name
|
||||
, instance_id
|
||||
, aggregate_id
|
||||
, aggregate_type
|
||||
, "sequence"
|
||||
, event_date
|
||||
, "position"
|
||||
, last_updated
|
||||
) VALUES (
|
||||
$1
|
||||
, $2
|
||||
, $3
|
||||
, $4
|
||||
, $5
|
||||
, $6
|
||||
, $7
|
||||
, now()
|
||||
) ON CONFLICT (
|
||||
projection_name
|
||||
, instance_id
|
||||
) DO UPDATE SET
|
||||
aggregate_id = $3
|
||||
, aggregate_type = $4
|
||||
, "sequence" = $5
|
||||
, event_date = $6
|
||||
, "position" = $7
|
||||
, last_updated = statement_timestamp()
|
||||
;
|
||||
2
internal/eventstore/handler/v2/state_set_last_run.sql
Normal file
2
internal/eventstore/handler/v2/state_set_last_run.sql
Normal file
@@ -0,0 +1,2 @@
|
||||
UPDATE projections.current_states SET last_updated = now() WHERE projection_name = $1 AND instance_id = $2;
|
||||
|
||||
447
internal/eventstore/handler/v2/state_test.go
Normal file
447
internal/eventstore/handler/v2/state_test.go
Normal file
@@ -0,0 +1,447 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database/mock"
|
||||
errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
func TestHandler_lockState(t *testing.T) {
|
||||
type fields struct {
|
||||
projection Projection
|
||||
mock *mock.SQLMock
|
||||
}
|
||||
type args struct {
|
||||
instanceID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
isErr func(t *testing.T, err error)
|
||||
}{
|
||||
{
|
||||
name: "tx closed",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(
|
||||
lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecErr(sql.ErrTxDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
instanceID: "instance",
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Errorf("unexpected error, want: %v got: %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows affeced",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(
|
||||
lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecNoRowsAffected(),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
instanceID: "instance",
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, errs.ThrowInternal(nil, "V2-lpiK0", "")) {
|
||||
t.Errorf("unexpected error: want internal (V2lpiK0), got: %v", err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "rows affected",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(
|
||||
lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecRowsAffected(1),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
instanceID: "instance",
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.isErr == nil {
|
||||
tt.isErr = func(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Error("expected no error got:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
projection: tt.fields.projection,
|
||||
}
|
||||
|
||||
tx, err := tt.fields.mock.DB.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
err = h.lockState(tx, tt.args.instanceID)
|
||||
tt.isErr(t, err)
|
||||
|
||||
tt.fields.mock.Assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_updateLastUpdated(t *testing.T) {
|
||||
type fields struct {
|
||||
projection Projection
|
||||
mock *mock.SQLMock
|
||||
}
|
||||
type args struct {
|
||||
updatedState *state
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
isErr func(t *testing.T, err error)
|
||||
}{
|
||||
{
|
||||
name: "update fails",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "instance",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(updateStateStmt,
|
||||
mock.WithExecErr(sql.ErrTxDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
updatedState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: time.Now(),
|
||||
position: 42,
|
||||
},
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no rows affected",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "instance",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(updateStateStmt,
|
||||
mock.WithExecNoRowsAffected(),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
updatedState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: time.Now(),
|
||||
position: 42,
|
||||
},
|
||||
},
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, errs.ThrowInternal(nil, "V2-FGEKi", "")) {
|
||||
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExcpectExec(updateStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
"aggregate id",
|
||||
"aggregate type",
|
||||
uint64(42),
|
||||
mock.AnyType[time.Time]{},
|
||||
float64(42),
|
||||
),
|
||||
mock.WithExecRowsAffected(1),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
updatedState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: time.Now(),
|
||||
position: 42,
|
||||
aggregateType: "aggregate type",
|
||||
aggregateID: "aggregate id",
|
||||
sequence: 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.isErr == nil {
|
||||
tt.isErr = func(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Error("expected no error got:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tx, err := tt.fields.mock.DB.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
h := &Handler{
|
||||
projection: tt.fields.projection,
|
||||
}
|
||||
err = h.setState(tx, tt.args.updatedState)
|
||||
|
||||
tt.isErr(t, err)
|
||||
tt.fields.mock.Assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandler_currentState(t *testing.T) {
|
||||
testTime := time.Now()
|
||||
type fields struct {
|
||||
projection Projection
|
||||
mock *mock.SQLMock
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
}
|
||||
type want struct {
|
||||
currentState *state
|
||||
isErr func(t *testing.T, err error)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "connection done",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryErr(sql.ErrConnDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrConnDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "no row but lock err",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryErr(sql.ErrNoRows),
|
||||
),
|
||||
mock.ExcpectExec(lockStateStmt,
|
||||
mock.WithExecArgs(
|
||||
"projection",
|
||||
"instance",
|
||||
),
|
||||
mock.WithExecErr(sql.ErrTxDone),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
isErr: func(t *testing.T, err error) {
|
||||
if !errors.Is(err, sql.ErrTxDone) {
|
||||
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrTxDone, err)
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "state locked",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryErr(&pgconn.PgError{Code: "55P03"}),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
isErr: func(t *testing.T, err error) {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if !errors.As(err, &pgErr) {
|
||||
t.Errorf("error should be PgErr but was %T", err)
|
||||
return
|
||||
}
|
||||
if pgErr.Code != "55P03" {
|
||||
t.Errorf("expected code 55P03 got: %s", pgErr.Code)
|
||||
}
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
fields: fields{
|
||||
projection: &projection{
|
||||
name: "projection",
|
||||
},
|
||||
mock: mock.NewSQLMock(t,
|
||||
mock.ExpectBegin(nil),
|
||||
mock.ExpectQuery(currentStateStmt,
|
||||
mock.WithQueryArgs(
|
||||
"instance",
|
||||
"projection",
|
||||
),
|
||||
mock.WithQueryResult(
|
||||
[]string{"aggregate_id", "aggregate_type", "event_sequence", "event_date", "position"},
|
||||
[][]driver.Value{
|
||||
{
|
||||
"aggregate id",
|
||||
"aggregate type",
|
||||
int64(42),
|
||||
testTime,
|
||||
float64(42),
|
||||
},
|
||||
},
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance"),
|
||||
},
|
||||
want: want{
|
||||
currentState: &state{
|
||||
instanceID: "instance",
|
||||
eventTimestamp: testTime,
|
||||
position: 42,
|
||||
aggregateType: "aggregate type",
|
||||
aggregateID: "aggregate id",
|
||||
sequence: 42,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.want.isErr == nil {
|
||||
tt.want.isErr = func(t *testing.T, err error) {
|
||||
if err != nil {
|
||||
t.Error("expected no error got:", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h := &Handler{
|
||||
projection: tt.fields.projection,
|
||||
}
|
||||
|
||||
tx, err := tt.fields.mock.DB.Begin()
|
||||
if err != nil {
|
||||
t.Fatalf("unable to begin transaction: %v", err)
|
||||
}
|
||||
|
||||
gotCurrentState, err := h.currentState(tt.args.ctx, tx, new(triggerConfig))
|
||||
|
||||
tt.want.isErr(t, err)
|
||||
if !reflect.DeepEqual(gotCurrentState, tt.want.currentState) {
|
||||
t.Errorf("Handler.currentState() gotCurrentState = %v, want %v", gotCurrentState, tt.want.currentState)
|
||||
}
|
||||
tt.fields.mock.Assert(t)
|
||||
})
|
||||
}
|
||||
}
|
||||
551
internal/eventstore/handler/v2/statement.go
Normal file
551
internal/eventstore/handler/v2/statement.go
Normal file
@@ -0,0 +1,551 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
errs "errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func (h *Handler) eventsToStatements(tx *sql.Tx, events []eventstore.Event, currentState *state) (statements []*Statement, err error) {
|
||||
statements = make([]*Statement, 0, len(events))
|
||||
for _, event := range events {
|
||||
statement, err := h.reduce(event)
|
||||
if err != nil {
|
||||
h.logEvent(event).WithError(err).Error("reduce failed")
|
||||
if shouldContinue := h.handleFailedStmt(tx, currentState, failureFromEvent(event, err)); shouldContinue {
|
||||
continue
|
||||
}
|
||||
return statements, err
|
||||
}
|
||||
statements = append(statements, statement)
|
||||
}
|
||||
return statements, nil
|
||||
}
|
||||
|
||||
func (h *Handler) reduce(event eventstore.Event) (*Statement, error) {
|
||||
for _, reducer := range h.projection.Reducers() {
|
||||
if reducer.Aggregate != event.Aggregate().Type {
|
||||
continue
|
||||
}
|
||||
for _, reduce := range reducer.EventReducers {
|
||||
if reduce.Event != event.Type() {
|
||||
continue
|
||||
}
|
||||
return reduce.Reduce(event)
|
||||
}
|
||||
}
|
||||
return NewNoOpStatement(event), nil
|
||||
}
|
||||
|
||||
type Statement struct {
|
||||
AggregateType eventstore.AggregateType
|
||||
AggregateID string
|
||||
Sequence uint64
|
||||
Position float64
|
||||
CreationDate time.Time
|
||||
InstanceID string
|
||||
|
||||
Execute Exec
|
||||
}
|
||||
|
||||
type Exec func(ex Executer, projectionName string) error
|
||||
|
||||
func WithTableSuffix(name string) func(*execConfig) {
|
||||
return func(o *execConfig) {
|
||||
o.tableName += "_" + name
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
ErrNoProjection = errs.New("no projection")
|
||||
ErrNoValues = errs.New("no values")
|
||||
ErrNoCondition = errs.New("no condition")
|
||||
)
|
||||
|
||||
func NewStatement(event eventstore.Event, e Exec) *Statement {
|
||||
return &Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
Position: event.Position(),
|
||||
AggregateID: event.Aggregate().ID,
|
||||
CreationDate: event.CreatedAt(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: e,
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateStatement(event eventstore.Event, values []Column, opts ...execOption) *Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
columnNames := strings.Join(cols, ", ")
|
||||
valuesPlaceholder := strings.Join(params, ", ")
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "INSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
|
||||
}
|
||||
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func NewUpsertStatement(event eventstore.Event, conflictCols []Column, values []Column, opts ...execOption) *Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
|
||||
conflictTarget := make([]string, len(conflictCols))
|
||||
for i, col := range conflictCols {
|
||||
conflictTarget[i] = col.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
|
||||
if len(updateCols) == 0 || len(updateVals) == 0 {
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
var updateStmt string
|
||||
// the postgres standard does not allow to update a single column using a multi-column update
|
||||
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
|
||||
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
|
||||
if len(updateCols) == 1 && !strings.HasPrefix(updateVals[0], "SELECT") {
|
||||
updateStmt = "UPDATE SET " + updateCols[0] + " = " + updateVals[0]
|
||||
} else {
|
||||
updateStmt = "UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
|
||||
}
|
||||
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
|
||||
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO " + updateStmt
|
||||
}
|
||||
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
|
||||
updateCols = make([]string, len(cols))
|
||||
updateVals = make([]string, len(cols))
|
||||
|
||||
copy(updateCols, cols)
|
||||
|
||||
for i := len(updateCols) - 1; i >= 0; i-- {
|
||||
updateVals[i] = "EXCLUDED." + updateCols[i]
|
||||
|
||||
for _, conflict := range conflictTarget {
|
||||
if conflict == updateCols[i] {
|
||||
copy(updateCols[i:], updateCols[i+1:])
|
||||
updateCols[len(updateCols)-1] = ""
|
||||
updateCols = updateCols[:len(updateCols)-1]
|
||||
|
||||
copy(updateVals[i:], updateVals[i+1:])
|
||||
updateVals[len(updateVals)-1] = ""
|
||||
updateVals = updateVals[:len(updateVals)-1]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return updateCols, updateVals
|
||||
}
|
||||
|
||||
func NewUpdateStatement(event eventstore.Event, values []Column, conditions []Condition, opts ...execOption) *Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
wheres, whereArgs := conditionsToWhere(conditions, len(args)+1)
|
||||
args = append(args, whereArgs...)
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
config.err = ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
// the postgres standard does not allow to update a single column using a multi-column update
|
||||
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
|
||||
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
|
||||
if len(cols) == 1 && !strings.HasPrefix(params[0], "SELECT") {
|
||||
return "UPDATE " + config.tableName + " SET " + cols[0] + " = " + params[0] + " WHERE " + strings.Join(wheres, " AND ")
|
||||
}
|
||||
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
|
||||
}
|
||||
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func NewDeleteStatement(event eventstore.Event, conditions []Condition, opts ...execOption) *Statement {
|
||||
wheres, args := conditionsToWhere(conditions, 1)
|
||||
|
||||
wheresPlaceholders := strings.Join(wheres, " AND ")
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
config.err = ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "DELETE FROM " + config.tableName + " WHERE " + wheresPlaceholders
|
||||
}
|
||||
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func NewNoOpStatement(event eventstore.Event) *Statement {
|
||||
return NewStatement(event, nil)
|
||||
}
|
||||
|
||||
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *Statement {
|
||||
if len(opts) == 0 {
|
||||
return NewNoOpStatement(event)
|
||||
}
|
||||
execs := make([]Exec, len(opts))
|
||||
for i, opt := range opts {
|
||||
execs[i] = opt(event)
|
||||
}
|
||||
return NewStatement(event, multiExec(execs))
|
||||
}
|
||||
|
||||
func AddNoOpStatement() func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewNoOpStatement(event).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddCreateStatement(columns []Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewCreateStatement(event, columns, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpsertStatement(indexCols []Column, values []Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpsertStatement(event, indexCols, values, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpdateStatement(values []Column, conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpdateStatement(event, values, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddDeleteStatement(conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewDeleteStatement(event, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddCopyStatement(conflict, from, to []Column, conditions []NamespacedCondition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayAppendCol(column string, value interface{}) Column {
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
return "array_append(" + column + ", " + placeholder + ")"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayRemoveCol(column string, value interface{}) Column {
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
return "array_remove(" + column + ", " + placeholder + ")"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayIntersectCol(column string, value interface{}) Column {
|
||||
var arrayType string
|
||||
switch value.(type) {
|
||||
|
||||
case []string, database.TextArray[string]:
|
||||
arrayType = "TEXT"
|
||||
//TODO: handle more types if necessary
|
||||
}
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
return "SELECT ARRAY( SELECT UNNEST(" + column + ") INTERSECT SELECT UNNEST (" + placeholder + "::" + arrayType + "[]))"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCopyCol(column, from string) Column {
|
||||
return Column{
|
||||
Name: column,
|
||||
Value: NewCol(from, nil),
|
||||
}
|
||||
}
|
||||
|
||||
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
|
||||
// cols represent the columns which are objective to change.
|
||||
// if the value of a col is empty the data will be copied from the selected row
|
||||
// if the value of a col is not empty the data will be set by the static value
|
||||
// conds represent the conditions for the selection subquery
|
||||
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []Column, nsCond []NamespacedCondition, opts ...execOption) *Statement {
|
||||
columnNames := make([]string, len(to))
|
||||
selectColumns := make([]string, len(from))
|
||||
updateColumns := make([]string, len(columnNames))
|
||||
argCounter := 0
|
||||
args := []interface{}{}
|
||||
|
||||
for i, col := range from {
|
||||
columnNames[i] = to[i].Name
|
||||
selectColumns[i] = from[i].Name
|
||||
updateColumns[i] = "EXCLUDED." + col.Name
|
||||
if col.Value != nil {
|
||||
argCounter++
|
||||
selectColumns[i] = "$" + strconv.Itoa(argCounter)
|
||||
updateColumns[i] = selectColumns[i]
|
||||
args = append(args, col.Value)
|
||||
}
|
||||
|
||||
}
|
||||
cond := make([]Condition, len(nsCond))
|
||||
for i := range nsCond {
|
||||
cond[i] = nsCond[i]("copy_table")
|
||||
}
|
||||
wheres, values := conditionsToWhere(cond, len(args)+1)
|
||||
args = append(args, values...)
|
||||
|
||||
conflictTargets := make([]string, len(conflictCols))
|
||||
for i, conflictCol := range conflictCols {
|
||||
conflictTargets[i] = conflictCol.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(from) == 0 || len(to) == 0 || len(from) != len(to) {
|
||||
config.err = ErrNoValues
|
||||
}
|
||||
|
||||
if len(cond) == 0 {
|
||||
config.err = ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "INSERT INTO " +
|
||||
config.tableName +
|
||||
" (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
") SELECT " +
|
||||
strings.Join(selectColumns, ", ") +
|
||||
" FROM " +
|
||||
config.tableName + " AS copy_table WHERE " +
|
||||
strings.Join(wheres, " AND ") +
|
||||
" ON CONFLICT (" +
|
||||
strings.Join(conflictTargets, ", ") +
|
||||
") DO UPDATE SET (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
") = (" +
|
||||
strings.Join(updateColumns, ", ") +
|
||||
")"
|
||||
}
|
||||
|
||||
return NewStatement(event, exec(config, q, opts))
|
||||
}
|
||||
|
||||
func columnsToQuery(cols []Column) (names []string, parameters []string, values []interface{}) {
|
||||
names = make([]string, len(cols))
|
||||
values = make([]interface{}, len(cols))
|
||||
parameters = make([]string, len(cols))
|
||||
var parameterIndex int
|
||||
for i, col := range cols {
|
||||
names[i] = col.Name
|
||||
if c, ok := col.Value.(Column); ok {
|
||||
parameters[i] = c.Name
|
||||
continue
|
||||
} else {
|
||||
values[parameterIndex] = col.Value
|
||||
}
|
||||
parameters[i] = "$" + strconv.Itoa(parameterIndex+1)
|
||||
if col.ParameterOpt != nil {
|
||||
parameters[i] = col.ParameterOpt(parameters[i])
|
||||
}
|
||||
parameterIndex++
|
||||
}
|
||||
return names, parameters, values[:parameterIndex]
|
||||
}
|
||||
|
||||
func conditionsToWhere(conds []Condition, paramOffset int) (wheres []string, values []interface{}) {
|
||||
wheres = make([]string, len(conds))
|
||||
values = make([]any, 0, len(conds))
|
||||
|
||||
for i, cond := range conds {
|
||||
var args []any
|
||||
wheres[i], args = cond("$" + strconv.Itoa(paramOffset))
|
||||
paramOffset += len(args)
|
||||
values = append(values, args...)
|
||||
wheres[i] = "(" + wheres[i] + ")"
|
||||
}
|
||||
|
||||
return wheres, values
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
Name string
|
||||
Value interface{}
|
||||
ParameterOpt func(string) string
|
||||
}
|
||||
|
||||
func NewCol(name string, value interface{}) Column {
|
||||
return Column{
|
||||
Name: name,
|
||||
Value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func NewJSONCol(name string, value interface{}) Column {
|
||||
marshalled, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
logging.WithFields("column", name).WithError(err).Panic("unable to marshal column")
|
||||
}
|
||||
|
||||
return NewCol(name, marshalled)
|
||||
}
|
||||
|
||||
type Condition func(param string) (string, []any)
|
||||
|
||||
type NamespacedCondition func(namespace string) Condition
|
||||
|
||||
func NewCond(name string, value interface{}) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
return name + " = " + param, []any{value}
|
||||
}
|
||||
}
|
||||
|
||||
func NewNamespacedCondition(name string, value interface{}) NamespacedCondition {
|
||||
return func(namespace string) Condition {
|
||||
return NewCond(namespace+"."+name, value)
|
||||
}
|
||||
}
|
||||
|
||||
func NewLessThanCond(column string, value interface{}) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
return column + " < " + param, []any{value}
|
||||
}
|
||||
}
|
||||
|
||||
func NewIsNullCond(column string) Condition {
|
||||
return func(string) (string, []any) {
|
||||
return column + " IS NULL", nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewTextArrayContainsCond returns a Condition that checks if the column that stores an array of text contains the given value
|
||||
func NewTextArrayContainsCond(column string, value string) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
return column + " @> " + param, []any{database.TextArray[string]{value}}
|
||||
}
|
||||
}
|
||||
|
||||
// Not is a function and not a method, so that calling it is well readable
|
||||
// For example conditions := []Condition{ Not(NewTextArrayContainsCond())}
|
||||
func Not(condition Condition) Condition {
|
||||
return func(param string) (string, []any) {
|
||||
cond, value := condition(param)
|
||||
return "NOT (" + cond + ")", value
|
||||
}
|
||||
}
|
||||
|
||||
type Executer interface {
|
||||
Exec(string, ...interface{}) (sql.Result, error)
|
||||
}
|
||||
|
||||
type execOption func(*execConfig)
|
||||
type execConfig struct {
|
||||
tableName string
|
||||
|
||||
args []interface{}
|
||||
err error
|
||||
}
|
||||
|
||||
type query func(config execConfig) string
|
||||
|
||||
func exec(config execConfig, q query, opts []execOption) Exec {
|
||||
return func(ex Executer, projectionName string) (err error) {
|
||||
if projectionName == "" {
|
||||
return ErrNoProjection
|
||||
}
|
||||
|
||||
if config.err != nil {
|
||||
return config.err
|
||||
}
|
||||
|
||||
config.tableName = projectionName
|
||||
for _, opt := range opts {
|
||||
opt(&config)
|
||||
}
|
||||
|
||||
_, err = ex.Exec("SAVEPOINT stmt_exec")
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-YdOXD", "create savepoint failed")
|
||||
}
|
||||
defer func() {
|
||||
if err != nil {
|
||||
_, rollbackErr := ex.Exec("ROLLBACK TO SAVEPOINT stmt_exec")
|
||||
logging.OnError(rollbackErr).Debug("rollback failed")
|
||||
return
|
||||
}
|
||||
_, err = ex.Exec("RELEASE SAVEPOINT stmt_exec")
|
||||
}()
|
||||
_, err = ex.Exec(q(config), config.args...)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func multiExec(execList []Exec) Exec {
|
||||
return func(ex Executer, projectionName string) error {
|
||||
for _, exec := range execList {
|
||||
if exec == nil {
|
||||
continue
|
||||
}
|
||||
if err := exec(ex, projectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
1633
internal/eventstore/handler/v2/statement_test.go
Normal file
1633
internal/eventstore/handler/v2/statement_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user