feat(database): support for postgres (#3998)

* beginning with postgres statements

* try pgx

* use pgx

* database

* init works for postgres

* arrays working

* init for cockroach

* init

* start tests

* tests

* TESTS

* ch

* ch

* chore: use go 1.18

* read stmts

* fix typo

* tests

* connection string

* add missing error handler

* cleanup

* start all apis

* go mod tidy

* old update

* switch back to minute

* on conflict

* replace string slice with `database.StringArray` in db models

* fix tests and start

* update go version in dockerfile

* setup go

* clean up

* remove notification migration

* update

* docs: add deploy guide for postgres

* fix: revert sonyflake

* use `database.StringArray` for daos

* use `database.StringArray` every where

* new tables

* index naming,
metadata primary key,
project grant role key type

* docs(postgres): change to beta

* chore: correct compose

* fix(defaults): add empty postgres config

* refactor: remove unused code

* docs: add postgres to self hosted

* fix broken link

* so?

* change title

* add mdx to link

* fix stmt

* update goreleaser in test-code

* docs: improve postgres example

* update more projections

* fix: add beta log for postgres

* revert index name change

* prerelease

* fix: add sequence to v1 "reduce paniced"

* log if nil

* add logging

* fix: log output

* fix(import): check if org exists and user

* refactor: imports

* fix(user): ignore malformed events

* refactor: method naming

* fix: test

* refactor: correct errors.Is call

* ci: don't build dev binaries on main

* fix(go releaser): update version to 1.11.0

* fix(user): projection should not break

* fix(user): handle error properly

* docs: correct config example

* Update .releaserc.js

* Update .releaserc.js

Co-authored-by: Livio Amstutz <livio.a@gmail.com>
Co-authored-by: Elio Bischof <eliobischof@gmail.com>
This commit is contained in:
Silvan
2022-08-31 09:52:43 +02:00
committed by GitHub
parent d6c9815945
commit 77b4fc5487
189 changed files with 3401 additions and 2956 deletions

View File

@@ -6,15 +6,15 @@ import (
"strconv"
"strings"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
)
type currentSequences map[eventstore.AggregateType][]*instanceSequence
@@ -24,8 +24,8 @@ type instanceSequence struct {
sequence uint64
}
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs []string) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, pq.StringArray(instanceIDs))
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs database.StringArray) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, instanceIDs)
if err != nil {
return nil, err
}
@@ -74,7 +74,7 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS
}
}
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...)
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", ")+updateCurrentSequencesConflictStmt, values...)
if err != nil {
return errors.ThrowInternal(err, "CRDB-TrH2Z", "unable to exec update sequence")
}

View File

@@ -8,8 +8,8 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
@@ -17,7 +17,7 @@ type mockExpectation func(sqlmock.Sqlmock)
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\ AND instance_id = \$3\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`).
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2 AND instance_id = \$3\) SELECT COALESCE\(\(SELECT failure_count FROM failures\), 0\) AS failure_count`).
WithArgs(projectionName, failedSeq, instanceID).
WillReturnRows(
sqlmock.NewRows([]string{"failure_count"}).
@@ -28,7 +28,7 @@ func expectFailureCount(tableName string, projectionName, instanceID string, fai
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\)`).
m.ExpectExec(`INSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\) ON CONFLICT \(projection_name, failed_sequence, instance_id\) DO UPDATE SET failure_count = EXCLUDED\.failure_count, error = EXCLUDED\.error`).
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID).WillReturnResult(sqlmock.NewResult(1, 1))
}
}
@@ -133,7 +133,7 @@ func expectCurrentSequence(tableName, projection string, seq uint64, aggregateTy
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnRows(
rows,
@@ -146,7 +146,7 @@ func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnError(err)
}
@@ -157,7 +157,7 @@ func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []str
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
@@ -170,7 +170,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
@@ -182,7 +182,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
@@ -213,7 +213,7 @@ func expectUpdateThreeCurrentSequence(t *testing.T, tableName, projection string
matchers[i] = matcher
}
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\)`).
m.ExpectExec("INSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
matchers...,
).
@@ -262,7 +262,7 @@ func (m *currentSequenceMatcher) Match(value driver.Value) bool {
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
@@ -275,7 +275,7 @@ func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, er
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
@@ -297,10 +297,10 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID,
pq.StringArray{instanceID},
database.StringArray{instanceID},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -317,11 +317,11 @@ func expectLockMultipleInstances(lockTable, workerName string, d time.Duration,
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID1,
instanceID2,
pq.StringArray{instanceID1, instanceID2},
database.StringArray{instanceID1, instanceID2},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -338,10 +338,10 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID,
pq.StringArray{instanceID},
database.StringArray{instanceID},
).
WillReturnResult(driver.ResultNoRows)
}
@@ -356,10 +356,10 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID,
pq.StringArray{instanceID},
database.StringArray{instanceID},
).
WillReturnError(err)
}

View File

@@ -10,15 +10,12 @@ import (
)
const (
setFailureCountStmtFormat = "UPSERT INTO %s" +
setFailureCountStmtFormat = "INSERT INTO %s" +
" (projection_name, failed_sequence, failure_count, error, instance_id)" +
" VALUES ($1, $2, $3, $4, $5)"
" VALUES ($1, $2, $3, $4, $5) ON CONFLICT (projection_name, failed_sequence, instance_id)" +
" DO UPDATE SET failure_count = EXCLUDED.failure_count, error = EXCLUDED.error"
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
" SELECT IF(" +
"EXISTS(SELECT failure_count FROM failures)," +
" (SELECT failure_count FROM failures)," +
" 0" +
") AS failure_count"
" SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count"
)
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {

View File

@@ -72,12 +72,12 @@ func NewStatementHandler(
aggregates: aggregateTypes,
reduces: reduces,
bulkLimit: config.BulkLimit,
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionName),
}
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery, h.Lock, h.Unlock)
err := h.Init(ctx, config.InitCheck)
logging.OnError(err).Fatal("unable to initialize projections")
logging.OnError(err).WithField("projection", config.ProjectionName).Fatal("unable to initialize projections")
h.Subscribe(h.aggregates...)

View File

@@ -6,11 +6,10 @@ import (
"fmt"
"strings"
"github.com/lib/pq"
"github.com/jackc/pgconn"
"github.com/zitadel/logging"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
@@ -186,7 +185,7 @@ type ForeignKey struct {
RefColumns []string
}
//Init implements handler.Init
// Init implements handler.Init
func (h *StatementHandler) Init(ctx context.Context, checks ...*handler.Check) error {
for _, check := range checks {
if check == nil || check.IsNoop() {
@@ -280,7 +279,7 @@ func isErrAlreadyExists(err error) bool {
if !errors.As(err, &caosErr) {
return false
}
sqlErr, ok := caosErr.GetParent().(*pq.Error)
sqlErr, ok := caosErr.GetParent().(*pgconn.PgError)
if !ok {
return false
}
@@ -288,14 +287,11 @@ func isErrAlreadyExists(err error) bool {
}
func createTableStatement(table *Table, tableName string, suffix string) string {
stmt := fmt.Sprintf("CREATE TABLE %s (%s, PRIMARY KEY (%s)",
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
tableName+suffix,
createColumnsStatement(table.columns, tableName),
strings.Join(table.primaryKey, ", "),
)
for _, index := range table.indices {
stmt += fmt.Sprintf(", INDEX %s (%s)", index.Name, strings.Join(index.Columns, ","))
}
for _, key := range table.foreignKeys {
ref := tableName
if len(key.RefColumns) > 0 {
@@ -309,7 +305,13 @@ func createTableStatement(table *Table, tableName string, suffix string) string
for _, constraint := range table.constraints {
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraint.Name, strings.Join(constraint.Columns, ","))
}
return stmt + ");"
stmt += ");"
for _, index := range table.indices {
stmt += fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);", index.Name, tableName+suffix, strings.Join(index.Columns, ","))
}
return stmt
}
func createViewStatement(viewName string, selectStmt string) string {
@@ -321,7 +323,7 @@ func createViewStatement(viewName string, selectStmt string) string {
func createIndexStatement(index *Index) func(config execConfig) string {
return func(config execConfig) string {
stmt := fmt.Sprintf("CREATE INDEX %s ON %s (%s)",
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
index.Name,
config.tableName,
strings.Join(index.Columns, ","),
@@ -380,9 +382,8 @@ func columnType(columnType ColumnType) string {
case ColumnTypeJSONB:
return "JSONB"
case ColumnTypeBytes:
return "BYTES"
return "BYTEA"
default:
panic("unknown column type")
return ""
}
}

View File

@@ -8,9 +8,9 @@ import (
"strings"
"time"
"github.com/lib/pq"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/id"
)
@@ -91,17 +91,17 @@ func (h *locker) Unlock(instanceIDs ...string) error {
return nil
}
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs []string) (string, []interface{}) {
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) {
valueQueries := make([]string, len(instanceIDs))
values := make([]interface{}, len(instanceIDs)+4)
values[0] = h.workerName
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
values[1] = lockDuration.Seconds()
values[1] = lockDuration
values[2] = h.projectionName
for i, instanceID := range instanceIDs {
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
values[i+3] = instanceID
}
values[len(values)-1] = pq.StringArray(instanceIDs)
values[len(values)-1] = instanceIDs
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
z_errs "github.com/zitadel/zitadel/internal/errors"
)
@@ -43,9 +44,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2, "instanceID"),
expectLockErr(lockTable, workerName, 2, "instanceID", errLock),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLockErr(lockTable, workerName, 2*time.Second, "instanceID", errLock),
},
},
args: args{
@@ -63,8 +64,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
},
},
args: args{
@@ -81,8 +82,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
},
},
args: args{
@@ -149,7 +150,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone),
expectLockErr(lockTable, workerName, 1*time.Second, "instanceID", sql.ErrTxDone),
},
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
@@ -157,14 +158,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 1 * time.Second,
instanceIDs: []string{"instanceID"},
instanceIDs: database.StringArray{"instanceID"},
},
},
{
name: "lock no rows",
want: want{
expectations: []mockExpectation{
expectLockNoRows(lockTable, workerName, 2, "instanceID"),
expectLockNoRows(lockTable, workerName, 2*time.Second, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, renewNoRowsAffectedErr)
@@ -172,14 +173,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 2 * time.Second,
instanceIDs: []string{"instanceID"},
instanceIDs: database.StringArray{"instanceID"},
},
},
{
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 3, "instanceID"),
expectLock(lockTable, workerName, 3*time.Second, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
@@ -187,14 +188,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 3 * time.Second,
instanceIDs: []string{"instanceID"},
instanceIDs: database.StringArray{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 3, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 3*time.Second, "instanceID1", "instanceID2"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)

View File

@@ -4,8 +4,7 @@ import (
"strconv"
"strings"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/database"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
@@ -51,10 +50,13 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
}
}
func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ...execOption) *handler.Statement {
func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, values []handler.Column, opts ...execOption) *handler.Statement {
cols, params, args := columnsToQuery(values)
columnNames := strings.Join(cols, ", ")
valuesPlaceholder := strings.Join(params, ", ")
conflictTarget := make([]string, len(conflictCols))
for i, col := range conflictCols {
conflictTarget[i] = col.Name
}
config := execConfig{
args: args,
@@ -64,8 +66,14 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
config.err = handler.ErrNoValues
}
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
if len(updateCols) == 0 || len(updateVals) == 0 {
config.err = handler.ErrNoValues
}
q := func(config execConfig) string {
return "UPSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
}
return &handler.Statement{
@@ -77,15 +85,38 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
}
}
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
updateCols = make([]string, len(cols))
updateVals = make([]string, len(cols))
copy(updateCols, cols)
for i := len(updateCols) - 1; i >= 0; i-- {
updateVals[i] = "EXCLUDED." + updateCols[i]
for _, conflict := range conflictTarget {
if conflict == updateCols[i] {
copy(updateCols[i:], updateCols[i+1:])
updateCols[len(updateCols)-1] = ""
updateCols = updateCols[:len(updateCols)-1]
copy(updateVals[i:], updateVals[i+1:])
updateVals[len(updateVals)-1] = ""
updateVals = updateVals[:len(updateVals)-1]
break
}
}
}
return updateCols, updateVals
}
func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditions []handler.Condition, opts ...execOption) *handler.Statement {
cols, params, args := columnsToQuery(values)
wheres, whereArgs := conditionsToWhere(conditions, len(params))
args = append(args, whereArgs...)
columnNames := strings.Join(cols, ", ")
valuesPlaceholder := strings.Join(params, ", ")
wheresPlaceholders := strings.Join(wheres, " AND ")
config := execConfig{
args: args,
}
@@ -99,7 +130,7 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
}
q := func(config execConfig) string {
return "UPDATE " + config.tableName + " SET (" + columnNames + ") = (" + valuesPlaceholder + ") WHERE " + wheresPlaceholders
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
}
return &handler.Statement{
@@ -171,9 +202,9 @@ func AddCreateStatement(columns []handler.Column, opts ...execOption) func(event
}
}
func AddUpsertStatement(values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
func AddUpsertStatement(indexCols []handler.Column, values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpsertStatement(event, values, opts...).Execute
return NewUpsertStatement(event, indexCols, values, opts...).Execute
}
}
@@ -189,9 +220,9 @@ func AddDeleteStatement(conditions []handler.Condition, opts ...execOption) func
}
}
func AddCopyStatement(from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
func AddCopyStatement(conflict, from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCopyStatement(event, from, to, conditions, opts...).Execute
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
}
}
@@ -218,11 +249,9 @@ func NewArrayRemoveCol(column string, value interface{}) handler.Column {
func NewArrayIntersectCol(column string, value interface{}) handler.Column {
var arrayType string
switch value.(type) {
case pq.StringArray:
arrayType = "STRING"
case pq.Int32Array,
pq.Int64Array:
arrayType = "INT"
case []string, database.StringArray:
arrayType = "TEXT"
//TODO: handle more types if necessary
}
return handler.Column{
@@ -234,25 +263,29 @@ func NewArrayIntersectCol(column string, value interface{}) handler.Column {
}
}
//NewCopyStatement creates a new upsert statement which updates a column from an existing row
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
// cols represent the columns which are objective to change.
// if the value of a col is empty the data will be copied from the selected row
// if the value of a col is not empty the data will be set by the static value
// conds represent the conditions for the selection subquery
func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
columnNames := make([]string, len(to))
selectColumns := make([]string, len(from))
updateColumns := make([]string, len(columnNames))
argCounter := 0
args := []interface{}{}
for i := range from {
for i, col := range from {
columnNames[i] = to[i].Name
selectColumns[i] = from[i].Name
if from[i].Value != nil {
updateColumns[i] = "EXCLUDED." + col.Name
if col.Value != nil {
argCounter++
selectColumns[i] = "$" + strconv.Itoa(argCounter)
args = append(args, from[i].Value)
updateColumns[i] = selectColumns[i]
args = append(args, col.Value)
}
}
wheres := make([]string, len(conds))
@@ -262,6 +295,11 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
args = append(args, cond.Value)
}
conflictTargets := make([]string, len(conflictCols))
for i, conflictCol := range conflictCols {
conflictTargets[i] = conflictCol.Name
}
config := execConfig{
args: args,
}
@@ -275,7 +313,7 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
}
q := func(config execConfig) string {
return "UPSERT INTO " +
return "INSERT INTO " +
config.tableName +
" (" +
strings.Join(columnNames, ", ") +
@@ -283,7 +321,14 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
strings.Join(selectColumns, ", ") +
" FROM " +
config.tableName + " AS copy_table WHERE " +
strings.Join(wheres, " AND ")
strings.Join(wheres, " AND ") +
" ON CONFLICT (" +
strings.Join(conflictTargets, ", ") +
") DO UPDATE SET (" +
strings.Join(columnNames, ", ") +
") = (" +
strings.Join(updateColumns, ", ") +
")"
}
return &handler.Statement{

View File

@@ -178,9 +178,10 @@ func TestNewCreateStatement(t *testing.T) {
func TestNewUpsertStatement(t *testing.T) {
type args struct {
table string
event *testEvent
values []handler.Column
table string
event *testEvent
conflictCols []handler.Column
values []handler.Column
}
type want struct {
aggregateType eventstore.AggregateType
@@ -249,7 +250,7 @@ func TestNewUpsertStatement(t *testing.T) {
},
},
{
name: "correct",
name: "no update cols",
args: args{
table: "my_table",
event: &testEvent{
@@ -257,6 +258,9 @@ func TestNewUpsertStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conflictCols: []handler.Column{
handler.NewCol("col1", nil),
},
values: []handler.Column{
{
Name: "col1",
@@ -264,6 +268,42 @@ func TestNewUpsertStatement(t *testing.T) {
},
},
},
want: want{
table: "my_table",
aggregateType: "agg",
sequence: 1,
previousSequence: 1,
executer: &wantExecuter{
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoValues)
},
},
},
{
name: "correct",
args: args{
table: "my_table",
event: &testEvent{
aggregateType: "agg",
sequence: 1,
previousSequence: 0,
},
conflictCols: []handler.Column{
handler.NewCol("col1", nil),
},
values: []handler.Column{
{
Name: "col1",
Value: "val",
},
{
Name: "col2",
Value: "val",
},
},
},
want: want{
table: "my_table",
aggregateType: "agg",
@@ -272,8 +312,8 @@ func TestNewUpsertStatement(t *testing.T) {
executer: &wantExecuter{
params: []params{
{
query: "UPSERT INTO my_table (col1) VALUES ($1)",
args: []interface{}{"val"},
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
args: []interface{}{"val", "val"},
},
},
shouldExecute: true,
@@ -287,7 +327,7 @@ func TestNewUpsertStatement(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.want.executer.t = t
stmt := NewUpsertStatement(tt.args.event, tt.args.values)
stmt := NewUpsertStatement(tt.args.event, tt.args.conflictCols, tt.args.values)
err := stmt.Execute(tt.want.executer, tt.args.table)
if !tt.want.isErr(err) {
@@ -724,11 +764,18 @@ func TestNewMultiStatement(t *testing.T) {
},
}),
AddUpsertStatement(
[]handler.Column{
handler.NewCol("col1", nil),
},
[]handler.Column{
{
Name: "col1",
Value: 1,
},
{
Name: "col2",
Value: 2,
},
}),
AddUpdateStatement(
[]handler.Column{
@@ -761,8 +808,8 @@ func TestNewMultiStatement(t *testing.T) {
args: []interface{}{1},
},
{
query: "UPSERT INTO my_table (col1) VALUES ($1)",
args: []interface{}{1},
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
args: []interface{}{1, 2},
},
{
query: "UPDATE my_table SET (col1) = ($1) WHERE (col1 = $2)",
@@ -799,11 +846,12 @@ func TestNewMultiStatement(t *testing.T) {
func TestNewCopyStatement(t *testing.T) {
type args struct {
table string
event *testEvent
from []handler.Column
to []handler.Column
conds []handler.Condition
table string
event *testEvent
conflictingCols []handler.Column
from []handler.Column
to []handler.Column
conds []handler.Condition
}
type want struct {
aggregateType eventstore.AggregateType
@@ -1004,7 +1052,7 @@ func TestNewCopyStatement(t *testing.T) {
executer: &wantExecuter{
params: []params{
{
query: "UPSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
query: "INSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_a, col_b) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
args: []interface{}{1, 2, 3},
},
},
@@ -1071,7 +1119,7 @@ func TestNewCopyStatement(t *testing.T) {
executer: &wantExecuter{
params: []params{
{
query: "UPSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
query: "INSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_c, col_d) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
args: []interface{}{1, 2, 3},
},
},
@@ -1086,7 +1134,7 @@ func TestNewCopyStatement(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.want.executer.t = t
stmt := NewCopyStatement(tt.args.event, tt.args.from, tt.args.to, tt.args.conds)
stmt := NewCopyStatement(tt.args.event, tt.args.conflictingCols, tt.args.from, tt.args.to, tt.args.conds)
err := stmt.Execute(tt.want.executer, tt.args.table)
if !tt.want.isErr(err) {