mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:47:33 +00:00
feat(database): support for postgres (#3998)
* beginning with postgres statements * try pgx * use pgx * database * init works for postgres * arrays working * init for cockroach * init * start tests * tests * TESTS * ch * ch * chore: use go 1.18 * read stmts * fix typo * tests * connection string * add missing error handler * cleanup * start all apis * go mod tidy * old update * switch back to minute * on conflict * replace string slice with `database.StringArray` in db models * fix tests and start * update go version in dockerfile * setup go * clean up * remove notification migration * update * docs: add deploy guide for postgres * fix: revert sonyflake * use `database.StringArray` for daos * use `database.StringArray` every where * new tables * index naming, metadata primary key, project grant role key type * docs(postgres): change to beta * chore: correct compose * fix(defaults): add empty postgres config * refactor: remove unused code * docs: add postgres to self hosted * fix broken link * so? * change title * add mdx to link * fix stmt * update goreleaser in test-code * docs: improve postgres example * update more projections * fix: add beta log for postgres * revert index name change * prerelease * fix: add sequence to v1 "reduce paniced" * log if nil * add logging * fix: log output * fix(import): check if org exists and user * refactor: imports * fix(user): ignore malformed events * refactor: method naming * fix: test * refactor: correct errors.Is call * ci: don't build dev binaries on main * fix(go releaser): update version to 1.11.0 * fix(user): projection should not break * fix(user): handle error properly * docs: correct config example * Update .releaserc.js * Update .releaserc.js Co-authored-by: Livio Amstutz <livio.a@gmail.com> Co-authored-by: Elio Bischof <eliobischof@gmail.com>
This commit is contained in:
@@ -6,15 +6,15 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
|
||||
)
|
||||
|
||||
type currentSequences map[eventstore.AggregateType][]*instanceSequence
|
||||
@@ -24,8 +24,8 @@ type instanceSequence struct {
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs []string) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, pq.StringArray(instanceIDs))
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs database.StringArray) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...)
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", ")+updateCurrentSequencesConflictStmt, values...)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-TrH2Z", "unable to exec update sequence")
|
||||
}
|
||||
|
@@ -8,8 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ type mockExpectation func(sqlmock.Sqlmock)
|
||||
|
||||
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\ AND instance_id = \$3\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`).
|
||||
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2 AND instance_id = \$3\) SELECT COALESCE\(\(SELECT failure_count FROM failures\), 0\) AS failure_count`).
|
||||
WithArgs(projectionName, failedSeq, instanceID).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"failure_count"}).
|
||||
@@ -28,7 +28,7 @@ func expectFailureCount(tableName string, projectionName, instanceID string, fai
|
||||
|
||||
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\)`).
|
||||
m.ExpectExec(`INSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\) ON CONFLICT \(projection_name, failed_sequence, instance_id\) DO UPDATE SET failure_count = EXCLUDED\.failure_count, error = EXCLUDED\.error`).
|
||||
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func expectCurrentSequence(tableName, projection string, seq uint64, aggregateTy
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
rows,
|
||||
@@ -146,7 +146,7 @@ func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []str
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
|
||||
@@ -170,7 +170,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
@@ -182,7 +182,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
|
||||
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
@@ -213,7 +213,7 @@ func expectUpdateThreeCurrentSequence(t *testing.T, tableName, projection string
|
||||
matchers[i] = matcher
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
matchers...,
|
||||
).
|
||||
@@ -262,7 +262,7 @@ func (m *currentSequenceMatcher) Match(value driver.Value) bool {
|
||||
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
@@ -275,7 +275,7 @@ func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, er
|
||||
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
@@ -297,10 +297,10 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
database.StringArray{instanceID},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -317,11 +317,11 @@ func expectLockMultipleInstances(lockTable, workerName string, d time.Duration,
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID1,
|
||||
instanceID2,
|
||||
pq.StringArray{instanceID1, instanceID2},
|
||||
database.StringArray{instanceID1, instanceID2},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -338,10 +338,10 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
database.StringArray{instanceID},
|
||||
).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
@@ -356,10 +356,10 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
database.StringArray{instanceID},
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
|
@@ -10,15 +10,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
setFailureCountStmtFormat = "UPSERT INTO %s" +
|
||||
setFailureCountStmtFormat = "INSERT INTO %s" +
|
||||
" (projection_name, failed_sequence, failure_count, error, instance_id)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)"
|
||||
" VALUES ($1, $2, $3, $4, $5) ON CONFLICT (projection_name, failed_sequence, instance_id)" +
|
||||
" DO UPDATE SET failure_count = EXCLUDED.failure_count, error = EXCLUDED.error"
|
||||
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
|
||||
" SELECT IF(" +
|
||||
"EXISTS(SELECT failure_count FROM failures)," +
|
||||
" (SELECT failure_count FROM failures)," +
|
||||
" 0" +
|
||||
") AS failure_count"
|
||||
" SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count"
|
||||
)
|
||||
|
||||
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {
|
||||
|
@@ -72,12 +72,12 @@ func NewStatementHandler(
|
||||
aggregates: aggregateTypes,
|
||||
reduces: reduces,
|
||||
bulkLimit: config.BulkLimit,
|
||||
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
|
||||
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionName),
|
||||
}
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery, h.Lock, h.Unlock)
|
||||
|
||||
err := h.Init(ctx, config.InitCheck)
|
||||
logging.OnError(err).Fatal("unable to initialize projections")
|
||||
logging.OnError(err).WithField("projection", config.ProjectionName).Fatal("unable to initialize projections")
|
||||
|
||||
h.Subscribe(h.aggregates...)
|
||||
|
||||
|
@@ -6,11 +6,10 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
@@ -186,7 +185,7 @@ type ForeignKey struct {
|
||||
RefColumns []string
|
||||
}
|
||||
|
||||
//Init implements handler.Init
|
||||
// Init implements handler.Init
|
||||
func (h *StatementHandler) Init(ctx context.Context, checks ...*handler.Check) error {
|
||||
for _, check := range checks {
|
||||
if check == nil || check.IsNoop() {
|
||||
@@ -280,7 +279,7 @@ func isErrAlreadyExists(err error) bool {
|
||||
if !errors.As(err, &caosErr) {
|
||||
return false
|
||||
}
|
||||
sqlErr, ok := caosErr.GetParent().(*pq.Error)
|
||||
sqlErr, ok := caosErr.GetParent().(*pgconn.PgError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -288,14 +287,11 @@ func isErrAlreadyExists(err error) bool {
|
||||
}
|
||||
|
||||
func createTableStatement(table *Table, tableName string, suffix string) string {
|
||||
stmt := fmt.Sprintf("CREATE TABLE %s (%s, PRIMARY KEY (%s)",
|
||||
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
|
||||
tableName+suffix,
|
||||
createColumnsStatement(table.columns, tableName),
|
||||
strings.Join(table.primaryKey, ", "),
|
||||
)
|
||||
for _, index := range table.indices {
|
||||
stmt += fmt.Sprintf(", INDEX %s (%s)", index.Name, strings.Join(index.Columns, ","))
|
||||
}
|
||||
for _, key := range table.foreignKeys {
|
||||
ref := tableName
|
||||
if len(key.RefColumns) > 0 {
|
||||
@@ -309,7 +305,13 @@ func createTableStatement(table *Table, tableName string, suffix string) string
|
||||
for _, constraint := range table.constraints {
|
||||
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraint.Name, strings.Join(constraint.Columns, ","))
|
||||
}
|
||||
return stmt + ");"
|
||||
|
||||
stmt += ");"
|
||||
|
||||
for _, index := range table.indices {
|
||||
stmt += fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);", index.Name, tableName+suffix, strings.Join(index.Columns, ","))
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
func createViewStatement(viewName string, selectStmt string) string {
|
||||
@@ -321,7 +323,7 @@ func createViewStatement(viewName string, selectStmt string) string {
|
||||
|
||||
func createIndexStatement(index *Index) func(config execConfig) string {
|
||||
return func(config execConfig) string {
|
||||
stmt := fmt.Sprintf("CREATE INDEX %s ON %s (%s)",
|
||||
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
index.Name,
|
||||
config.tableName,
|
||||
strings.Join(index.Columns, ","),
|
||||
@@ -380,9 +382,8 @@ func columnType(columnType ColumnType) string {
|
||||
case ColumnTypeJSONB:
|
||||
return "JSONB"
|
||||
case ColumnTypeBytes:
|
||||
return "BYTES"
|
||||
return "BYTEA"
|
||||
default:
|
||||
panic("unknown column type")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
@@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
)
|
||||
@@ -91,17 +91,17 @@ func (h *locker) Unlock(instanceIDs ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs []string) (string, []interface{}) {
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) {
|
||||
valueQueries := make([]string, len(instanceIDs))
|
||||
values := make([]interface{}, len(instanceIDs)+4)
|
||||
values[0] = h.workerName
|
||||
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
|
||||
values[1] = lockDuration.Seconds()
|
||||
values[1] = lockDuration
|
||||
values[2] = h.projectionName
|
||||
for i, instanceID := range instanceIDs {
|
||||
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
|
||||
values[i+3] = instanceID
|
||||
}
|
||||
values[len(values)-1] = pq.StringArray(instanceIDs)
|
||||
values[len(values)-1] = instanceIDs
|
||||
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
|
||||
}
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
z_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
@@ -43,9 +44,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLockErr(lockTable, workerName, 2, "instanceID", errLock),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
expectLockErr(lockTable, workerName, 2*time.Second, "instanceID", errLock),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -63,8 +64,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -81,8 +82,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "success with multiple",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -149,7 +150,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone),
|
||||
expectLockErr(lockTable, workerName, 1*time.Second, "instanceID", sql.ErrTxDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
@@ -157,14 +158,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 1 * time.Second,
|
||||
instanceIDs: []string{"instanceID"},
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "lock no rows",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockNoRows(lockTable, workerName, 2, "instanceID"),
|
||||
expectLockNoRows(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, renewNoRowsAffectedErr)
|
||||
@@ -172,14 +173,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
instanceIDs: []string{"instanceID"},
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 3, "instanceID"),
|
||||
expectLock(lockTable, workerName, 3*time.Second, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
@@ -187,14 +188,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceIDs: []string{"instanceID"},
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with multiple",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockMultipleInstances(lockTable, workerName, 3, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 3*time.Second, "instanceID1", "instanceID2"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
|
@@ -4,8 +4,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
@@ -51,10 +50,13 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
|
||||
}
|
||||
}
|
||||
|
||||
func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
columnNames := strings.Join(cols, ", ")
|
||||
valuesPlaceholder := strings.Join(params, ", ")
|
||||
|
||||
conflictTarget := make([]string, len(conflictCols))
|
||||
for i, col := range conflictCols {
|
||||
conflictTarget[i] = col.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
@@ -64,8 +66,14 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
|
||||
if len(updateCols) == 0 || len(updateVals) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "UPSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
|
||||
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
|
||||
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
@@ -77,15 +85,38 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
|
||||
}
|
||||
}
|
||||
|
||||
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
|
||||
updateCols = make([]string, len(cols))
|
||||
updateVals = make([]string, len(cols))
|
||||
|
||||
copy(updateCols, cols)
|
||||
|
||||
for i := len(updateCols) - 1; i >= 0; i-- {
|
||||
updateVals[i] = "EXCLUDED." + updateCols[i]
|
||||
|
||||
for _, conflict := range conflictTarget {
|
||||
if conflict == updateCols[i] {
|
||||
copy(updateCols[i:], updateCols[i+1:])
|
||||
updateCols[len(updateCols)-1] = ""
|
||||
updateCols = updateCols[:len(updateCols)-1]
|
||||
|
||||
copy(updateVals[i:], updateVals[i+1:])
|
||||
updateVals[len(updateVals)-1] = ""
|
||||
updateVals = updateVals[:len(updateVals)-1]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return updateCols, updateVals
|
||||
}
|
||||
|
||||
func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditions []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
wheres, whereArgs := conditionsToWhere(conditions, len(params))
|
||||
args = append(args, whereArgs...)
|
||||
|
||||
columnNames := strings.Join(cols, ", ")
|
||||
valuesPlaceholder := strings.Join(params, ", ")
|
||||
wheresPlaceholders := strings.Join(wheres, " AND ")
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
@@ -99,7 +130,7 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "UPDATE " + config.tableName + " SET (" + columnNames + ") = (" + valuesPlaceholder + ") WHERE " + wheresPlaceholders
|
||||
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
@@ -171,9 +202,9 @@ func AddCreateStatement(columns []handler.Column, opts ...execOption) func(event
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpsertStatement(values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddUpsertStatement(indexCols []handler.Column, values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpsertStatement(event, values, opts...).Execute
|
||||
return NewUpsertStatement(event, indexCols, values, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,9 +220,9 @@ func AddDeleteStatement(conditions []handler.Condition, opts ...execOption) func
|
||||
}
|
||||
}
|
||||
|
||||
func AddCopyStatement(from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddCopyStatement(conflict, from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewCopyStatement(event, from, to, conditions, opts...).Execute
|
||||
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,11 +249,9 @@ func NewArrayRemoveCol(column string, value interface{}) handler.Column {
|
||||
func NewArrayIntersectCol(column string, value interface{}) handler.Column {
|
||||
var arrayType string
|
||||
switch value.(type) {
|
||||
case pq.StringArray:
|
||||
arrayType = "STRING"
|
||||
case pq.Int32Array,
|
||||
pq.Int64Array:
|
||||
arrayType = "INT"
|
||||
|
||||
case []string, database.StringArray:
|
||||
arrayType = "TEXT"
|
||||
//TODO: handle more types if necessary
|
||||
}
|
||||
return handler.Column{
|
||||
@@ -234,25 +263,29 @@ func NewArrayIntersectCol(column string, value interface{}) handler.Column {
|
||||
}
|
||||
}
|
||||
|
||||
//NewCopyStatement creates a new upsert statement which updates a column from an existing row
|
||||
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
|
||||
// cols represent the columns which are objective to change.
|
||||
// if the value of a col is empty the data will be copied from the selected row
|
||||
// if the value of a col is not empty the data will be set by the static value
|
||||
// conds represent the conditions for the selection subquery
|
||||
func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
columnNames := make([]string, len(to))
|
||||
selectColumns := make([]string, len(from))
|
||||
updateColumns := make([]string, len(columnNames))
|
||||
argCounter := 0
|
||||
args := []interface{}{}
|
||||
|
||||
for i := range from {
|
||||
for i, col := range from {
|
||||
columnNames[i] = to[i].Name
|
||||
selectColumns[i] = from[i].Name
|
||||
if from[i].Value != nil {
|
||||
updateColumns[i] = "EXCLUDED." + col.Name
|
||||
if col.Value != nil {
|
||||
argCounter++
|
||||
selectColumns[i] = "$" + strconv.Itoa(argCounter)
|
||||
args = append(args, from[i].Value)
|
||||
updateColumns[i] = selectColumns[i]
|
||||
args = append(args, col.Value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
wheres := make([]string, len(conds))
|
||||
@@ -262,6 +295,11 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
|
||||
args = append(args, cond.Value)
|
||||
}
|
||||
|
||||
conflictTargets := make([]string, len(conflictCols))
|
||||
for i, conflictCol := range conflictCols {
|
||||
conflictTargets[i] = conflictCol.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
@@ -275,7 +313,7 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "UPSERT INTO " +
|
||||
return "INSERT INTO " +
|
||||
config.tableName +
|
||||
" (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
@@ -283,7 +321,14 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
|
||||
strings.Join(selectColumns, ", ") +
|
||||
" FROM " +
|
||||
config.tableName + " AS copy_table WHERE " +
|
||||
strings.Join(wheres, " AND ")
|
||||
strings.Join(wheres, " AND ") +
|
||||
" ON CONFLICT (" +
|
||||
strings.Join(conflictTargets, ", ") +
|
||||
") DO UPDATE SET (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
") = (" +
|
||||
strings.Join(updateColumns, ", ") +
|
||||
")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
|
@@ -178,9 +178,10 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
|
||||
func TestNewUpsertStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
values []handler.Column
|
||||
table string
|
||||
event *testEvent
|
||||
conflictCols []handler.Column
|
||||
values []handler.Column
|
||||
}
|
||||
type want struct {
|
||||
aggregateType eventstore.AggregateType
|
||||
@@ -249,7 +250,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct",
|
||||
name: "no update cols",
|
||||
args: args{
|
||||
table: "my_table",
|
||||
event: &testEvent{
|
||||
@@ -257,6 +258,9 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conflictCols: []handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
},
|
||||
values: []handler.Column{
|
||||
{
|
||||
Name: "col1",
|
||||
@@ -264,6 +268,42 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
aggregateType: "agg",
|
||||
sequence: 1,
|
||||
previousSequence: 1,
|
||||
executer: &wantExecuter{
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoValues)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct",
|
||||
args: args{
|
||||
table: "my_table",
|
||||
event: &testEvent{
|
||||
aggregateType: "agg",
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conflictCols: []handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
},
|
||||
values: []handler.Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
},
|
||||
{
|
||||
Name: "col2",
|
||||
Value: "val",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
aggregateType: "agg",
|
||||
@@ -272,8 +312,8 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
executer: &wantExecuter{
|
||||
params: []params{
|
||||
{
|
||||
query: "UPSERT INTO my_table (col1) VALUES ($1)",
|
||||
args: []interface{}{"val"},
|
||||
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
|
||||
args: []interface{}{"val", "val"},
|
||||
},
|
||||
},
|
||||
shouldExecute: true,
|
||||
@@ -287,7 +327,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.want.executer.t = t
|
||||
stmt := NewUpsertStatement(tt.args.event, tt.args.values)
|
||||
stmt := NewUpsertStatement(tt.args.event, tt.args.conflictCols, tt.args.values)
|
||||
|
||||
err := stmt.Execute(tt.want.executer, tt.args.table)
|
||||
if !tt.want.isErr(err) {
|
||||
@@ -724,11 +764,18 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
},
|
||||
}),
|
||||
AddUpsertStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
},
|
||||
[]handler.Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
},
|
||||
{
|
||||
Name: "col2",
|
||||
Value: 2,
|
||||
},
|
||||
}),
|
||||
AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
@@ -761,8 +808,8 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
query: "UPSERT INTO my_table (col1) VALUES ($1)",
|
||||
args: []interface{}{1},
|
||||
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
|
||||
args: []interface{}{1, 2},
|
||||
},
|
||||
{
|
||||
query: "UPDATE my_table SET (col1) = ($1) WHERE (col1 = $2)",
|
||||
@@ -799,11 +846,12 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
|
||||
func TestNewCopyStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
from []handler.Column
|
||||
to []handler.Column
|
||||
conds []handler.Condition
|
||||
table string
|
||||
event *testEvent
|
||||
conflictingCols []handler.Column
|
||||
from []handler.Column
|
||||
to []handler.Column
|
||||
conds []handler.Condition
|
||||
}
|
||||
type want struct {
|
||||
aggregateType eventstore.AggregateType
|
||||
@@ -1004,7 +1052,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
executer: &wantExecuter{
|
||||
params: []params{
|
||||
{
|
||||
query: "UPSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
|
||||
query: "INSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_a, col_b) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
|
||||
args: []interface{}{1, 2, 3},
|
||||
},
|
||||
},
|
||||
@@ -1071,7 +1119,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
executer: &wantExecuter{
|
||||
params: []params{
|
||||
{
|
||||
query: "UPSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
|
||||
query: "INSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_c, col_d) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
|
||||
args: []interface{}{1, 2, 3},
|
||||
},
|
||||
},
|
||||
@@ -1086,7 +1134,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.want.executer.t = t
|
||||
stmt := NewCopyStatement(tt.args.event, tt.args.from, tt.args.to, tt.args.conds)
|
||||
stmt := NewCopyStatement(tt.args.event, tt.args.conflictingCols, tt.args.from, tt.args.to, tt.args.conds)
|
||||
|
||||
err := stmt.Execute(tt.want.executer, tt.args.table)
|
||||
if !tt.want.isErr(err) {
|
||||
|
Reference in New Issue
Block a user