mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:47:33 +00:00
feat(eventstore): increase parallel write capabilities (#5940)
This implementation increases parallel write capabilities of the eventstore. Please have a look at the technical advisories: [05](https://zitadel.com/docs/support/advisory/a10005) and [06](https://zitadel.com/docs/support/advisory/a10006). The implementation of eventstore.push is rewritten and stored events are migrated to a new table `eventstore.events2`. If you are using cockroach: make sure that the database user of ZITADEL has `VIEWACTIVITY` grant. This is used to query events.
This commit is contained in:
@@ -1,83 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
currentSequenceStmtWithoutLockFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2)`
|
||||
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
|
||||
)
|
||||
|
||||
type currentSequences map[eventstore.AggregateType][]*instanceSequence
|
||||
|
||||
type instanceSequence struct {
|
||||
instanceID string
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, isTx bool, query func(context.Context, func(*sql.Rows) error, string, ...interface{}) error, instanceIDs database.StringArray) (currentSequences, error) {
|
||||
stmt := h.currentSequenceStmt
|
||||
if !isTx {
|
||||
stmt = h.currentSequenceWithoutLockStmt
|
||||
}
|
||||
|
||||
sequences := make(currentSequences, len(h.aggregates))
|
||||
err := query(ctx,
|
||||
func(rows *sql.Rows) error {
|
||||
for rows.Next() {
|
||||
var (
|
||||
aggregateType eventstore.AggregateType
|
||||
sequence uint64
|
||||
instanceID string
|
||||
)
|
||||
|
||||
err := rows.Scan(&sequence, &aggregateType, &instanceID)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
|
||||
}
|
||||
|
||||
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
|
||||
sequence: sequence,
|
||||
instanceID: instanceID,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
},
|
||||
stmt, h.ProjectionName, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sequences, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentSequences) error {
|
||||
valueQueries := make([]string, 0, len(sequences))
|
||||
valueCounter := 0
|
||||
values := make([]interface{}, 0, len(sequences)*3)
|
||||
for aggregate, instanceSequence := range sequences {
|
||||
for _, sequence := range instanceSequence {
|
||||
valueQueries = append(valueQueries, "($"+strconv.Itoa(valueCounter+1)+", $"+strconv.Itoa(valueCounter+2)+", $"+strconv.Itoa(valueCounter+3)+", $"+strconv.Itoa(valueCounter+4)+", NOW())")
|
||||
valueCounter += 4
|
||||
values = append(values, h.ProjectionName, aggregate, sequence.sequence, sequence.instanceID)
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", ")+updateCurrentSequencesConflictStmt, values...)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-TrH2Z", "unable to exec update sequence")
|
||||
}
|
||||
if rows, _ := res.RowsAffected(); rows < 1 {
|
||||
return errSeqNotUpdated
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,301 +1,15 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type mockExpectation func(sqlmock.Sqlmock)
|
||||
|
||||
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2 AND instance_id = \$3\) SELECT COALESCE\(\(SELECT failure_count FROM failures\), 0\) AS failure_count`).
|
||||
WithArgs(projectionName, failedSeq, instanceID).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"failure_count"}).
|
||||
AddRow(failureCount),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id, last_failed\) VALUES \(\$1, \$2, \$3, \$4\, \$5\, \$6\) ON CONFLICT \(projection_name, failed_sequence, instance_id\) DO UPDATE SET failure_count = EXCLUDED\.failure_count, error = EXCLUDED\.error, last_failed = EXCLUDED\.last_failed`).
|
||||
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID, "NOW()").WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectCreate(projectionName string, columnNames, placeholders []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
args := make([]driver.Value, len(columnNames))
|
||||
for i := 0; i < len(columnNames); i++ {
|
||||
args[i] = sqlmock.AnyArg()
|
||||
placeholders[i] = `\` + placeholders[i]
|
||||
}
|
||||
m.ExpectExec("INSERT INTO " + projectionName + ` \(` + strings.Join(columnNames, ", ") + `\) VALUES \(` + strings.Join(placeholders, ", ") + `\)`).
|
||||
WithArgs(args...).
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectCreateErr(projectionName string, columnNames, placeholders []string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
args := make([]driver.Value, len(columnNames))
|
||||
for i := 0; i < len(columnNames); i++ {
|
||||
args[i] = sqlmock.AnyArg()
|
||||
placeholders[i] = `\` + placeholders[i]
|
||||
}
|
||||
m.ExpectExec("INSERT INTO " + projectionName + ` \(` + strings.Join(columnNames, ", ") + `\) VALUES \(` + strings.Join(placeholders, ", ") + `\)`).
|
||||
WithArgs(args...).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectBegin() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin()
|
||||
}
|
||||
}
|
||||
|
||||
func expectBeginErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectBegin().WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCommit() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectCommit()
|
||||
}
|
||||
}
|
||||
|
||||
func expectCommitErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectCommit().WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectRollback() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectRollback()
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePoint() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("SAVEPOINT push_stmt").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("SAVEPOINT push_stmt").
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointRollback() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("ROLLBACK TO SAVEPOINT push_stmt").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointRollbackErr(err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("ROLLBACK TO SAVEPOINT push_stmt").
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectSavePointRelease() func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("RELEASE push_stmt").
|
||||
WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequence(isTx bool, tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
|
||||
for _, instanceID := range instanceIDs {
|
||||
rows.AddRow(seq, aggregateType, instanceID)
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
|
||||
if isTx {
|
||||
stmt += " FOR UPDATE"
|
||||
}
|
||||
m.ExpectQuery(stmt).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
rows,
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceErr(isTx bool, tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
|
||||
if isTx {
|
||||
stmt += " FOR UPDATE"
|
||||
}
|
||||
m.ExpectQuery(stmt).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
RowError(0, sql.ErrTxDone).
|
||||
AddRow(0, "agg", "instanceID"),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateThreeCurrentSequence(t *testing.T, tableName, projection string, sequences currentSequences) func(sqlmock.Sqlmock) {
|
||||
args := make([][]interface{}, 0)
|
||||
for aggregateType, instanceSequences := range sequences {
|
||||
for _, sequence := range instanceSequences {
|
||||
args = append(args, []interface{}{
|
||||
projection,
|
||||
aggregateType,
|
||||
sequence.sequence,
|
||||
sequence.instanceID,
|
||||
})
|
||||
}
|
||||
}
|
||||
matcher := ¤tSequenceMatcher{t: t, seq: args}
|
||||
matchers := make([]driver.Value, len(args)*4)
|
||||
for i := 0; i < len(args)*4; i++ {
|
||||
matchers[i] = matcher
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
matchers...,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
type currentSequenceMatcher struct {
|
||||
seq [][]interface{}
|
||||
i int
|
||||
t *testing.T
|
||||
}
|
||||
|
||||
func (m *currentSequenceMatcher) Match(value driver.Value) bool {
|
||||
if m.i%4 == 0 {
|
||||
m.i = 0
|
||||
}
|
||||
left := make([]interface{}, 0, len(m.seq))
|
||||
for _, seq := range m.seq {
|
||||
found := seq[m.i]
|
||||
if found == nil {
|
||||
continue
|
||||
}
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
if found == v || found == eventstore.AggregateType(v) {
|
||||
seq[m.i] = nil
|
||||
m.i++
|
||||
return true
|
||||
}
|
||||
case int64:
|
||||
if found == uint64(v) {
|
||||
seq[m.i] = nil
|
||||
m.i++
|
||||
return true
|
||||
}
|
||||
}
|
||||
left = append(left, found)
|
||||
}
|
||||
m.t.Errorf("expected: %v, possible left values: %v", value, left)
|
||||
m.t.FailNow()
|
||||
return false
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
seq,
|
||||
instanceID,
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(0, 0),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
func expectLock(lockTable, workerName string, d time.Duration, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`INSERT INTO `+lockTable+
|
||||
@@ -308,7 +22,7 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
database.StringArray{instanceID},
|
||||
database.TextArray[string]{instanceID},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -329,7 +43,7 @@ func expectLockMultipleInstances(lockTable, workerName string, d time.Duration,
|
||||
projectionName,
|
||||
instanceID1,
|
||||
instanceID2,
|
||||
database.StringArray{instanceID1, instanceID2},
|
||||
database.TextArray[string]{instanceID1, instanceID2},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -349,7 +63,7 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
database.StringArray{instanceID},
|
||||
database.TextArray[string]{instanceID},
|
||||
).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
@@ -367,7 +81,7 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
database.StringArray{instanceID},
|
||||
database.TextArray[string]{instanceID},
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
|
@@ -1,51 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
const (
|
||||
setFailureCountStmtFormat = "INSERT INTO %s" +
|
||||
" (projection_name, failed_sequence, failure_count, error, instance_id, last_failed)" +
|
||||
" VALUES ($1, $2, $3, $4, $5, $6) ON CONFLICT (projection_name, failed_sequence, instance_id)" +
|
||||
" DO UPDATE SET failure_count = EXCLUDED.failure_count, error = EXCLUDED.error, last_failed = EXCLUDED.last_failed"
|
||||
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
|
||||
" SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count"
|
||||
)
|
||||
|
||||
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {
|
||||
failureCount, err := h.failureCount(tx, stmt.Sequence, stmt.InstanceID)
|
||||
if err != nil {
|
||||
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).WithError(err).Warn("unable to get failure count")
|
||||
return false
|
||||
}
|
||||
failureCount += 1
|
||||
err = h.setFailureCount(tx, stmt.Sequence, failureCount, execErr, stmt.InstanceID)
|
||||
logging.WithFields("projection", h.ProjectionName, "sequence", stmt.Sequence).OnError(err).Warn("unable to update failure count")
|
||||
|
||||
return failureCount >= h.maxFailureCount
|
||||
}
|
||||
|
||||
func (h *StatementHandler) failureCount(tx *sql.Tx, seq uint64, instanceID string) (count uint, err error) {
|
||||
row := tx.QueryRow(h.failureCountStmt, h.ProjectionName, seq, instanceID)
|
||||
if err = row.Err(); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-Unnex", "unable to update failure count")
|
||||
}
|
||||
if err = row.Scan(&count); err != nil {
|
||||
return 0, errors.ThrowInternal(err, "CRDB-RwSMV", "unable to scan count")
|
||||
}
|
||||
return count, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) setFailureCount(tx *sql.Tx, seq uint64, count uint, err error, instanceID string) error {
|
||||
_, dbErr := tx.Exec(h.setFailureCountStmt, h.ProjectionName, seq, count, err.Error(), instanceID, "NOW()")
|
||||
if dbErr != nil {
|
||||
return errors.ThrowInternal(dbErr, "CRDB-4Ht4x", "set failure count failed")
|
||||
}
|
||||
return nil
|
||||
}
|
@@ -1,347 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
"github.com/zitadel/zitadel/internal/repository/pseudo"
|
||||
)
|
||||
|
||||
var (
|
||||
errSeqNotUpdated = errors.ThrowInternal(nil, "CRDB-79GWt", "current sequence not updated")
|
||||
)
|
||||
|
||||
type StatementHandlerConfig struct {
|
||||
handler.ProjectionHandlerConfig
|
||||
|
||||
Client *database.DB
|
||||
SequenceTable string
|
||||
LockTable string
|
||||
FailedEventsTable string
|
||||
MaxFailureCount uint
|
||||
BulkLimit uint64
|
||||
|
||||
Reducers []handler.AggregateReducer
|
||||
InitCheck *handler.Check
|
||||
}
|
||||
|
||||
type StatementHandler struct {
|
||||
*handler.ProjectionHandler
|
||||
Locker
|
||||
|
||||
client *database.DB
|
||||
sequenceTable string
|
||||
currentSequenceStmt string
|
||||
currentSequenceWithoutLockStmt string
|
||||
updateSequencesBaseStmt string
|
||||
maxFailureCount uint
|
||||
failureCountStmt string
|
||||
setFailureCountStmt string
|
||||
|
||||
aggregates []eventstore.AggregateType
|
||||
reduces map[eventstore.EventType]handler.Reduce
|
||||
initCheck *handler.Check
|
||||
initialized chan bool
|
||||
|
||||
bulkLimit uint64
|
||||
|
||||
reduceScheduledPseudoEvent bool
|
||||
}
|
||||
|
||||
func NewStatementHandler(
|
||||
ctx context.Context,
|
||||
config StatementHandlerConfig,
|
||||
) StatementHandler {
|
||||
aggregateTypes := make([]eventstore.AggregateType, 0, len(config.Reducers))
|
||||
reduces := make(map[eventstore.EventType]handler.Reduce, len(config.Reducers))
|
||||
reduceScheduledPseudoEvent := false
|
||||
for _, aggReducer := range config.Reducers {
|
||||
aggregateTypes = append(aggregateTypes, aggReducer.Aggregate)
|
||||
if aggReducer.Aggregate == pseudo.AggregateType {
|
||||
reduceScheduledPseudoEvent = true
|
||||
if len(config.Reducers) != 1 ||
|
||||
len(aggReducer.EventRedusers) != 1 ||
|
||||
aggReducer.EventRedusers[0].Event != pseudo.ScheduledEventType {
|
||||
panic("if a pseudo.AggregateType is reduced, exactly one event reducer for pseudo.ScheduledEventType is supported and no other aggregate can be reduced")
|
||||
}
|
||||
}
|
||||
for _, eventReducer := range aggReducer.EventRedusers {
|
||||
reduces[eventReducer.Event] = eventReducer.Reduce
|
||||
}
|
||||
}
|
||||
|
||||
h := StatementHandler{
|
||||
client: config.Client,
|
||||
sequenceTable: config.SequenceTable,
|
||||
maxFailureCount: config.MaxFailureCount,
|
||||
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
|
||||
currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable),
|
||||
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
|
||||
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
|
||||
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
|
||||
aggregates: aggregateTypes,
|
||||
reduces: reduces,
|
||||
bulkLimit: config.BulkLimit,
|
||||
Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName),
|
||||
initCheck: config.InitCheck,
|
||||
initialized: make(chan bool),
|
||||
reduceScheduledPseudoEvent: reduceScheduledPseudoEvent,
|
||||
}
|
||||
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.searchQuery, h.Lock, h.Unlock, h.initialized, reduceScheduledPseudoEvent)
|
||||
|
||||
return h
|
||||
}
|
||||
|
||||
func (h *StatementHandler) Start() {
|
||||
h.initialized <- true
|
||||
close(h.initialized)
|
||||
if !h.reduceScheduledPseudoEvent {
|
||||
h.Subscribe(h.aggregates...)
|
||||
}
|
||||
}
|
||||
|
||||
func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
if h.reduceScheduledPseudoEvent {
|
||||
return nil, 1, nil
|
||||
}
|
||||
return h.dbSearchQuery(ctx, instanceIDs)
|
||||
}
|
||||
|
||||
func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
|
||||
sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
queryBuilder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).Limit(h.bulkLimit).AllowTimeTravel()
|
||||
|
||||
for _, aggregateType := range h.aggregates {
|
||||
for _, instanceID := range instanceIDs {
|
||||
var seq uint64
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
if sequence.instanceID == instanceID {
|
||||
seq = sequence.sequence
|
||||
break
|
||||
}
|
||||
}
|
||||
queryBuilder.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(seq).
|
||||
InstanceID(instanceID)
|
||||
}
|
||||
}
|
||||
return queryBuilder, h.bulkLimit, nil
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
*sql.Tx
|
||||
}
|
||||
|
||||
func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
|
||||
rows, err := t.Tx.QueryContext(ctx, query, args...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
logging.OnError(closeErr).Info("rows.Close failed")
|
||||
}()
|
||||
|
||||
if err = scan(rows); err != nil {
|
||||
return err
|
||||
}
|
||||
return rows.Err()
|
||||
}
|
||||
|
||||
// Update implements handler.Update
|
||||
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
|
||||
if len(stmts) == 0 {
|
||||
return -1, nil
|
||||
}
|
||||
instanceIDs := make([]string, 0, len(stmts))
|
||||
for _, stmt := range stmts {
|
||||
instanceIDs = appendToInstanceIDs(instanceIDs, stmt.InstanceID)
|
||||
}
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
|
||||
}
|
||||
|
||||
sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
}
|
||||
|
||||
//checks for events between create statement and current sequence
|
||||
// because there could be events between current sequence and a creation event
|
||||
// and we cannot check via stmt.PreviousSequence
|
||||
if stmts[0].PreviousSequence == 0 {
|
||||
previousStmts, err := h.fetchPreviousStmts(ctx, tx, stmts[0].Sequence, stmts[0].InstanceID, sequences, reduce)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
}
|
||||
stmts = append(previousStmts, stmts...)
|
||||
}
|
||||
|
||||
lastSuccessfulIdx := h.executeStmts(tx, &stmts, sequences)
|
||||
|
||||
if lastSuccessfulIdx >= 0 {
|
||||
err = h.updateCurrentSequences(tx, sequences)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
if err = tx.Commit(); err != nil {
|
||||
return -1, err
|
||||
}
|
||||
|
||||
if lastSuccessfulIdx < len(stmts)-1 {
|
||||
return lastSuccessfulIdx, handler.ErrSomeStmtsFailed
|
||||
}
|
||||
|
||||
return lastSuccessfulIdx, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) fetchPreviousStmts(ctx context.Context, tx *sql.Tx, stmtSeq uint64, instanceID string, sequences currentSequences, reduce handler.Reduce) (previousStmts []*handler.Statement, err error) {
|
||||
query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).SetTx(tx)
|
||||
queriesAdded := false
|
||||
for _, aggregateType := range h.aggregates {
|
||||
for _, sequence := range sequences[aggregateType] {
|
||||
if stmtSeq <= sequence.sequence && instanceID == sequence.instanceID {
|
||||
continue
|
||||
}
|
||||
|
||||
query.
|
||||
AddQuery().
|
||||
AggregateTypes(aggregateType).
|
||||
SequenceGreater(sequence.sequence).
|
||||
SequenceLess(stmtSeq).
|
||||
InstanceID(sequence.instanceID)
|
||||
|
||||
queriesAdded = true
|
||||
}
|
||||
}
|
||||
|
||||
if !queriesAdded {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
events, err := h.Eventstore.Filter(ctx, query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for _, event := range events {
|
||||
stmt, err := reduce(event)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
previousStmts = append(previousStmts, stmt)
|
||||
}
|
||||
return previousStmts, nil
|
||||
}
|
||||
|
||||
func (h *StatementHandler) executeStmts(
|
||||
tx *sql.Tx,
|
||||
stmts *[]*handler.Statement,
|
||||
sequences currentSequences,
|
||||
) int {
|
||||
|
||||
lastSuccessfulIdx := -1
|
||||
stmts:
|
||||
for i := 0; i < len(*stmts); i++ {
|
||||
stmt := (*stmts)[i]
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if stmt.Sequence <= sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
logging.WithFields("statement", stmt, "currentSequence", sequence).Debug("statement dropped")
|
||||
if i < len(*stmts)-1 {
|
||||
copy((*stmts)[i:], (*stmts)[i+1:])
|
||||
}
|
||||
*stmts = (*stmts)[:len(*stmts)-1]
|
||||
i--
|
||||
continue stmts
|
||||
}
|
||||
if stmt.PreviousSequence > 0 && stmt.PreviousSequence != sequence.sequence && stmt.InstanceID == sequence.instanceID {
|
||||
logging.WithFields("projection", h.ProjectionName, "aggregateType", stmt.AggregateType, "sequence", stmt.Sequence, "prevSeq", stmt.PreviousSequence, "currentSeq", sequence.sequence).Warn("sequences do not match")
|
||||
break stmts
|
||||
}
|
||||
}
|
||||
err := h.executeStmt(tx, stmt)
|
||||
if err == nil {
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
|
||||
shouldContinue := h.handleFailedStmt(tx, stmt, err)
|
||||
if !shouldContinue {
|
||||
break
|
||||
}
|
||||
|
||||
updateSequences(sequences, stmt)
|
||||
lastSuccessfulIdx = i
|
||||
continue
|
||||
}
|
||||
return lastSuccessfulIdx
|
||||
}
|
||||
|
||||
// executeStmt handles sql statements
|
||||
// an error is returned if the statement could not be inserted properly
|
||||
func (h *StatementHandler) executeStmt(tx *sql.Tx, stmt *handler.Statement) error {
|
||||
if stmt.IsNoop() {
|
||||
return nil
|
||||
}
|
||||
_, err := tx.Exec("SAVEPOINT push_stmt")
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-i1wp6", "unable to create savepoint")
|
||||
}
|
||||
err = stmt.Execute(tx, h.ProjectionName)
|
||||
if err != nil {
|
||||
logging.WithError(err).Error()
|
||||
_, rollbackErr := tx.Exec("ROLLBACK TO SAVEPOINT push_stmt")
|
||||
if rollbackErr != nil {
|
||||
return errors.ThrowInternal(rollbackErr, "CRDB-zzp3P", "rollback to savepoint failed")
|
||||
}
|
||||
return errors.ThrowInternal(err, "CRDB-oRkaN", "unable execute stmt")
|
||||
}
|
||||
_, err = tx.Exec("RELEASE push_stmt")
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-qWgwT", "unable to release savepoint")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func updateSequences(sequences currentSequences, stmt *handler.Statement) {
|
||||
for _, sequence := range sequences[stmt.AggregateType] {
|
||||
if sequence.instanceID == stmt.InstanceID {
|
||||
sequence.sequence = stmt.Sequence
|
||||
return
|
||||
}
|
||||
}
|
||||
sequences[stmt.AggregateType] = append(sequences[stmt.AggregateType], &instanceSequence{
|
||||
instanceID: stmt.InstanceID,
|
||||
sequence: stmt.Sequence,
|
||||
})
|
||||
}
|
||||
|
||||
func appendToInstanceIDs(instances []string, id string) []string {
|
||||
for _, instance := range instances {
|
||||
if instance == id {
|
||||
return instances
|
||||
}
|
||||
}
|
||||
return append(instances, id)
|
||||
}
|
File diff suppressed because it is too large
Load Diff
@@ -1,412 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
type Table struct {
|
||||
columns []*Column
|
||||
primaryKey PrimaryKey
|
||||
indices []*Index
|
||||
constraints []*Constraint
|
||||
foreignKeys []*ForeignKey
|
||||
}
|
||||
|
||||
func NewTable(columns []*Column, key PrimaryKey, opts ...TableOption) *Table {
|
||||
t := &Table{
|
||||
columns: columns,
|
||||
primaryKey: key,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(t)
|
||||
}
|
||||
return t
|
||||
}
|
||||
|
||||
type SuffixedTable struct {
|
||||
Table
|
||||
suffix string
|
||||
}
|
||||
|
||||
func NewSuffixedTable(columns []*Column, key PrimaryKey, suffix string, opts ...TableOption) *SuffixedTable {
|
||||
return &SuffixedTable{
|
||||
Table: *NewTable(columns, key, opts...),
|
||||
suffix: suffix,
|
||||
}
|
||||
}
|
||||
|
||||
type TableOption func(*Table)
|
||||
|
||||
func WithIndex(index *Index) TableOption {
|
||||
return func(table *Table) {
|
||||
table.indices = append(table.indices, index)
|
||||
}
|
||||
}
|
||||
|
||||
func WithConstraint(constraint *Constraint) TableOption {
|
||||
return func(table *Table) {
|
||||
table.constraints = append(table.constraints, constraint)
|
||||
}
|
||||
}
|
||||
|
||||
func WithForeignKey(key *ForeignKey) TableOption {
|
||||
return func(table *Table) {
|
||||
table.foreignKeys = append(table.foreignKeys, key)
|
||||
}
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
Name string
|
||||
Type ColumnType
|
||||
nullable bool
|
||||
defaultValue interface{}
|
||||
deleteCascade string
|
||||
}
|
||||
|
||||
type ColumnOption func(*Column)
|
||||
|
||||
func NewColumn(name string, columnType ColumnType, opts ...ColumnOption) *Column {
|
||||
column := &Column{
|
||||
Name: name,
|
||||
Type: columnType,
|
||||
nullable: false,
|
||||
defaultValue: nil,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(column)
|
||||
}
|
||||
return column
|
||||
}
|
||||
|
||||
func Nullable() ColumnOption {
|
||||
return func(c *Column) {
|
||||
c.nullable = true
|
||||
}
|
||||
}
|
||||
|
||||
func Default(value interface{}) ColumnOption {
|
||||
return func(c *Column) {
|
||||
c.defaultValue = value
|
||||
}
|
||||
}
|
||||
|
||||
func DeleteCascade(column string) ColumnOption {
|
||||
return func(c *Column) {
|
||||
c.deleteCascade = column
|
||||
}
|
||||
}
|
||||
|
||||
type PrimaryKey []string
|
||||
|
||||
func NewPrimaryKey(columnNames ...string) PrimaryKey {
|
||||
return columnNames
|
||||
}
|
||||
|
||||
type ColumnType int32
|
||||
|
||||
const (
|
||||
ColumnTypeText ColumnType = iota
|
||||
ColumnTypeTextArray
|
||||
ColumnTypeJSONB
|
||||
ColumnTypeBytes
|
||||
ColumnTypeTimestamp
|
||||
ColumnTypeInterval
|
||||
ColumnTypeEnum
|
||||
ColumnTypeEnumArray
|
||||
ColumnTypeInt64
|
||||
ColumnTypeBool
|
||||
)
|
||||
|
||||
func NewIndex(name string, columns []string, opts ...indexOpts) *Index {
|
||||
i := &Index{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
bucketCount: 0,
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(i)
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
type Index struct {
|
||||
Name string
|
||||
Columns []string
|
||||
bucketCount uint16
|
||||
}
|
||||
|
||||
type indexOpts func(*Index)
|
||||
|
||||
func Hash(bucketsCount uint16) indexOpts {
|
||||
return func(i *Index) {
|
||||
i.bucketCount = bucketsCount
|
||||
}
|
||||
}
|
||||
|
||||
func NewConstraint(name string, columns []string) *Constraint {
|
||||
i := &Constraint{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
type Constraint struct {
|
||||
Name string
|
||||
Columns []string
|
||||
}
|
||||
|
||||
func NewForeignKey(name string, columns []string, refColumns []string) *ForeignKey {
|
||||
i := &ForeignKey{
|
||||
Name: name,
|
||||
Columns: columns,
|
||||
RefColumns: refColumns,
|
||||
}
|
||||
return i
|
||||
}
|
||||
|
||||
func NewForeignKeyOfPublicKeys() *ForeignKey {
|
||||
return &ForeignKey{
|
||||
Name: "",
|
||||
}
|
||||
}
|
||||
|
||||
type ForeignKey struct {
|
||||
Name string
|
||||
Columns []string
|
||||
RefColumns []string
|
||||
}
|
||||
|
||||
// Init implements handler.Init
|
||||
func (h *StatementHandler) Init(ctx context.Context) error {
|
||||
check := h.initCheck
|
||||
if check == nil || check.IsNoop() {
|
||||
return nil
|
||||
}
|
||||
tx, err := h.client.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return caos_errs.ThrowInternal(err, "CRDB-SAdf2", "begin failed")
|
||||
}
|
||||
for i, execute := range check.Executes {
|
||||
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("executing check")
|
||||
next, err := execute(h.client, h.ProjectionName)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return err
|
||||
}
|
||||
if !next {
|
||||
logging.WithFields("projection", h.ProjectionName, "execute", i).Debug("projection set up")
|
||||
break
|
||||
}
|
||||
}
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func NewTableCheck(table *Table, opts ...execOption) *handler.Check {
|
||||
config := execConfig{}
|
||||
create := func(config execConfig) string {
|
||||
return createTableStatement(table, config.tableName, "")
|
||||
}
|
||||
executes := make([]func(handler.Executer, string) (bool, error), len(table.indices)+1)
|
||||
executes[0] = execNextIfExists(config, create, opts, true)
|
||||
for i, index := range table.indices {
|
||||
executes[i+1] = execNextIfExists(config, createIndexCheck(index), opts, true)
|
||||
}
|
||||
return &handler.Check{
|
||||
Executes: executes,
|
||||
}
|
||||
}
|
||||
|
||||
func NewMultiTableCheck(primaryTable *Table, secondaryTables ...*SuffixedTable) *handler.Check {
|
||||
config := execConfig{}
|
||||
create := func(config execConfig) string {
|
||||
stmt := createTableStatement(primaryTable, config.tableName, "")
|
||||
for _, table := range secondaryTables {
|
||||
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
return &handler.Check{
|
||||
Executes: []func(handler.Executer, string) (bool, error){
|
||||
execNextIfExists(config, create, nil, true),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewViewCheck(selectStmt string, secondaryTables ...*SuffixedTable) *handler.Check {
|
||||
config := execConfig{}
|
||||
create := func(config execConfig) string {
|
||||
var stmt string
|
||||
for _, table := range secondaryTables {
|
||||
stmt += createTableStatement(&table.Table, config.tableName, "_"+table.suffix)
|
||||
}
|
||||
stmt += createViewStatement(config.tableName, selectStmt)
|
||||
return stmt
|
||||
}
|
||||
|
||||
return &handler.Check{
|
||||
Executes: []func(handler.Executer, string) (bool, error){
|
||||
execNextIfExists(config, create, nil, false),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func execNextIfExists(config execConfig, q query, opts []execOption, executeNext bool) func(handler.Executer, string) (bool, error) {
|
||||
return func(handler handler.Executer, name string) (bool, error) {
|
||||
err := exec(config, q, opts)(handler, name)
|
||||
if isErrAlreadyExists(err) {
|
||||
return executeNext, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
func isErrAlreadyExists(err error) bool {
|
||||
caosErr := &caos_errs.CaosError{}
|
||||
if !errors.As(err, &caosErr) {
|
||||
return false
|
||||
}
|
||||
sqlErr, ok := caosErr.GetParent().(*pgconn.PgError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return sqlErr.Code == "42P07"
|
||||
}
|
||||
|
||||
func createTableStatement(table *Table, tableName string, suffix string) string {
|
||||
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
|
||||
tableName+suffix,
|
||||
createColumnsStatement(table.columns, tableName),
|
||||
strings.Join(table.primaryKey, ", "),
|
||||
)
|
||||
for _, key := range table.foreignKeys {
|
||||
ref := tableName
|
||||
if len(key.RefColumns) > 0 {
|
||||
ref += fmt.Sprintf("(%s)", strings.Join(key.RefColumns, ","))
|
||||
}
|
||||
if len(key.Columns) == 0 {
|
||||
key.Columns = table.primaryKey
|
||||
}
|
||||
stmt += fmt.Sprintf(", CONSTRAINT %s FOREIGN KEY (%s) REFERENCES %s ON DELETE CASCADE", foreignKeyName(key.Name, tableName, suffix), strings.Join(key.Columns, ","), ref)
|
||||
}
|
||||
for _, constraint := range table.constraints {
|
||||
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraintName(constraint.Name, tableName, suffix), strings.Join(constraint.Columns, ","))
|
||||
}
|
||||
|
||||
stmt += ");"
|
||||
|
||||
for _, index := range table.indices {
|
||||
stmt += createIndexStatement(index, tableName+suffix)
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
func createViewStatement(viewName string, selectStmt string) string {
|
||||
return fmt.Sprintf("CREATE VIEW %s AS %s",
|
||||
viewName,
|
||||
selectStmt,
|
||||
)
|
||||
}
|
||||
|
||||
func createIndexCheck(index *Index) func(config execConfig) string {
|
||||
return func(config execConfig) string {
|
||||
return createIndexStatement(index, config.tableName)
|
||||
}
|
||||
}
|
||||
|
||||
func createIndexStatement(index *Index, tableName string) string {
|
||||
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
indexName(index.Name, tableName),
|
||||
tableName,
|
||||
strings.Join(index.Columns, ","),
|
||||
)
|
||||
if index.bucketCount == 0 {
|
||||
return stmt + ";"
|
||||
}
|
||||
return fmt.Sprintf("SET experimental_enable_hash_sharded_indexes=on; %s USING HASH WITH BUCKET_COUNT = %d;",
|
||||
stmt, index.bucketCount)
|
||||
}
|
||||
|
||||
func foreignKeyName(name, tableName, suffix string) string {
|
||||
if name == "" {
|
||||
key := "fk" + suffix + "_ref_" + tableNameWithoutSchema(tableName)
|
||||
return key
|
||||
}
|
||||
return "fk_" + tableNameWithoutSchema(tableName+suffix) + "_" + name
|
||||
}
|
||||
func constraintName(name, tableName, suffix string) string {
|
||||
return tableNameWithoutSchema(tableName+suffix) + "_" + name + "_unique"
|
||||
}
|
||||
func indexName(name, tableName string) string {
|
||||
return tableNameWithoutSchema(tableName) + "_" + name + "_idx"
|
||||
}
|
||||
|
||||
func tableNameWithoutSchema(name string) string {
|
||||
return name[strings.LastIndex(name, ".")+1:]
|
||||
}
|
||||
|
||||
func createColumnsStatement(cols []*Column, tableName string) string {
|
||||
columns := make([]string, len(cols))
|
||||
for i, col := range cols {
|
||||
column := col.Name + " " + columnType(col.Type)
|
||||
if !col.nullable {
|
||||
column += " NOT NULL"
|
||||
}
|
||||
if col.defaultValue != nil {
|
||||
column += " DEFAULT " + defaultValue(col.defaultValue)
|
||||
}
|
||||
if len(col.deleteCascade) != 0 {
|
||||
column += fmt.Sprintf(" REFERENCES %s (%s) ON DELETE CASCADE", tableName, col.deleteCascade)
|
||||
}
|
||||
columns[i] = column
|
||||
}
|
||||
return strings.Join(columns, ",")
|
||||
}
|
||||
|
||||
func defaultValue(value interface{}) string {
|
||||
switch v := value.(type) {
|
||||
case string:
|
||||
return "'" + v + "'"
|
||||
case fmt.Stringer:
|
||||
return fmt.Sprintf("%#v", v)
|
||||
default:
|
||||
return fmt.Sprintf("%v", v)
|
||||
}
|
||||
}
|
||||
|
||||
func columnType(columnType ColumnType) string {
|
||||
switch columnType {
|
||||
case ColumnTypeText:
|
||||
return "TEXT"
|
||||
case ColumnTypeTextArray:
|
||||
return "TEXT[]"
|
||||
case ColumnTypeTimestamp:
|
||||
return "TIMESTAMPTZ"
|
||||
case ColumnTypeInterval:
|
||||
return "INTERVAL"
|
||||
case ColumnTypeEnum:
|
||||
return "SMALLINT"
|
||||
case ColumnTypeEnumArray:
|
||||
return "SMALLINT[]"
|
||||
case ColumnTypeInt64:
|
||||
return "BIGINT"
|
||||
case ColumnTypeBool:
|
||||
return "BOOLEAN"
|
||||
case ColumnTypeJSONB:
|
||||
return "JSONB"
|
||||
case ColumnTypeBytes:
|
||||
return "BYTEA"
|
||||
default:
|
||||
panic("unknown column type")
|
||||
}
|
||||
}
|
@@ -1,49 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import "testing"
|
||||
|
||||
func Test_defaultValue(t *testing.T) {
|
||||
type args struct {
|
||||
value interface{}
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "string",
|
||||
args: args{
|
||||
value: "asdf",
|
||||
},
|
||||
want: "'asdf'",
|
||||
},
|
||||
{
|
||||
name: "primitive non string",
|
||||
args: args{
|
||||
value: 1,
|
||||
},
|
||||
want: "1",
|
||||
},
|
||||
{
|
||||
name: "stringer",
|
||||
args: args{
|
||||
value: testStringer(0),
|
||||
},
|
||||
want: "0",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := defaultValue(tt.args.value); got != tt.want {
|
||||
t.Errorf("defaultValue() = %v, want %v", got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type testStringer int
|
||||
|
||||
func (t testStringer) String() string {
|
||||
return "0529958243"
|
||||
}
|
@@ -91,7 +91,7 @@ func (h *locker) Unlock(instanceIDs ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) {
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.TextArray[string]) (string, []interface{}) {
|
||||
valueQueries := make([]string, len(instanceIDs))
|
||||
values := make([]interface{}, len(instanceIDs)+4)
|
||||
values[0] = h.workerName
|
||||
|
@@ -158,7 +158,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 1 * time.Second,
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
instanceIDs: database.TextArray[string]{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -173,7 +173,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
instanceIDs: database.TextArray[string]{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -188,7 +188,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
instanceIDs: database.TextArray[string]{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
|
@@ -1,16 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
//reduce implements handler.Reduce function
|
||||
func (h *StatementHandler) reduce(event eventstore.Event) (*handler.Statement, error) {
|
||||
reduce, ok := h.reduces[event.Type()]
|
||||
if !ok {
|
||||
return NewNoOpStatement(event), nil
|
||||
}
|
||||
|
||||
return reduce(event)
|
||||
}
|
@@ -1,470 +0,0 @@
|
||||
package crdb
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
zitadel_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
type execOption func(*execConfig)
|
||||
type execConfig struct {
|
||||
tableName string
|
||||
|
||||
args []interface{}
|
||||
err error
|
||||
ignoreNotFound bool
|
||||
}
|
||||
|
||||
func WithTableSuffix(name string) func(*execConfig) {
|
||||
return func(o *execConfig) {
|
||||
o.tableName += "_" + name
|
||||
}
|
||||
}
|
||||
|
||||
func WithIgnoreNotFound() func(*execConfig) {
|
||||
return func(o *execConfig) {
|
||||
o.ignoreNotFound = true
|
||||
}
|
||||
}
|
||||
|
||||
func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
columnNames := strings.Join(cols, ", ")
|
||||
valuesPlaceholder := strings.Join(params, ", ")
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "INSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
}
|
||||
|
||||
func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
|
||||
conflictTarget := make([]string, len(conflictCols))
|
||||
for i, col := range conflictCols {
|
||||
conflictTarget[i] = col.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
|
||||
if len(updateCols) == 0 || len(updateVals) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
var updateStmt string
|
||||
// the postgres standard does not allow to update a single column using a multi-column update
|
||||
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
|
||||
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
|
||||
if len(updateCols) == 1 && !strings.HasPrefix(updateVals[0], "SELECT") {
|
||||
updateStmt = "UPDATE SET " + updateCols[0] + " = " + updateVals[0]
|
||||
} else {
|
||||
updateStmt = "UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
|
||||
}
|
||||
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
|
||||
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO " + updateStmt
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
}
|
||||
|
||||
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
|
||||
updateCols = make([]string, len(cols))
|
||||
updateVals = make([]string, len(cols))
|
||||
|
||||
copy(updateCols, cols)
|
||||
|
||||
for i := len(updateCols) - 1; i >= 0; i-- {
|
||||
updateVals[i] = "EXCLUDED." + updateCols[i]
|
||||
|
||||
for _, conflict := range conflictTarget {
|
||||
if conflict == updateCols[i] {
|
||||
copy(updateCols[i:], updateCols[i+1:])
|
||||
updateCols[len(updateCols)-1] = ""
|
||||
updateCols = updateCols[:len(updateCols)-1]
|
||||
|
||||
copy(updateVals[i:], updateVals[i+1:])
|
||||
updateVals[len(updateVals)-1] = ""
|
||||
updateVals = updateVals[:len(updateVals)-1]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return updateCols, updateVals
|
||||
}
|
||||
|
||||
func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditions []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
wheres, whereArgs := conditionsToWhere(conditions, len(args))
|
||||
args = append(args, whereArgs...)
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(values) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
config.err = handler.ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
// the postgres standard does not allow to update a single column using a multi-column update
|
||||
// discussion: https://www.postgresql.org/message-id/17451.1509381766%40sss.pgh.pa.us
|
||||
// see Compatibility in https://www.postgresql.org/docs/current/sql-update.html
|
||||
if len(cols) == 1 && !strings.HasPrefix(params[0], "SELECT") {
|
||||
return "UPDATE " + config.tableName + " SET " + cols[0] + " = " + params[0] + " WHERE " + strings.Join(wheres, " AND ")
|
||||
}
|
||||
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
}
|
||||
|
||||
func NewDeleteStatement(event eventstore.Event, conditions []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
wheres, args := conditionsToWhere(conditions, 0)
|
||||
|
||||
wheresPlaceholders := strings.Join(wheres, " AND ")
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(conditions) == 0 {
|
||||
config.err = handler.ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "DELETE FROM " + config.tableName + " WHERE " + wheresPlaceholders
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
}
|
||||
|
||||
func NewNoOpStatement(event eventstore.Event) *handler.Statement {
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
}
|
||||
}
|
||||
|
||||
func NewMultiStatement(event eventstore.Event, opts ...func(eventstore.Event) Exec) *handler.Statement {
|
||||
if len(opts) == 0 {
|
||||
return NewNoOpStatement(event)
|
||||
}
|
||||
execs := make([]Exec, len(opts))
|
||||
for i, opt := range opts {
|
||||
execs[i] = opt(event)
|
||||
}
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: multiExec(execs),
|
||||
}
|
||||
}
|
||||
|
||||
type Exec func(ex handler.Executer, projectionName string) error
|
||||
|
||||
func AddCreateStatement(columns []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewCreateStatement(event, columns, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpsertStatement(indexCols []handler.Column, values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpsertStatement(event, indexCols, values, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpdateStatement(values []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpdateStatement(event, values, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func AddDeleteStatement(conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewDeleteStatement(event, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayAppendCol(column string, value interface{}) handler.Column {
|
||||
return handler.Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
return "array_append(" + column + ", " + placeholder + ")"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayRemoveCol(column string, value interface{}) handler.Column {
|
||||
return handler.Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
return "array_remove(" + column + ", " + placeholder + ")"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewArrayIntersectCol(column string, value interface{}) handler.Column {
|
||||
var arrayType string
|
||||
switch value.(type) {
|
||||
|
||||
case []string, database.StringArray:
|
||||
arrayType = "TEXT"
|
||||
//TODO: handle more types if necessary
|
||||
}
|
||||
return handler.Column{
|
||||
Name: column,
|
||||
Value: value,
|
||||
ParameterOpt: func(placeholder string) string {
|
||||
return "SELECT ARRAY( SELECT UNNEST(" + column + ") INTERSECT SELECT UNNEST (" + placeholder + "::" + arrayType + "[]))"
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func NewCopyCol(column, from string) handler.Column {
|
||||
return handler.Column{
|
||||
Name: column,
|
||||
Value: handler.NewCol(from, nil),
|
||||
}
|
||||
}
|
||||
|
||||
func NewLessThanCond(column string, value interface{}) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
return column + " < " + param, value
|
||||
}
|
||||
}
|
||||
|
||||
func NewIsNullCond(column string) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
return column + " IS NULL", nil
|
||||
}
|
||||
}
|
||||
|
||||
// NewTextArrayContainsCond returns a handler.Condition that checks if the column that stores an array of text contains the given value
|
||||
func NewTextArrayContainsCond(column string, value string) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
return column + " @> " + param, database.StringArray{value}
|
||||
}
|
||||
}
|
||||
|
||||
// Not is a function and not a method, so that calling it is well readable
|
||||
// For example conditions := []handler.Condition{ Not(NewTextArrayContainsCond())}
|
||||
func Not(condition handler.Condition) handler.Condition {
|
||||
return func(param string) (string, interface{}) {
|
||||
cond, value := condition(param)
|
||||
return "NOT (" + cond + ")", value
|
||||
}
|
||||
}
|
||||
|
||||
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
|
||||
// cols represent the columns which are objective to change.
|
||||
// if the value of a col is empty the data will be copied from the selected row
|
||||
// if the value of a col is not empty the data will be set by the static value
|
||||
// conds represent the conditions for the selection subquery
|
||||
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.Column, nsCond []handler.NamespacedCondition, opts ...execOption) *handler.Statement {
|
||||
columnNames := make([]string, len(to))
|
||||
selectColumns := make([]string, len(from))
|
||||
updateColumns := make([]string, len(columnNames))
|
||||
argCounter := 0
|
||||
args := []interface{}{}
|
||||
|
||||
for i, col := range from {
|
||||
columnNames[i] = to[i].Name
|
||||
selectColumns[i] = from[i].Name
|
||||
updateColumns[i] = "EXCLUDED." + col.Name
|
||||
if col.Value != nil {
|
||||
argCounter++
|
||||
selectColumns[i] = "$" + strconv.Itoa(argCounter)
|
||||
updateColumns[i] = selectColumns[i]
|
||||
args = append(args, col.Value)
|
||||
}
|
||||
|
||||
}
|
||||
cond := make([]handler.Condition, len(nsCond))
|
||||
for i := range nsCond {
|
||||
cond[i] = nsCond[i]("copy_table")
|
||||
}
|
||||
wheres, values := conditionsToWhere(cond, len(args))
|
||||
args = append(args, values...)
|
||||
|
||||
conflictTargets := make([]string, len(conflictCols))
|
||||
for i, conflictCol := range conflictCols {
|
||||
conflictTargets[i] = conflictCol.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
|
||||
if len(from) == 0 || len(to) == 0 || len(from) != len(to) {
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
if len(cond) == 0 {
|
||||
config.err = handler.ErrNoCondition
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "INSERT INTO " +
|
||||
config.tableName +
|
||||
" (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
") SELECT " +
|
||||
strings.Join(selectColumns, ", ") +
|
||||
" FROM " +
|
||||
config.tableName + " AS copy_table WHERE " +
|
||||
strings.Join(wheres, " AND ") +
|
||||
" ON CONFLICT (" +
|
||||
strings.Join(conflictTargets, ", ") +
|
||||
") DO UPDATE SET (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
") = (" +
|
||||
strings.Join(updateColumns, ", ") +
|
||||
")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
AggregateType: event.Aggregate().Type,
|
||||
Sequence: event.Sequence(),
|
||||
PreviousSequence: event.PreviousAggregateTypeSequence(),
|
||||
InstanceID: event.Aggregate().InstanceID,
|
||||
Execute: exec(config, q, opts),
|
||||
}
|
||||
}
|
||||
|
||||
func columnsToQuery(cols []handler.Column) (names []string, parameters []string, values []interface{}) {
|
||||
names = make([]string, len(cols))
|
||||
values = make([]interface{}, len(cols))
|
||||
parameters = make([]string, len(cols))
|
||||
var parameterIndex int
|
||||
for i, col := range cols {
|
||||
names[i] = col.Name
|
||||
if c, ok := col.Value.(handler.Column); ok {
|
||||
parameters[i] = c.Name
|
||||
continue
|
||||
} else {
|
||||
values[parameterIndex] = col.Value
|
||||
}
|
||||
parameters[i] = "$" + strconv.Itoa(parameterIndex+1)
|
||||
if col.ParameterOpt != nil {
|
||||
parameters[i] = col.ParameterOpt(parameters[i])
|
||||
}
|
||||
parameterIndex++
|
||||
}
|
||||
return names, parameters, values[:parameterIndex]
|
||||
}
|
||||
|
||||
func conditionsToWhere(conditions []handler.Condition, paramOffset int) (wheres []string, values []interface{}) {
|
||||
wheres = make([]string, len(conditions))
|
||||
values = make([]interface{}, 0, len(conditions))
|
||||
for i, conditionFunc := range conditions {
|
||||
condition, value := conditionFunc("$" + strconv.Itoa(i+1+paramOffset))
|
||||
wheres[i] = "(" + condition + ")"
|
||||
if value != nil {
|
||||
values = append(values, value)
|
||||
}
|
||||
}
|
||||
return wheres, values
|
||||
}
|
||||
|
||||
type query func(config execConfig) string
|
||||
|
||||
func exec(config execConfig, q query, opts []execOption) Exec {
|
||||
return func(ex handler.Executer, projectionName string) error {
|
||||
if projectionName == "" {
|
||||
return handler.ErrNoProjection
|
||||
}
|
||||
|
||||
if config.err != nil {
|
||||
return config.err
|
||||
}
|
||||
|
||||
config.tableName = projectionName
|
||||
for _, opt := range opts {
|
||||
opt(&config)
|
||||
}
|
||||
|
||||
if _, err := ex.Exec(q(config), config.args...); err != nil {
|
||||
if config.ignoreNotFound && errors.Is(err, sql.ErrNoRows) {
|
||||
logging.WithError(err).Debugf("ignored not found: %v", err)
|
||||
return nil
|
||||
}
|
||||
return zitadel_errors.ThrowInternal(err, "CRDB-pKtsr", "exec failed")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func multiExec(execList []Exec) Exec {
|
||||
return func(ex handler.Executer, projectionName string) error {
|
||||
for _, exec := range execList {
|
||||
if err := exec(ex, projectionName); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user