feat(eventstore): increase parallel write capabilities (#5940)

This implementation increases parallel write capabilities of the eventstore.
Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and  [06](https://zitadel.com/docs/support/advisory/a10006).
The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`.
If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
Silvan
2023-10-19 12:19:10 +02:00
committed by GitHub
parent 259faba3f0
commit b5564572bc
791 changed files with 30326 additions and 43202 deletions

View File

@@ -0,0 +1,52 @@
package handler
import (
"context"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
schedulerSucceeded = eventstore.EventType("system.projections.scheduler.succeeded")
aggregateType = eventstore.AggregateType("system")
aggregateID = "SYSTEM"
)
func (h *Handler) didProjectionInitialize(ctx context.Context) bool {
events, err := h.es.Filter(ctx, eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
InstanceID("").
AddQuery().
AggregateTypes(aggregateType).
AggregateIDs(aggregateID).
EventTypes(schedulerSucceeded).
EventData(map[string]interface{}{
"name": h.projection.Name(),
}).
Builder(),
)
return len(events) > 0 && err == nil
}
func (h *Handler) setSucceededOnce(ctx context.Context) error {
_, err := h.es.Push(ctx, &ProjectionSucceededEvent{
BaseEvent: *eventstore.NewBaseEventForPush(ctx,
eventstore.NewAggregate(ctx, aggregateID, aggregateType, "v1"),
schedulerSucceeded,
),
Name: h.projection.Name(),
})
return err
}
type ProjectionSucceededEvent struct {
eventstore.BaseEvent `json:"-"`
Name string `json:"name"`
}
func (p *ProjectionSucceededEvent) Payload() interface{} {
return p
}
func (p *ProjectionSucceededEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}

View File

@@ -0,0 +1,95 @@
package handler
import (
"database/sql"
_ "embed"
"time"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed failed_event_set.sql
setFailedEventStmt string
//go:embed failed_event_get_count.sql
failureCountStmt string
)
type failure struct {
sequence uint64
instance string
aggregateID string
aggregateType eventstore.AggregateType
eventDate time.Time
err error
}
func failureFromEvent(event eventstore.Event, err error) *failure {
return &failure{
sequence: event.Sequence(),
instance: event.Aggregate().InstanceID,
aggregateID: event.Aggregate().ID,
aggregateType: event.Aggregate().Type,
eventDate: event.CreatedAt(),
err: err,
}
}
func failureFromStatement(statement *Statement, err error) *failure {
return &failure{
sequence: statement.Sequence,
instance: statement.InstanceID,
aggregateID: statement.AggregateID,
aggregateType: statement.AggregateType,
eventDate: statement.CreationDate,
err: err,
}
}
func (h *Handler) handleFailedStmt(tx *sql.Tx, currentState *state, f *failure) (shouldContinue bool) {
failureCount, err := h.failureCount(tx, f)
if err != nil {
h.logFailure(f).WithError(err).Warn("unable to get failure count")
return false
}
failureCount += 1
err = h.setFailureCount(tx, failureCount, f)
h.logFailure(f).OnError(err).Warn("unable to update failure count")
return failureCount >= h.maxFailureCount
}
func (h *Handler) failureCount(tx *sql.Tx, f *failure) (count uint8, err error) {
row := tx.QueryRow(failureCountStmt,
h.projection.Name(),
f.instance,
f.aggregateType,
f.aggregateID,
f.sequence,
)
if err = row.Err(); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
}
if err = row.Scan(&count); err != nil {
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
}
return count, nil
}
func (h *Handler) setFailureCount(tx *sql.Tx, count uint8, f *failure) error {
_, err := tx.Exec(setFailedEventStmt,
h.projection.Name(),
f.instance,
f.aggregateType,
f.aggregateID,
f.eventDate,
f.sequence,
count,
f.err.Error(),
)
if err != nil {
return errors.ThrowInternal(err, "CRDB-4Ht4x", "set failure count failed")
}
return nil
}

View File

@@ -0,0 +1,12 @@
WITH failures AS (
SELECT
failure_count
FROM
projections.failed_events2
WHERE
projection_name = $1
AND instance_id = $2
AND aggregate_type = $3
AND aggregate_id = $4
AND failed_sequence = $5
) SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count

View File

@@ -0,0 +1,31 @@
INSERT INTO projections.failed_events2 (
projection_name
, instance_id
, aggregate_type
, aggregate_id
, event_creation_date
, failed_sequence
, failure_count
, error
, last_failed
) VALUES (
$1
, $2
, $3
, $4
, $5
, $6
, $7
, $8
, now()
) ON CONFLICT (
projection_name
, aggregate_type
, aggregate_id
, failed_sequence
, instance_id
) DO UPDATE SET
failure_count = EXCLUDED.failure_count
, error = EXCLUDED.error
, last_failed = EXCLUDED.last_failed
;

View File

@@ -0,0 +1,465 @@
package handler
import (
"context"
"database/sql"
"errors"
"math"
"sync"
"time"
"github.com/jackc/pgconn"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/pseudo"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type EventStore interface {
InstanceIDs(ctx context.Context, maxAge time.Duration, forceLoad bool, query *eventstore.SearchQueryBuilder) ([]string, error)
Filter(ctx context.Context, queryFactory *eventstore.SearchQueryBuilder) ([]eventstore.Event, error)
Push(ctx context.Context, cmds ...eventstore.Command) ([]eventstore.Event, error)
}
type Config struct {
Client *database.DB
Eventstore EventStore
BulkLimit uint16
RequeueEvery time.Duration
RetryFailedAfter time.Duration
HandleActiveInstances time.Duration
TransactionDuration time.Duration
MaxFailureCount uint8
TriggerWithoutEvents Reduce
}
type Handler struct {
client *database.DB
projection Projection
es EventStore
bulkLimit uint16
eventTypes map[eventstore.AggregateType][]eventstore.EventType
maxFailureCount uint8
retryFailedAfter time.Duration
requeueEvery time.Duration
handleActiveInstances time.Duration
txDuration time.Duration
now nowFunc
triggeredInstancesSync sync.Map
triggerWithoutEvents Reduce
}
// nowFunc makes [time.Now] mockable
type nowFunc func() time.Time
type Projection interface {
Name() string
Reducers() []AggregateReducer
}
func NewHandler(
ctx context.Context,
config *Config,
projection Projection,
) *Handler {
aggregates := make(map[eventstore.AggregateType][]eventstore.EventType, len(projection.Reducers()))
for _, reducer := range projection.Reducers() {
eventTypes := make([]eventstore.EventType, len(reducer.EventReducers))
for i, eventReducer := range reducer.EventReducers {
eventTypes[i] = eventReducer.Event
}
if _, ok := aggregates[reducer.Aggregate]; ok {
aggregates[reducer.Aggregate] = append(aggregates[reducer.Aggregate], eventTypes...)
continue
}
aggregates[reducer.Aggregate] = eventTypes
}
handler := &Handler{
projection: projection,
client: config.Client,
es: config.Eventstore,
bulkLimit: config.BulkLimit,
eventTypes: aggregates,
requeueEvery: config.RequeueEvery,
handleActiveInstances: config.HandleActiveInstances,
now: time.Now,
maxFailureCount: config.MaxFailureCount,
retryFailedAfter: config.RetryFailedAfter,
triggeredInstancesSync: sync.Map{},
triggerWithoutEvents: config.TriggerWithoutEvents,
txDuration: config.TransactionDuration,
}
return handler
}
func (h *Handler) Start(ctx context.Context) {
go h.schedule(ctx)
if h.triggerWithoutEvents != nil {
return
}
go h.subscribe(ctx)
}
func (h *Handler) schedule(ctx context.Context) {
// if there was no run before trigger instantly
t := time.NewTimer(0)
didInitialize := h.didProjectionInitialize(ctx)
if didInitialize {
t.Reset(h.requeueEvery)
}
for {
select {
case <-ctx.Done():
t.Stop()
return
case <-t.C:
instances, err := h.queryInstances(ctx, didInitialize)
h.log().OnError(err).Debug("unable to query instances")
var instanceFailed bool
scheduledCtx := call.WithTimestamp(ctx)
for _, instance := range instances {
instanceCtx := authz.WithInstanceID(scheduledCtx, instance)
// simple implementation of do while
_, err = h.Trigger(instanceCtx)
instanceFailed = instanceFailed || err != nil
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
// retry if trigger failed
for ; err != nil; _, err = h.Trigger(instanceCtx) {
time.Sleep(h.retryFailedAfter)
instanceFailed = instanceFailed || err != nil
h.log().WithField("instance", instance).OnError(err).Info("scheduled trigger failed")
if err == nil {
break
}
}
}
if !didInitialize && !instanceFailed {
err = h.setSucceededOnce(ctx)
h.log().OnError(err).Debug("unable to set succeeded once")
didInitialize = err == nil
}
t.Reset(h.requeueEvery)
}
}
}
func (h *Handler) subscribe(ctx context.Context) {
queue := make(chan eventstore.Event, 100)
subscription := eventstore.SubscribeEventTypes(queue, h.eventTypes)
for {
select {
case <-ctx.Done():
subscription.Unsubscribe()
h.log().Debug("shutdown")
return
case event := <-queue:
events := checkAdditionalEvents(queue, event)
solvedInstances := make([]string, 0, len(events))
queueCtx := call.WithTimestamp(ctx)
for _, e := range events {
if instanceSolved(solvedInstances, e.Aggregate().InstanceID) {
continue
}
queueCtx = authz.WithInstanceID(queueCtx, e.Aggregate().InstanceID)
_, err := h.Trigger(queueCtx)
h.log().OnError(err).Debug("trigger of queued event failed")
if err == nil {
solvedInstances = append(solvedInstances, e.Aggregate().InstanceID)
}
}
}
}
}
func instanceSolved(solvedInstances []string, instanceID string) bool {
for _, solvedInstance := range solvedInstances {
if solvedInstance == instanceID {
return true
}
}
return false
}
func checkAdditionalEvents(eventQueue chan eventstore.Event, event eventstore.Event) []eventstore.Event {
events := make([]eventstore.Event, 1)
events[0] = event
for {
wait := time.NewTimer(1 * time.Millisecond)
select {
case event := <-eventQueue:
events = append(events, event)
case <-wait.C:
return events
}
}
}
func (h *Handler) queryInstances(ctx context.Context, didInitialize bool) ([]string, error) {
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsInstanceIDs).
AwaitOpenTransactions().
AllowTimeTravel().
ExcludedInstanceID("")
if didInitialize {
query = query.
CreationDateAfter(h.now().Add(-1 * h.handleActiveInstances))
}
return h.es.InstanceIDs(ctx, h.requeueEvery, !didInitialize, query)
}
type triggerConfig struct {
awaitRunning bool
}
type triggerOpt func(conf *triggerConfig)
func WithAwaitRunning() triggerOpt {
return func(conf *triggerConfig) {
conf.awaitRunning = true
}
}
func (h *Handler) Trigger(ctx context.Context, opts ...triggerOpt) (_ context.Context, err error) {
if authz.GetInstance(ctx).InstanceID() != "" {
var span *tracing.Span
ctx, span = tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
}
config := new(triggerConfig)
for _, opt := range opts {
opt(config)
}
cancel := h.lockInstance(ctx, config)
if cancel == nil {
return call.ResetTimestamp(ctx), nil
}
defer cancel()
for i := 0; ; i++ {
additionalIteration, err := h.processEvents(ctx, config)
h.log().OnError(err).Warn("process events failed")
h.log().WithField("iteration", i).Debug("trigger iteration")
if !additionalIteration || err != nil {
return call.ResetTimestamp(ctx), err
}
}
}
// lockInstances tries to lock the instance.
// If the instance is already locked from another process no cancel function is returned
// the instance can be skipped then
// If the instance is locked, an unlock deferable function is returned
func (h *Handler) lockInstance(ctx context.Context, config *triggerConfig) func() {
instanceID := authz.GetInstance(ctx).InstanceID()
// Check that the instance has a mutex to lock
instanceMu, _ := h.triggeredInstancesSync.LoadOrStore(instanceID, new(sync.Mutex))
unlock := func() {
instanceMu.(*sync.Mutex).Unlock()
}
if !instanceMu.(*sync.Mutex).TryLock() {
instanceMu.(*sync.Mutex).Lock()
if config.awaitRunning {
return unlock
}
defer unlock()
return nil
}
return unlock
}
func (h *Handler) processEvents(ctx context.Context, config *triggerConfig) (additionalIteration bool, err error) {
defer func() {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// error returned if the row is currently locked by another connection
if pgErr.Code == "55P03" {
h.log().Debug("state already locked")
err = nil
additionalIteration = false
}
}
}()
if h.txDuration > 0 {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, h.txDuration)
defer cancel()
}
tx, err := h.client.Begin()
if err != nil {
return false, err
}
defer func() {
if err != nil {
rollbackErr := tx.Rollback()
h.log().OnError(rollbackErr).Debug("unable to rollback tx")
return
}
err = tx.Commit()
}()
currentState, err := h.currentState(ctx, tx, config)
if err != nil {
if errors.Is(err, errJustUpdated) {
return false, nil
}
return additionalIteration, err
}
var statements []*Statement
statements, additionalIteration, err = h.generateStatements(ctx, tx, currentState)
if err != nil || len(statements) == 0 {
return additionalIteration, err
}
lastProcessedIndex, err := h.executeStatements(ctx, tx, currentState, statements)
if lastProcessedIndex < 0 {
return false, err
}
currentState.position = statements[lastProcessedIndex].Position
currentState.aggregateID = statements[lastProcessedIndex].AggregateID
currentState.aggregateType = statements[lastProcessedIndex].AggregateType
currentState.sequence = statements[lastProcessedIndex].Sequence
currentState.eventTimestamp = statements[lastProcessedIndex].CreationDate
err = h.setState(tx, currentState)
return additionalIteration, err
}
func (h *Handler) generateStatements(ctx context.Context, tx *sql.Tx, currentState *state) (_ []*Statement, additionalIteration bool, err error) {
if h.triggerWithoutEvents != nil {
stmt, err := h.triggerWithoutEvents(pseudo.NewScheduledEvent(ctx, time.Now(), currentState.instanceID))
if err != nil {
return nil, false, err
}
return []*Statement{stmt}, false, nil
}
events, err := h.es.Filter(ctx, h.eventQuery(currentState))
if err != nil {
h.log().WithError(err).Debug("filter eventstore failed")
return nil, false, err
}
eventAmount := len(events)
events = skipPreviouslyReduced(events, currentState)
if len(events) == 0 {
h.updateLastUpdated(ctx, tx, currentState)
return nil, false, nil
}
statements, err := h.eventsToStatements(tx, events, currentState)
if len(statements) == 0 {
return nil, false, err
}
additionalIteration = eventAmount == int(h.bulkLimit)
if len(statements) < len(events) {
// retry imediatly if statements failed
additionalIteration = true
}
return statements, additionalIteration, nil
}
func skipPreviouslyReduced(events []eventstore.Event, currentState *state) []eventstore.Event {
for i, event := range events {
if event.Position() == currentState.position &&
event.Aggregate().ID == currentState.aggregateID &&
event.Aggregate().Type == currentState.aggregateType &&
event.Sequence() == currentState.sequence {
return events[i+1:]
}
}
return events
}
func (h *Handler) executeStatements(ctx context.Context, tx *sql.Tx, currentState *state, statements []*Statement) (lastProcessedIndex int, err error) {
lastProcessedIndex = -1
for i, statement := range statements {
select {
case <-ctx.Done():
break
default:
err := h.executeStatement(ctx, tx, currentState, statement)
if err != nil {
return lastProcessedIndex, err
}
lastProcessedIndex = i
}
}
return lastProcessedIndex, nil
}
func (h *Handler) executeStatement(ctx context.Context, tx *sql.Tx, currentState *state, statement *Statement) (err error) {
if statement.Execute == nil {
return nil
}
_, err = tx.Exec("SAVEPOINT exec")
if err != nil {
h.log().WithError(err).Debug("create savepoint failed")
return err
}
var shouldContinue bool
defer func() {
_, err = tx.Exec("RELEASE SAVEPOINT exec")
}()
if err = statement.Execute(tx, h.projection.Name()); err != nil {
h.log().WithError(err).Error("statement execution failed")
shouldContinue = h.handleFailedStmt(tx, currentState, failureFromStatement(statement, err))
if shouldContinue {
return nil
}
return err
}
return nil
}
func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder {
builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AwaitOpenTransactions().
Limit(uint64(h.bulkLimit)).
AllowTimeTravel().
OrderAsc().
InstanceID(currentState.instanceID)
if currentState.position > 0 {
builder = builder.PositionAfter(math.Float64frombits(math.Float64bits(currentState.position) - 10))
}
for aggregateType, eventTypes := range h.eventTypes {
query := builder.
AddQuery().
AggregateTypes(aggregateType).
EventTypes(eventTypes...)
builder = query.Builder()
}
return builder
}

View File

@@ -0,0 +1,413 @@
package handler
import (
"context"
"errors"
"fmt"
"strings"
"github.com/jackc/pgconn"
"github.com/zitadel/logging"
errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
type Table struct {
columns []*InitColumn
primaryKey PrimaryKey
indices []*Index
constraints []*Constraint
foreignKeys []*ForeignKey
}
func NewTable(columns []*InitColumn, key PrimaryKey, opts ...TableOption) *Table {
t := &Table{
columns: columns,
primaryKey: key,
}
for _, opt := range opts {
opt(t)
}
return t
}
type SuffixedTable struct {
Table
suffix string
}
func NewSuffixedTable(columns []*InitColumn, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
return &SuffixedTable{
Table: *NewTable(columns, key, opts...),
suffix: suffix,
}
}
type TableOption func(*Table)
func WithIndex(index *Index) TableOption {
return func(table *Table) {
table.indices = append(table.indices, index)
}
}
func WithConstraint(constraint *Constraint) TableOption {
return func(table *Table) {
table.constraints = append(table.constraints, constraint)
}
}
func WithForeignKey(key *ForeignKey) TableOption {
return func(table *Table) {
table.foreignKeys = append(table.foreignKeys, key)
}
}
type InitColumn struct {
Name string
Type ColumnType
nullable bool
defaultValue interface{}
deleteCascade string
}
type ColumnOption func(*InitColumn)
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *InitColumn {
column := &InitColumn{
Name: name,
Type: columnType,
nullable: false,
defaultValue: nil,
}
for _, opt := range opts {
opt(column)
}
return column
}
func Nullable() ColumnOption {
return func(c *InitColumn) {
c.nullable = true
}
}
func Default(value interface{}) ColumnOption {
return func(c *InitColumn) {
c.defaultValue = value
}
}
func DeleteCascade(column string) ColumnOption {
return func(c *InitColumn) {
c.deleteCascade = column
}
}
type PrimaryKey []string
func NewPrimaryKey(columnNames ...string) PrimaryKey {
return columnNames
}
type ColumnType int32
const (
ColumnTypeText ColumnType = iota
ColumnTypeTextArray
ColumnTypeJSONB
ColumnTypeBytes
ColumnTypeTimestamp
ColumnTypeInterval
ColumnTypeEnum
ColumnTypeEnumArray
ColumnTypeInt64
ColumnTypeBool
)
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
i := &Index{
Name: name,
Columns: columns,
}
for _, opt := range opts {
opt(i)
}
return i
}
type Index struct {
Name string
Columns []string
includes []string
}
type indexOpts func(*Index)
func WithInclude(columns ...string) indexOpts {
return func(i *Index) {
i.includes = columns
}
}
func NewConstraint(name string, columns []string) *Constraint {
i := &Constraint{
Name: name,
Columns: columns,
}
return i
}
type Constraint struct {
Name string
Columns []string
}
func NewForeignKey(name string, columns []string, refColumns []string) *ForeignKey {
i := &ForeignKey{
Name: name,
Columns: columns,
RefColumns: refColumns,
}
return i
}
func NewForeignKeyOfPublicKeys() *ForeignKey {
return &ForeignKey{
Name: "",
}
}
type ForeignKey struct {
Name string
Columns []string
RefColumns []string
}
type initializer interface {
Init() *handler.Check
}
func (h *Handler) Init(ctx context.Context) error {
check, ok := h.projection.(initializer)
if !ok || check.Init().IsNoop() {
return nil
}
tx, err := h.client.BeginTx(ctx, nil)
if err != nil {
return errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
}
for i, execute := range check.Init().Executes {
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("executing check")
next, err := execute(tx, h.projection.Name())
if err != nil {
logging.OnError(tx.Rollback()).Debug("unable to rollback")
return err
}
if !next {
logging.WithFields("projection", h.projection.Name(), "execute", i).Debug("projection set up")
break
}
}
return tx.Commit()
}
func NewTableCheck(table *Table, opts ...execOption) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
return createTableStatement(table, config.tableName, "")
}
executes := make([]func(handler.Executer, string) (bool, error), len(table.indices)+1)
executes[0] = execNextIfExists(config, create, opts, true)
for i, index := range table.indices {
executes[i+1] = execNextIfExists(config, createIndexCheck(index), opts, true)
}
return &handler.Check{
Executes: executes,
}
}
func NewMultiTableCheck(primaryTable *Table, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
stmt := createTableStatement(primaryTable, config.tableName, "")
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, true),
},
}
}
func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler.Check {
config := execConfig{}
create := func(config execConfig) string {
var stmt string
for _, table := range secondaryTables {
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
}
stmt += createViewStatement(config.tableName, selectStmt)
return stmt
}
return &handler.Check{
Executes: []func(handler.Executer, string) (bool, error){
execNextIfExists(config, create, nil, false),
},
}
}
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) {
return func(handler handler.Executer, name string) (bool, error) {
err := exec(config, q, opts)(handler, name)
if isErrAlreadyExists(err) {
return executeNext, nil
}
return false, err
}
}
func isErrAlreadyExists(err error) bool {
caosErr := &errs.CaosError{}
if !errors.As(err, &caosErr) {
return false
}
pgErr := new(pgconn.PgError)
if errors.As(caosErr.Parent, &pgErr) {
return pgErr.Code == "42P07"
}
return false
}
func createTableStatement(table *Table, tableName string, suffix string) string {
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
tableName+suffix,
createColumnsStatement(table.columns, tableName),
strings.Join(table.primaryKey, ", "),
)
for _, key := range table.foreignKeys {
ref := tableName
if len(key.RefColumns) > 0 {
ref += fmt.Sprintf("(%s)", strings.Join(key.RefColumns, ","))
}
if len(key.Columns) == 0 {
key.Columns = table.primaryKey
}
stmt += fmt.Sprintf(", CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE CASCADE", foreignKeyName(key.Name, tableName, suffix), strings.Join(key.Columns, ","), ref)
}
for _, constraint := range table.constraints {
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraintName(constraint.Name, tableName, suffix), strings.Join(constraint.Columns, ","))
}
stmt += ");"
for _, index := range table.indices {
stmt += createIndexStatement(index, tableName+suffix)
}
return stmt
}
func createViewStatement(viewName string, selectStmt string) string {
return fmt.Sprintf("CREATE VIEW %s AS %s",
viewName,
selectStmt,
)
}
func createIndexCheck(index *Index) func(config execConfig) string {
return func(config execConfig) string {
return createIndexStatement(index, config.tableName)
}
}
func createIndexStatement(index *Index, tableName string) string {
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
indexName(index.Name, tableName),
tableName,
strings.Join(index.Columns, ","),
)
if len(index.includes) > 0 {
stmt += " INCLUDE (" + strings.Join(index.includes, ", ") + ")"
}
return stmt + ";"
}
func foreignKeyName(name, tableName, suffix string) string {
if name == "" {
key := "fk" + suffix + "_ref_" + tableNameWithoutSchema(tableName)
return key
}
return "fk_" + tableNameWithoutSchema(tableName+suffix) + "_" + name
}
func constraintName(name, tableName, suffix string) string {
return tableNameWithoutSchema(tableName+suffix) + "_" + name + "_unique"
}
func indexName(name, tableName string) string {
return tableNameWithoutSchema(tableName) + "_" + name + "_idx"
}
func tableNameWithoutSchema(name string) string {
return name[strings.LastIndex(name, ".")+1:]
}
func createColumnsStatement(cols []*InitColumn, tableName string) string {
columns := make([]string, len(cols))
for i, col := range cols {
column := col.Name + " " + columnType(col.Type)
if !col.nullable {
column += " NOT NULL"
}
if col.defaultValue != nil {
column += " DEFAULT " + defaultValue(col.defaultValue)
}
if len(col.deleteCascade) != 0 {
column += fmt.Sprintf(" REFERENCES %s (%s) ON DELETE CASCADE", tableName, col.deleteCascade)
}
columns[i] = column
}
return strings.Join(columns, ",")
}
func defaultValue(value interface{}) string {
switch v := value.(type) {
case string:
return "'" + v + "'"
case fmt.Stringer:
return fmt.Sprintf("%#v", v)
default:
return fmt.Sprintf("%v", v)
}
}
func columnType(columnType ColumnType) string {
switch columnType {
case ColumnTypeText:
return "TEXT"
case ColumnTypeTextArray:
return "TEXT[]"
case ColumnTypeTimestamp:
return "TIMESTAMPTZ"
case ColumnTypeInterval:
return "INTERVAL"
case ColumnTypeEnum:
return "SMALLINT"
case ColumnTypeEnumArray:
return "SMALLINT[]"
case ColumnTypeInt64:
return "BIGINT"
case ColumnTypeBool:
return "BOOLEAN"
case ColumnTypeJSONB:
return "JSONB"
case ColumnTypeBytes:
return "BYTEA"
default:
panic("unknown column type")
}
}

View File

@@ -0,0 +1,23 @@
package handler
import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"
)
func (h *Handler) log() *logging.Entry {
return logging.WithFields("projection", h.projection.Name())
}
func (h *Handler) logFailure(fail *failure) *logging.Entry {
return h.log().WithField("sequence", fail.sequence).
WithField("instance", fail.instance).
WithField("aggregate", fail.aggregateID)
}
func (h *Handler) logEvent(event eventstore.Event) *logging.Entry {
return h.log().WithField("sequence", event.Sequence()).
WithField("instance", event.Aggregate().InstanceID).
WithField("aggregate", event.Aggregate().Type)
}

View File

@@ -0,0 +1,18 @@
package handler
var _ Projection = (*projection)(nil)
type projection struct {
name string
reducers []AggregateReducer
}
// Name implements Projection
func (p *projection) Name() string {
return p.name
}
// Reducers implements Projection
func (p *projection) Reducers() []AggregateReducer {
return p.reducers
}

View File

@@ -0,0 +1,21 @@
package handler
import "github.com/zitadel/zitadel/internal/eventstore"
// EventReducer represents the required data
// to work with events
type EventReducer struct {
Event eventstore.EventType
Reduce Reduce
}
// Reduce reduces the given event to a statement
// which is used to update the projection
type Reduce func(eventstore.Event) (*Statement, error)
// EventReducer represents the required data
// to work with aggregates
type AggregateReducer struct {
Aggregate eventstore.AggregateType
EventReducers []EventReducer
}

View File

@@ -0,0 +1,119 @@
package handler
import (
"context"
"database/sql"
_ "embed"
"errors"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
type state struct {
instanceID string
position float64
eventTimestamp time.Time
aggregateType eventstore.AggregateType
aggregateID string
sequence uint64
}
var (
//go:embed state_get.sql
currentStateStmt string
//go:embed state_get_await.sql
currentStateAwaitStmt string
//go:embed state_set.sql
updateStateStmt string
//go:embed state_lock.sql
lockStateStmt string
//go:embed state_set_last_run.sql
updateStateLastRunStmt string
errJustUpdated = errors.New("projection was just updated")
)
func (h *Handler) currentState(ctx context.Context, tx *sql.Tx, config *triggerConfig) (currentState *state, err error) {
currentState = &state{
instanceID: authz.GetInstance(ctx).InstanceID(),
}
var (
aggregateID = new(sql.NullString)
aggregateType = new(sql.NullString)
sequence = new(sql.NullInt64)
timestamp = new(sql.NullTime)
position = new(sql.NullFloat64)
)
stateQuery := currentStateStmt
if config.awaitRunning {
stateQuery = currentStateAwaitStmt
}
row := tx.QueryRow(stateQuery, currentState.instanceID, h.projection.Name())
err = row.Scan(
aggregateID,
aggregateType,
sequence,
timestamp,
position,
)
if errors.Is(err, sql.ErrNoRows) {
err = h.lockState(tx, currentState.instanceID)
}
if err != nil {
h.log().WithError(err).Debug("unable to query current state")
return nil, err
}
currentState.aggregateID = aggregateID.String
currentState.aggregateType = eventstore.AggregateType(aggregateType.String)
currentState.sequence = uint64(sequence.Int64)
currentState.eventTimestamp = timestamp.Time
currentState.position = position.Float64
return currentState, nil
}
func (h *Handler) setState(tx *sql.Tx, updatedState *state) error {
res, err := tx.Exec(updateStateStmt,
h.projection.Name(),
updatedState.instanceID,
updatedState.aggregateID,
updatedState.aggregateType,
updatedState.sequence,
updatedState.eventTimestamp,
updatedState.position,
)
if err != nil {
h.log().WithError(err).Debug("unable to update state")
return err
}
if affected, err := res.RowsAffected(); affected == 0 {
h.log().OnError(err).Error("unable to check if states are updated")
return errs.ThrowInternal(err, "V2-FGEKi", "unable to update state")
}
return nil
}
func (h *Handler) updateLastUpdated(ctx context.Context, tx *sql.Tx, updatedState *state) {
_, err := tx.ExecContext(ctx, updateStateLastRunStmt, h.projection.Name(), updatedState.instanceID)
h.log().OnError(err).Debug("unable to update last updated")
}
func (h *Handler) lockState(tx *sql.Tx, instanceID string) error {
res, err := tx.Exec(lockStateStmt,
h.projection.Name(),
instanceID,
)
if err != nil {
return err
}
if affected, err := res.RowsAffected(); affected == 0 || err != nil {
return errs.ThrowInternal(err, "V2-lpiK0", "projection already locked")
}
return nil
}

View File

@@ -0,0 +1,12 @@
SELECT
aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
FROM
projections.current_states
WHERE
instance_id = $1
AND projection_name = $2
FOR UPDATE NOWAIT;

View File

@@ -0,0 +1,12 @@
SELECT
aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
FROM
projections.current_states
WHERE
instance_id = $1
AND projection_name = $2
FOR UPDATE;

View File

@@ -0,0 +1,9 @@
INSERT INTO projections.current_states (
projection_name
, instance_id
, last_updated
) VALUES (
$1
, $2
, now()
) ON CONFLICT DO NOTHING;

View File

@@ -0,0 +1,29 @@
INSERT INTO projections.current_states (
projection_name
, instance_id
, aggregate_id
, aggregate_type
, "sequence"
, event_date
, "position"
, last_updated
) VALUES (
$1
, $2
, $3
, $4
, $5
, $6
, $7
, now()
) ON CONFLICT (
projection_name
, instance_id
) DO UPDATE SET
aggregate_id = $3
, aggregate_type = $4
, "sequence" = $5
, event_date = $6
, "position" = $7
, last_updated = statement_timestamp()
;

View File

@@ -0,0 +1,2 @@
UPDATE projections.current_states SET last_updated = now() WHERE projection_name = $1 AND instance_id = $2;

View File

@@ -0,0 +1,447 @@
package handler
import (
"context"
"database/sql"
"database/sql/driver"
_ "embed"
"errors"
"reflect"
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/mock"
errs "github.com/zitadel/zitadel/internal/errors"
)
func TestHandler_lockState(t *testing.T) {
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
instanceID string
}
tests := []struct {
name string
fields fields
args args
isErr func(t *testing.T, err error)
}{
{
name: "tx closed",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
instanceID: "instance",
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v got: %v", sql.ErrTxDone, err)
}
},
},
{
name: "no rows affeced",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecNoRowsAffected(),
),
),
},
args: args{
instanceID: "instance",
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, errs.ThrowInternal(nil, "V2-lpiK0", "")) {
t.Errorf("unexpected error: want internal (V2lpiK0), got: %v", err)
}
},
},
{
name: "rows affected",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(
lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecRowsAffected(1),
),
),
},
args: args{
instanceID: "instance",
},
},
}
for _, tt := range tests {
if tt.isErr == nil {
tt.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
h := &Handler{
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.Begin()
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
err = h.lockState(tx, tt.args.instanceID)
tt.isErr(t, err)
tt.fields.mock.Assert(t)
})
}
}
func TestHandler_updateLastUpdated(t *testing.T) {
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
updatedState *state
}
tests := []struct {
name string
fields fields
args args
isErr func(t *testing.T, err error)
}{
{
name: "update fails",
fields: fields{
projection: &projection{
name: "instance",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: 42,
},
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
}
},
},
{
name: "no rows affected",
fields: fields{
projection: &projection{
name: "instance",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecNoRowsAffected(),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: 42,
},
},
isErr: func(t *testing.T, err error) {
if !errors.Is(err, errs.ThrowInternal(nil, "V2-FGEKi", "")) {
t.Errorf("unexpected error, want: %v, got %v", sql.ErrTxDone, err)
}
},
},
{
name: "success",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExcpectExec(updateStateStmt,
mock.WithExecArgs(
"projection",
"instance",
"aggregate id",
"aggregate type",
uint64(42),
mock.AnyType[time.Time]{},
float64(42),
),
mock.WithExecRowsAffected(1),
),
),
},
args: args{
updatedState: &state{
instanceID: "instance",
eventTimestamp: time.Now(),
position: 42,
aggregateType: "aggregate type",
aggregateID: "aggregate id",
sequence: 42,
},
},
},
}
for _, tt := range tests {
if tt.isErr == nil {
tt.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
tx, err := tt.fields.mock.DB.Begin()
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
h := &Handler{
projection: tt.fields.projection,
}
err = h.setState(tx, tt.args.updatedState)
tt.isErr(t, err)
tt.fields.mock.Assert(t)
})
}
}
func TestHandler_currentState(t *testing.T) {
testTime := time.Now()
type fields struct {
projection Projection
mock *mock.SQLMock
}
type args struct {
ctx context.Context
}
type want struct {
currentState *state
isErr func(t *testing.T, err error)
}
tests := []struct {
name string
fields fields
args args
want want
}{
{
name: "connection done",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(sql.ErrConnDone),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrConnDone) {
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrConnDone, err)
}
},
},
},
{
name: "no row but lock err",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(sql.ErrNoRows),
),
mock.ExcpectExec(lockStateStmt,
mock.WithExecArgs(
"projection",
"instance",
),
mock.WithExecErr(sql.ErrTxDone),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
if !errors.Is(err, sql.ErrTxDone) {
t.Errorf("unexpected error, want: %v, got: %v", sql.ErrTxDone, err)
}
},
},
},
{
name: "state locked",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryErr(&pgconn.PgError{Code: "55P03"}),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
isErr: func(t *testing.T, err error) {
pgErr := new(pgconn.PgError)
if !errors.As(err, &pgErr) {
t.Errorf("error should be PgErr but was %T", err)
return
}
if pgErr.Code != "55P03" {
t.Errorf("expected code 55P03 got: %s", pgErr.Code)
}
},
},
},
{
name: "success",
fields: fields{
projection: &projection{
name: "projection",
},
mock: mock.NewSQLMock(t,
mock.ExpectBegin(nil),
mock.ExpectQuery(currentStateStmt,
mock.WithQueryArgs(
"instance",
"projection",
),
mock.WithQueryResult(
[]string{"aggregate_id", "aggregate_type", "event_sequence", "event_date", "position"},
[][]driver.Value{
{
"aggregate id",
"aggregate type",
int64(42),
testTime,
float64(42),
},
},
),
),
),
},
args: args{
ctx: authz.WithInstanceID(context.Background(), "instance"),
},
want: want{
currentState: &state{
instanceID: "instance",
eventTimestamp: testTime,
position: 42,
aggregateType: "aggregate type",
aggregateID: "aggregate id",
sequence: 42,
},
},
},
}
for _, tt := range tests {
if tt.want.isErr == nil {
tt.want.isErr = func(t *testing.T, err error) {
if err != nil {
t.Error("expected no error got:", err)
}
}
}
t.Run(tt.name, func(t *testing.T) {
h := &Handler{
projection: tt.fields.projection,
}
tx, err := tt.fields.mock.DB.Begin()
if err != nil {
t.Fatalf("unable to begin transaction: %v", err)
}
gotCurrentState, err := h.currentState(tt.args.ctx, tx, new(triggerConfig))
tt.want.isErr(t, err)
if !reflect.DeepEqual(gotCurrentState, tt.want.currentState) {
t.Errorf("Handler.currentState() gotCurrentState = %v, want %v", gotCurrentState, tt.want.currentState)
}
tt.fields.mock.Assert(t)
})
}
}

View File

@@ -0,0 +1,551 @@
package handler
import (
"database/sql"
"encoding/json"
errs "errors"
"strconv"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
func (h *Handler) eventsToStatements(tx *sql.Tx, events []eventstore.Event, currentState *state) (statements []*Statement, err error) {
statements = make([]*Statement, 0, len(events))
for _, event := range events {
statement, err := h.reduce(event)
if err != nil {
h.logEvent(event).WithError(err).Error("reduce failed")
if shouldContinue := h.handleFailedStmt(tx, currentState, failureFromEvent(event, err)); shouldContinue {
continue
}
return statements, err
}
statements = append(statements, statement)
}
return statements, nil
}
func (h *Handler) reduce(event eventstore.Event) (*Statement, error) {
for _, reducer := range h.projection.Reducers() {
if reducer.Aggregate != event.Aggregate().Type {
continue
}
for _, reduce := range reducer.EventReducers {
if reduce.Event != event.Type() {
continue
}
return reduce.Reduce(event)
}
}
return NewNoOpStatement(event), nil
}
type Statement struct {
AggregateType eventstore.AggregateType
AggregateID string
Sequence uint64
Position float64
CreationDate time.Time
InstanceID string
Execute Exec
}
type Exec func(ex Executer, projectionName string) error
func WithTableSuffix(name string) func(*execConfig) {
return func(o *execConfig) {
o.tableName += "_" + name
}
}
var (
ErrNoProjection = errs.New("no projection")
ErrNoValues = errs.New("no values")
ErrNoCondition = errs.New("no condition")
)
func NewStatement(event eventstore.Event, e Exec) *Statement {
return &Statement{
AggregateType: event.Aggregate().Type,
Sequence: event.Sequence(),
Position: event.Position(),
AggregateID: event.Aggregate().ID,
CreationDate: event.CreatedAt(),
InstanceID: event.Aggregate().InstanceID,
Execute: e,
}
}
func NewCreateStatement(event eventstore.Event, values []Column, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
columnNames := strings.Join(cols, ", ")
valuesPlaceholder := strings.Join(params, ", ")
config := execConfig{
args: args,
}
if len(values) == 0 {
config.err = ErrNoValues
}
q := func(config execConfig) string {
return "INSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
}
return NewStatement(event, exec(config, q, opts))
}
func NewUpsertStatement(event eventstore.Event, conflictCols []Column, values []Column, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
conflictTarget := make([]string, len(conflictCols))
for i, col := range conflictCols {
conflictTarget[i] = col.Name
}
config := execConfig{
args: args,
}
if len(values) == 0 {
config.err = ErrNoValues
}
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
if len(updateCols) == 0 || len(updateVals) == 0 {
config.err = ErrNoValues
}
q := func(config execConfig) string {
var updateStmt string
// the postgres standard does not allow to update a single column using a multi-column update
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
if len(updateCols) == 1 && !strings.HasPrefix(updateVals[0], "SELECT") {
updateStmt = "UPDATE SET " + updateCols[0] + " = " + updateVals[0]
} else {
updateStmt = "UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
}
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO " + updateStmt
}
return NewStatement(event, exec(config, q, opts))
}
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
updateCols = make([]string, len(cols))
updateVals = make([]string, len(cols))
copy(updateCols, cols)
for i := len(updateCols) - 1; i >= 0; i-- {
updateVals[i] = "EXCLUDED." + updateCols[i]
for _, conflict := range conflictTarget {
if conflict == updateCols[i] {
copy(updateCols[i:], updateCols[i+1:])
updateCols[len(updateCols)-1] = ""
updateCols = updateCols[:len(updateCols)-1]
copy(updateVals[i:], updateVals[i+1:])
updateVals[len(updateVals)-1] = ""
updateVals = updateVals[:len(updateVals)-1]
break
}
}
}
return updateCols, updateVals
}
func NewUpdateStatement(event eventstore.Event, values []Column, conditions []Condition, opts ...execOption) *Statement {
cols, params, args := columnsToQuery(values)
wheres, whereArgs := conditionsToWhere(conditions, len(args)+1)
args = append(args, whereArgs...)
config := execConfig{
args: args,
}
if len(values) == 0 {
config.err = ErrNoValues
}
if len(conditions) == 0 {
config.err = ErrNoCondition
}
q := func(config execConfig) string {
// the postgres standard does not allow to update a single column using a multi-column update
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
if len(cols) == 1 && !strings.HasPrefix(params[0], "SELECT") {
return "UPDATE " + config.tableName + " SET " + cols[0] + " = " + params[0] + " WHERE " + strings.Join(wheres, " AND ")
}
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
}
return NewStatement(event, exec(config, q, opts))
}
func NewDeleteStatement(event eventstore.Event, conditions []Condition, opts ...execOption) *Statement {
wheres, args := conditionsToWhere(conditions, 1)
wheresPlaceholders := strings.Join(wheres, " AND ")
config := execConfig{
args: args,
}
if len(conditions) == 0 {
config.err = ErrNoCondition
}
q := func(config execConfig) string {
return "DELETE FROM " + config.tableName + " WHERE " + wheresPlaceholders
}
return NewStatement(event, exec(config, q, opts))
}
func NewNoOpStatement(event eventstore.Event) *Statement {
return NewStatement(event, nil)
}
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *Statement {
if len(opts) == 0 {
return NewNoOpStatement(event)
}
execs := make([]Exec, len(opts))
for i, opt := range opts {
execs[i] = opt(event)
}
return NewStatement(event, multiExec(execs))
}
func AddNoOpStatement() func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewNoOpStatement(event).Execute
}
}
func AddCreateStatement(columns []Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCreateStatement(event, columns, opts...).Execute
}
}
func AddUpsertStatement(indexCols []Column, values []Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpsertStatement(event, indexCols, values, opts...).Execute
}
}
func AddUpdateStatement(values []Column, conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpdateStatement(event, values, conditions, opts...).Execute
}
}
func AddDeleteStatement(conditions []Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewDeleteStatement(event, conditions, opts...).Execute
}
}
func AddCopyStatement(conflict, from, to []Column, conditions []NamespacedCondition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
}
}
func NewArrayAppendCol(column string, value interface{}) Column {
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
return "array_append(" + column + ", " + placeholder + ")"
},
}
}
func NewArrayRemoveCol(column string, value interface{}) Column {
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
return "array_remove(" + column + ", " + placeholder + ")"
},
}
}
func NewArrayIntersectCol(column string, value interface{}) Column {
var arrayType string
switch value.(type) {
case []string, database.TextArray[string]:
arrayType = "TEXT"
//TODO: handle more types if necessary
}
return Column{
Name: column,
Value: value,
ParameterOpt: func(placeholder string) string {
return "SELECT ARRAY( SELECT UNNEST(" + column + ") INTERSECT SELECT UNNEST (" + placeholder + "::" + arrayType + "[]))"
},
}
}
func NewCopyCol(column, from string) Column {
return Column{
Name: column,
Value: NewCol(from, nil),
}
}
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
// cols represent the columns which are objective to change.
// if the value of a col is empty the data will be copied from the selected row
// if the value of a col is not empty the data will be set by the static value
// conds represent the conditions for the selection subquery
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []Column, nsCond []NamespacedCondition, opts ...execOption) *Statement {
columnNames := make([]string, len(to))
selectColumns := make([]string, len(from))
updateColumns := make([]string, len(columnNames))
argCounter := 0
args := []interface{}{}
for i, col := range from {
columnNames[i] = to[i].Name
selectColumns[i] = from[i].Name
updateColumns[i] = "EXCLUDED." + col.Name
if col.Value != nil {
argCounter++
selectColumns[i] = "$" + strconv.Itoa(argCounter)
updateColumns[i] = selectColumns[i]
args = append(args, col.Value)
}
}
cond := make([]Condition, len(nsCond))
for i := range nsCond {
cond[i] = nsCond[i]("copy_table")
}
wheres, values := conditionsToWhere(cond, len(args)+1)
args = append(args, values...)
conflictTargets := make([]string, len(conflictCols))
for i, conflictCol := range conflictCols {
conflictTargets[i] = conflictCol.Name
}
config := execConfig{
args: args,
}
if len(from) == 0 || len(to) == 0 || len(from) != len(to) {
config.err = ErrNoValues
}
if len(cond) == 0 {
config.err = ErrNoCondition
}
q := func(config execConfig) string {
return "INSERT INTO " +
config.tableName +
" (" +
strings.Join(columnNames, ", ") +
") SELECT " +
strings.Join(selectColumns, ", ") +
" FROM " +
config.tableName + " AS copy_table WHERE " +
strings.Join(wheres, " AND ") +
" ON CONFLICT (" +
strings.Join(conflictTargets, ", ") +
") DO UPDATE SET (" +
strings.Join(columnNames, ", ") +
") = (" +
strings.Join(updateColumns, ", ") +
")"
}
return NewStatement(event, exec(config, q, opts))
}
func columnsToQuery(cols []Column) (names []string, parameters []string, values []interface{}) {
names = make([]string, len(cols))
values = make([]interface{}, len(cols))
parameters = make([]string, len(cols))
var parameterIndex int
for i, col := range cols {
names[i] = col.Name
if c, ok := col.Value.(Column); ok {
parameters[i] = c.Name
continue
} else {
values[parameterIndex] = col.Value
}
parameters[i] = "$" + strconv.Itoa(parameterIndex+1)
if col.ParameterOpt != nil {
parameters[i] = col.ParameterOpt(parameters[i])
}
parameterIndex++
}
return names, parameters, values[:parameterIndex]
}
func conditionsToWhere(conds []Condition, paramOffset int) (wheres []string, values []interface{}) {
wheres = make([]string, len(conds))
values = make([]any, 0, len(conds))
for i, cond := range conds {
var args []any
wheres[i], args = cond("$" + strconv.Itoa(paramOffset))
paramOffset += len(args)
values = append(values, args...)
wheres[i] = "(" + wheres[i] + ")"
}
return wheres, values
}
type Column struct {
Name string
Value interface{}
ParameterOpt func(string) string
}
func NewCol(name string, value interface{}) Column {
return Column{
Name: name,
Value: value,
}
}
func NewJSONCol(name string, value interface{}) Column {
marshalled, err := json.Marshal(value)
if err != nil {
logging.WithFields("column", name).WithError(err).Panic("unable to marshal column")
}
return NewCol(name, marshalled)
}
type Condition func(param string) (string, []any)
type NamespacedCondition func(namespace string) Condition
func NewCond(name string, value interface{}) Condition {
return func(param string) (string, []any) {
return name + " = " + param, []any{value}
}
}
func NewNamespacedCondition(name string, value interface{}) NamespacedCondition {
return func(namespace string) Condition {
return NewCond(namespace+"."+name, value)
}
}
func NewLessThanCond(column string, value interface{}) Condition {
return func(param string) (string, []any) {
return column + " < " + param, []any{value}
}
}
func NewIsNullCond(column string) Condition {
return func(string) (string, []any) {
return column + " IS NULL", nil
}
}
// NewTextArrayContainsCond returns a Condition that checks if the column that stores an array of text contains the given value
func NewTextArrayContainsCond(column string, value string) Condition {
return func(param string) (string, []any) {
return column + " @> " + param, []any{database.TextArray[string]{value}}
}
}
// Not is a function and not a method, so that calling it is well readable
// For example conditions := []Condition{ Not(NewTextArrayContainsCond())}
func Not(condition Condition) Condition {
return func(param string) (string, []any) {
cond, value := condition(param)
return "NOT (" + cond + ")", value
}
}
type Executer interface {
Exec(string, ...interface{}) (sql.Result, error)
}
type execOption func(*execConfig)
type execConfig struct {
tableName string
args []interface{}
err error
}
type query func(config execConfig) string
func exec(config execConfig, q query, opts []execOption) Exec {
return func(ex Executer, projectionName string) (err error) {
if projectionName == "" {
return ErrNoProjection
}
if config.err != nil {
return config.err
}
config.tableName = projectionName
for _, opt := range opts {
opt(&config)
}
_, err = ex.Exec("SAVEPOINT stmt_exec")
if err != nil {
return errors.ThrowInternal(err, "CRDB-YdOXD", "create savepoint failed")
}
defer func() {
if err != nil {
_, rollbackErr := ex.Exec("ROLLBACK TO SAVEPOINT stmt_exec")
logging.OnError(rollbackErr).Debug("rollback failed")
return
}
_, err = ex.Exec("RELEASE SAVEPOINT stmt_exec")
}()
_, err = ex.Exec(q(config), config.args...)
if err != nil {
return errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
}
return nil
}
}
func multiExec(execList []Exec) Exec {
return func(ex Executer, projectionName string) error {
for _, exec := range execList {
if exec == nil {
continue
}
if err := exec(ex, projectionName); err != nil {
return err
}
}
return nil
}
}

File diff suppressed because it is too large Load Diff