mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:37:32 +00:00
feat(database): support for postgres (#3998)
* beginning with postgres statements * try pgx * use pgx * database * init works for postgres * arrays working * init for cockroach * init * start tests * tests * TESTS * ch * ch * chore: use go 1.18 * read stmts * fix typo * tests * connection string * add missing error handler * cleanup * start all apis * go mod tidy * old update * switch back to minute * on conflict * replace string slice with `database.StringArray` in db models * fix tests and start * update go version in dockerfile * setup go * clean up * remove notification migration * update * docs: add deploy guide for postgres * fix: revert sonyflake * use `database.StringArray` for daos * use `database.StringArray` every where * new tables * index naming, metadata primary key, project grant role key type * docs(postgres): change to beta * chore: correct compose * fix(defaults): add empty postgres config * refactor: remove unused code * docs: add postgres to self hosted * fix broken link * so? * change title * add mdx to link * fix stmt * update goreleaser in test-code * docs: improve postgres example * update more projections * fix: add beta log for postgres * revert index name change * prerelease * fix: add sequence to v1 "reduce paniced" * log if nil * add logging * fix: log output * fix(import): check if org exists and user * refactor: imports * fix(user): ignore malformed events * refactor: method naming * fix: test * refactor: correct errors.Is call * ci: don't build dev binaries on main * fix(go releaser): update version to 1.11.0 * fix(user): projection should not break * fix(user): handle error properly * docs: correct config example * Update .releaserc.js * Update .releaserc.js Co-authored-by: Livio Amstutz <livio.a@gmail.com> Co-authored-by: Elio Bischof <eliobischof@gmail.com>
This commit is contained in:
@@ -6,15 +6,15 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
const (
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
|
||||
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
|
||||
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
|
||||
)
|
||||
|
||||
type currentSequences map[eventstore.AggregateType][]*instanceSequence
|
||||
@@ -24,8 +24,8 @@ type instanceSequence struct {
|
||||
sequence uint64
|
||||
}
|
||||
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs []string) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, pq.StringArray(instanceIDs))
|
||||
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs database.StringArray) (currentSequences, error) {
|
||||
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, instanceIDs)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -74,7 +74,7 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS
|
||||
}
|
||||
}
|
||||
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...)
|
||||
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", ")+updateCurrentSequencesConflictStmt, values...)
|
||||
if err != nil {
|
||||
return errors.ThrowInternal(err, "CRDB-TrH2Z", "unable to exec update sequence")
|
||||
}
|
||||
|
@@ -8,8 +8,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ type mockExpectation func(sqlmock.Sqlmock)
|
||||
|
||||
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\ AND instance_id = \$3\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`).
|
||||
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2 AND instance_id = \$3\) SELECT COALESCE\(\(SELECT failure_count FROM failures\), 0\) AS failure_count`).
|
||||
WithArgs(projectionName, failedSeq, instanceID).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"failure_count"}).
|
||||
@@ -28,7 +28,7 @@ func expectFailureCount(tableName string, projectionName, instanceID string, fai
|
||||
|
||||
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\)`).
|
||||
m.ExpectExec(`INSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\) ON CONFLICT \(projection_name, failed_sequence, instance_id\) DO UPDATE SET failure_count = EXCLUDED\.failure_count, error = EXCLUDED\.error`).
|
||||
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID).WillReturnResult(sqlmock.NewResult(1, 1))
|
||||
}
|
||||
}
|
||||
@@ -133,7 +133,7 @@ func expectCurrentSequence(tableName, projection string, seq uint64, aggregateTy
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
rows,
|
||||
@@ -146,7 +146,7 @@ func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
@@ -157,7 +157,7 @@ func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []str
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
|
||||
@@ -170,7 +170,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
|
||||
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
|
||||
WithArgs(
|
||||
projection,
|
||||
pq.StringArray(instanceIDs),
|
||||
database.StringArray(instanceIDs),
|
||||
).
|
||||
WillReturnRows(
|
||||
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
|
||||
@@ -182,7 +182,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
|
||||
|
||||
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
@@ -213,7 +213,7 @@ func expectUpdateThreeCurrentSequence(t *testing.T, tableName, projection string
|
||||
matchers[i] = matcher
|
||||
}
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
matchers...,
|
||||
).
|
||||
@@ -262,7 +262,7 @@ func (m *currentSequenceMatcher) Match(value driver.Value) bool {
|
||||
|
||||
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
@@ -275,7 +275,7 @@ func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, er
|
||||
|
||||
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
|
||||
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
|
||||
WithArgs(
|
||||
projection,
|
||||
aggregateType,
|
||||
@@ -297,10 +297,10 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
database.StringArray{instanceID},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -317,11 +317,11 @@ func expectLockMultipleInstances(lockTable, workerName string, d time.Duration,
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID1,
|
||||
instanceID2,
|
||||
pq.StringArray{instanceID1, instanceID2},
|
||||
database.StringArray{instanceID1, instanceID2},
|
||||
).
|
||||
WillReturnResult(
|
||||
sqlmock.NewResult(1, 1),
|
||||
@@ -338,10 +338,10 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
database.StringArray{instanceID},
|
||||
).
|
||||
WillReturnResult(driver.ResultNoRows)
|
||||
}
|
||||
@@ -356,10 +356,10 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
|
||||
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
|
||||
WithArgs(
|
||||
workerName,
|
||||
float64(d),
|
||||
d,
|
||||
projectionName,
|
||||
instanceID,
|
||||
pq.StringArray{instanceID},
|
||||
database.StringArray{instanceID},
|
||||
).
|
||||
WillReturnError(err)
|
||||
}
|
||||
|
@@ -10,15 +10,12 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
setFailureCountStmtFormat = "UPSERT INTO %s" +
|
||||
setFailureCountStmtFormat = "INSERT INTO %s" +
|
||||
" (projection_name, failed_sequence, failure_count, error, instance_id)" +
|
||||
" VALUES ($1, $2, $3, $4, $5)"
|
||||
" VALUES ($1, $2, $3, $4, $5) ON CONFLICT (projection_name, failed_sequence, instance_id)" +
|
||||
" DO UPDATE SET failure_count = EXCLUDED.failure_count, error = EXCLUDED.error"
|
||||
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
|
||||
" SELECT IF(" +
|
||||
"EXISTS(SELECT failure_count FROM failures)," +
|
||||
" (SELECT failure_count FROM failures)," +
|
||||
" 0" +
|
||||
") AS failure_count"
|
||||
" SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count"
|
||||
)
|
||||
|
||||
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {
|
||||
|
@@ -72,12 +72,12 @@ func NewStatementHandler(
|
||||
aggregates: aggregateTypes,
|
||||
reduces: reduces,
|
||||
bulkLimit: config.BulkLimit,
|
||||
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
|
||||
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionName),
|
||||
}
|
||||
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery, h.Lock, h.Unlock)
|
||||
|
||||
err := h.Init(ctx, config.InitCheck)
|
||||
logging.OnError(err).Fatal("unable to initialize projections")
|
||||
logging.OnError(err).WithField("projection", config.ProjectionName).Fatal("unable to initialize projections")
|
||||
|
||||
h.Subscribe(h.aggregates...)
|
||||
|
||||
|
@@ -6,11 +6,10 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
)
|
||||
|
||||
@@ -186,7 +185,7 @@ type ForeignKey struct {
|
||||
RefColumns []string
|
||||
}
|
||||
|
||||
//Init implements handler.Init
|
||||
// Init implements handler.Init
|
||||
func (h *StatementHandler) Init(ctx context.Context, checks ...*handler.Check) error {
|
||||
for _, check := range checks {
|
||||
if check == nil || check.IsNoop() {
|
||||
@@ -280,7 +279,7 @@ func isErrAlreadyExists(err error) bool {
|
||||
if !errors.As(err, &caosErr) {
|
||||
return false
|
||||
}
|
||||
sqlErr, ok := caosErr.GetParent().(*pq.Error)
|
||||
sqlErr, ok := caosErr.GetParent().(*pgconn.PgError)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
@@ -288,14 +287,11 @@ func isErrAlreadyExists(err error) bool {
|
||||
}
|
||||
|
||||
func createTableStatement(table *Table, tableName string, suffix string) string {
|
||||
stmt := fmt.Sprintf("CREATE TABLE %s (%s, PRIMARY KEY (%s)",
|
||||
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
|
||||
tableName+suffix,
|
||||
createColumnsStatement(table.columns, tableName),
|
||||
strings.Join(table.primaryKey, ", "),
|
||||
)
|
||||
for _, index := range table.indices {
|
||||
stmt += fmt.Sprintf(", INDEX %s (%s)", index.Name, strings.Join(index.Columns, ","))
|
||||
}
|
||||
for _, key := range table.foreignKeys {
|
||||
ref := tableName
|
||||
if len(key.RefColumns) > 0 {
|
||||
@@ -309,7 +305,13 @@ func createTableStatement(table *Table, tableName string, suffix string) string
|
||||
for _, constraint := range table.constraints {
|
||||
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraint.Name, strings.Join(constraint.Columns, ","))
|
||||
}
|
||||
return stmt + ");"
|
||||
|
||||
stmt += ");"
|
||||
|
||||
for _, index := range table.indices {
|
||||
stmt += fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);", index.Name, tableName+suffix, strings.Join(index.Columns, ","))
|
||||
}
|
||||
return stmt
|
||||
}
|
||||
|
||||
func createViewStatement(viewName string, selectStmt string) string {
|
||||
@@ -321,7 +323,7 @@ func createViewStatement(viewName string, selectStmt string) string {
|
||||
|
||||
func createIndexStatement(index *Index) func(config execConfig) string {
|
||||
return func(config execConfig) string {
|
||||
stmt := fmt.Sprintf("CREATE INDEX %s ON %s (%s)",
|
||||
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
|
||||
index.Name,
|
||||
config.tableName,
|
||||
strings.Join(index.Columns, ","),
|
||||
@@ -380,9 +382,8 @@ func columnType(columnType ColumnType) string {
|
||||
case ColumnTypeJSONB:
|
||||
return "JSONB"
|
||||
case ColumnTypeBytes:
|
||||
return "BYTES"
|
||||
return "BYTEA"
|
||||
default:
|
||||
panic("unknown column type")
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
@@ -8,9 +8,9 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/id"
|
||||
)
|
||||
@@ -91,17 +91,17 @@ func (h *locker) Unlock(instanceIDs ...string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs []string) (string, []interface{}) {
|
||||
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) {
|
||||
valueQueries := make([]string, len(instanceIDs))
|
||||
values := make([]interface{}, len(instanceIDs)+4)
|
||||
values[0] = h.workerName
|
||||
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
|
||||
values[1] = lockDuration.Seconds()
|
||||
values[1] = lockDuration
|
||||
values[2] = h.projectionName
|
||||
for i, instanceID := range instanceIDs {
|
||||
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
|
||||
values[i+3] = instanceID
|
||||
}
|
||||
values[len(values)-1] = pq.StringArray(instanceIDs)
|
||||
values[len(values)-1] = instanceIDs
|
||||
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
|
||||
}
|
||||
|
@@ -10,6 +10,7 @@ import (
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
z_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
)
|
||||
|
||||
@@ -43,9 +44,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLockErr(lockTable, workerName, 2, "instanceID", errLock),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
expectLockErr(lockTable, workerName, 2*time.Second, "instanceID", errLock),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -63,8 +64,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -81,8 +82,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
|
||||
name: "success with multiple",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
|
||||
},
|
||||
},
|
||||
args: args{
|
||||
@@ -149,7 +150,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
name: "lock fails",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone),
|
||||
expectLockErr(lockTable, workerName, 1*time.Second, "instanceID", sql.ErrTxDone),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, sql.ErrTxDone)
|
||||
@@ -157,14 +158,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 1 * time.Second,
|
||||
instanceIDs: []string{"instanceID"},
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "lock no rows",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockNoRows(lockTable, workerName, 2, "instanceID"),
|
||||
expectLockNoRows(lockTable, workerName, 2*time.Second, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, renewNoRowsAffectedErr)
|
||||
@@ -172,14 +173,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 2 * time.Second,
|
||||
instanceIDs: []string{"instanceID"},
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLock(lockTable, workerName, 3, "instanceID"),
|
||||
expectLock(lockTable, workerName, 3*time.Second, "instanceID"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
@@ -187,14 +188,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
lockDuration: 3 * time.Second,
|
||||
instanceIDs: []string{"instanceID"},
|
||||
instanceIDs: database.StringArray{"instanceID"},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success with multiple",
|
||||
want: want{
|
||||
expectations: []mockExpectation{
|
||||
expectLockMultipleInstances(lockTable, workerName, 3, "instanceID1", "instanceID2"),
|
||||
expectLockMultipleInstances(lockTable, workerName, 3*time.Second, "instanceID1", "instanceID2"),
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, nil)
|
||||
|
@@ -4,8 +4,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
caos_errs "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/handler"
|
||||
@@ -51,10 +50,13 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
|
||||
}
|
||||
}
|
||||
|
||||
func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, values []handler.Column, opts ...execOption) *handler.Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
columnNames := strings.Join(cols, ", ")
|
||||
valuesPlaceholder := strings.Join(params, ", ")
|
||||
|
||||
conflictTarget := make([]string, len(conflictCols))
|
||||
for i, col := range conflictCols {
|
||||
conflictTarget[i] = col.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
@@ -64,8 +66,14 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
|
||||
if len(updateCols) == 0 || len(updateVals) == 0 {
|
||||
config.err = handler.ErrNoValues
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "UPSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
|
||||
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
|
||||
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
@@ -77,15 +85,38 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
|
||||
}
|
||||
}
|
||||
|
||||
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
|
||||
updateCols = make([]string, len(cols))
|
||||
updateVals = make([]string, len(cols))
|
||||
|
||||
copy(updateCols, cols)
|
||||
|
||||
for i := len(updateCols) - 1; i >= 0; i-- {
|
||||
updateVals[i] = "EXCLUDED." + updateCols[i]
|
||||
|
||||
for _, conflict := range conflictTarget {
|
||||
if conflict == updateCols[i] {
|
||||
copy(updateCols[i:], updateCols[i+1:])
|
||||
updateCols[len(updateCols)-1] = ""
|
||||
updateCols = updateCols[:len(updateCols)-1]
|
||||
|
||||
copy(updateVals[i:], updateVals[i+1:])
|
||||
updateVals[len(updateVals)-1] = ""
|
||||
updateVals = updateVals[:len(updateVals)-1]
|
||||
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return updateCols, updateVals
|
||||
}
|
||||
|
||||
func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditions []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
cols, params, args := columnsToQuery(values)
|
||||
wheres, whereArgs := conditionsToWhere(conditions, len(params))
|
||||
args = append(args, whereArgs...)
|
||||
|
||||
columnNames := strings.Join(cols, ", ")
|
||||
valuesPlaceholder := strings.Join(params, ", ")
|
||||
wheresPlaceholders := strings.Join(wheres, " AND ")
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
@@ -99,7 +130,7 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "UPDATE " + config.tableName + " SET (" + columnNames + ") = (" + valuesPlaceholder + ") WHERE " + wheresPlaceholders
|
||||
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
@@ -171,9 +202,9 @@ func AddCreateStatement(columns []handler.Column, opts ...execOption) func(event
|
||||
}
|
||||
}
|
||||
|
||||
func AddUpsertStatement(values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddUpsertStatement(indexCols []handler.Column, values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewUpsertStatement(event, values, opts...).Execute
|
||||
return NewUpsertStatement(event, indexCols, values, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
@@ -189,9 +220,9 @@ func AddDeleteStatement(conditions []handler.Condition, opts ...execOption) func
|
||||
}
|
||||
}
|
||||
|
||||
func AddCopyStatement(from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
func AddCopyStatement(conflict, from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
|
||||
return func(event eventstore.Event) Exec {
|
||||
return NewCopyStatement(event, from, to, conditions, opts...).Execute
|
||||
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
|
||||
}
|
||||
}
|
||||
|
||||
@@ -218,11 +249,9 @@ func NewArrayRemoveCol(column string, value interface{}) handler.Column {
|
||||
func NewArrayIntersectCol(column string, value interface{}) handler.Column {
|
||||
var arrayType string
|
||||
switch value.(type) {
|
||||
case pq.StringArray:
|
||||
arrayType = "STRING"
|
||||
case pq.Int32Array,
|
||||
pq.Int64Array:
|
||||
arrayType = "INT"
|
||||
|
||||
case []string, database.StringArray:
|
||||
arrayType = "TEXT"
|
||||
//TODO: handle more types if necessary
|
||||
}
|
||||
return handler.Column{
|
||||
@@ -234,25 +263,29 @@ func NewArrayIntersectCol(column string, value interface{}) handler.Column {
|
||||
}
|
||||
}
|
||||
|
||||
//NewCopyStatement creates a new upsert statement which updates a column from an existing row
|
||||
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
|
||||
// cols represent the columns which are objective to change.
|
||||
// if the value of a col is empty the data will be copied from the selected row
|
||||
// if the value of a col is not empty the data will be set by the static value
|
||||
// conds represent the conditions for the selection subquery
|
||||
func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
|
||||
columnNames := make([]string, len(to))
|
||||
selectColumns := make([]string, len(from))
|
||||
updateColumns := make([]string, len(columnNames))
|
||||
argCounter := 0
|
||||
args := []interface{}{}
|
||||
|
||||
for i := range from {
|
||||
for i, col := range from {
|
||||
columnNames[i] = to[i].Name
|
||||
selectColumns[i] = from[i].Name
|
||||
if from[i].Value != nil {
|
||||
updateColumns[i] = "EXCLUDED." + col.Name
|
||||
if col.Value != nil {
|
||||
argCounter++
|
||||
selectColumns[i] = "$" + strconv.Itoa(argCounter)
|
||||
args = append(args, from[i].Value)
|
||||
updateColumns[i] = selectColumns[i]
|
||||
args = append(args, col.Value)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
wheres := make([]string, len(conds))
|
||||
@@ -262,6 +295,11 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
|
||||
args = append(args, cond.Value)
|
||||
}
|
||||
|
||||
conflictTargets := make([]string, len(conflictCols))
|
||||
for i, conflictCol := range conflictCols {
|
||||
conflictTargets[i] = conflictCol.Name
|
||||
}
|
||||
|
||||
config := execConfig{
|
||||
args: args,
|
||||
}
|
||||
@@ -275,7 +313,7 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
|
||||
}
|
||||
|
||||
q := func(config execConfig) string {
|
||||
return "UPSERT INTO " +
|
||||
return "INSERT INTO " +
|
||||
config.tableName +
|
||||
" (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
@@ -283,7 +321,14 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
|
||||
strings.Join(selectColumns, ", ") +
|
||||
" FROM " +
|
||||
config.tableName + " AS copy_table WHERE " +
|
||||
strings.Join(wheres, " AND ")
|
||||
strings.Join(wheres, " AND ") +
|
||||
" ON CONFLICT (" +
|
||||
strings.Join(conflictTargets, ", ") +
|
||||
") DO UPDATE SET (" +
|
||||
strings.Join(columnNames, ", ") +
|
||||
") = (" +
|
||||
strings.Join(updateColumns, ", ") +
|
||||
")"
|
||||
}
|
||||
|
||||
return &handler.Statement{
|
||||
|
@@ -178,9 +178,10 @@ func TestNewCreateStatement(t *testing.T) {
|
||||
|
||||
func TestNewUpsertStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
values []handler.Column
|
||||
table string
|
||||
event *testEvent
|
||||
conflictCols []handler.Column
|
||||
values []handler.Column
|
||||
}
|
||||
type want struct {
|
||||
aggregateType eventstore.AggregateType
|
||||
@@ -249,7 +250,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct",
|
||||
name: "no update cols",
|
||||
args: args{
|
||||
table: "my_table",
|
||||
event: &testEvent{
|
||||
@@ -257,6 +258,9 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conflictCols: []handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
},
|
||||
values: []handler.Column{
|
||||
{
|
||||
Name: "col1",
|
||||
@@ -264,6 +268,42 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
aggregateType: "agg",
|
||||
sequence: 1,
|
||||
previousSequence: 1,
|
||||
executer: &wantExecuter{
|
||||
shouldExecute: false,
|
||||
},
|
||||
isErr: func(err error) bool {
|
||||
return errors.Is(err, handler.ErrNoValues)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "correct",
|
||||
args: args{
|
||||
table: "my_table",
|
||||
event: &testEvent{
|
||||
aggregateType: "agg",
|
||||
sequence: 1,
|
||||
previousSequence: 0,
|
||||
},
|
||||
conflictCols: []handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
},
|
||||
values: []handler.Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: "val",
|
||||
},
|
||||
{
|
||||
Name: "col2",
|
||||
Value: "val",
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
table: "my_table",
|
||||
aggregateType: "agg",
|
||||
@@ -272,8 +312,8 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
executer: &wantExecuter{
|
||||
params: []params{
|
||||
{
|
||||
query: "UPSERT INTO my_table (col1) VALUES ($1)",
|
||||
args: []interface{}{"val"},
|
||||
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
|
||||
args: []interface{}{"val", "val"},
|
||||
},
|
||||
},
|
||||
shouldExecute: true,
|
||||
@@ -287,7 +327,7 @@ func TestNewUpsertStatement(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.want.executer.t = t
|
||||
stmt := NewUpsertStatement(tt.args.event, tt.args.values)
|
||||
stmt := NewUpsertStatement(tt.args.event, tt.args.conflictCols, tt.args.values)
|
||||
|
||||
err := stmt.Execute(tt.want.executer, tt.args.table)
|
||||
if !tt.want.isErr(err) {
|
||||
@@ -724,11 +764,18 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
},
|
||||
}),
|
||||
AddUpsertStatement(
|
||||
[]handler.Column{
|
||||
handler.NewCol("col1", nil),
|
||||
},
|
||||
[]handler.Column{
|
||||
{
|
||||
Name: "col1",
|
||||
Value: 1,
|
||||
},
|
||||
{
|
||||
Name: "col2",
|
||||
Value: 2,
|
||||
},
|
||||
}),
|
||||
AddUpdateStatement(
|
||||
[]handler.Column{
|
||||
@@ -761,8 +808,8 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
args: []interface{}{1},
|
||||
},
|
||||
{
|
||||
query: "UPSERT INTO my_table (col1) VALUES ($1)",
|
||||
args: []interface{}{1},
|
||||
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
|
||||
args: []interface{}{1, 2},
|
||||
},
|
||||
{
|
||||
query: "UPDATE my_table SET (col1) = ($1) WHERE (col1 = $2)",
|
||||
@@ -799,11 +846,12 @@ func TestNewMultiStatement(t *testing.T) {
|
||||
|
||||
func TestNewCopyStatement(t *testing.T) {
|
||||
type args struct {
|
||||
table string
|
||||
event *testEvent
|
||||
from []handler.Column
|
||||
to []handler.Column
|
||||
conds []handler.Condition
|
||||
table string
|
||||
event *testEvent
|
||||
conflictingCols []handler.Column
|
||||
from []handler.Column
|
||||
to []handler.Column
|
||||
conds []handler.Condition
|
||||
}
|
||||
type want struct {
|
||||
aggregateType eventstore.AggregateType
|
||||
@@ -1004,7 +1052,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
executer: &wantExecuter{
|
||||
params: []params{
|
||||
{
|
||||
query: "UPSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
|
||||
query: "INSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_a, col_b) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
|
||||
args: []interface{}{1, 2, 3},
|
||||
},
|
||||
},
|
||||
@@ -1071,7 +1119,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
executer: &wantExecuter{
|
||||
params: []params{
|
||||
{
|
||||
query: "UPSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
|
||||
query: "INSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_c, col_d) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
|
||||
args: []interface{}{1, 2, 3},
|
||||
},
|
||||
},
|
||||
@@ -1086,7 +1134,7 @@ func TestNewCopyStatement(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tt.want.executer.t = t
|
||||
stmt := NewCopyStatement(tt.args.event, tt.args.from, tt.args.to, tt.args.conds)
|
||||
stmt := NewCopyStatement(tt.args.event, tt.args.conflictingCols, tt.args.from, tt.args.to, tt.args.conds)
|
||||
|
||||
err := stmt.Execute(tt.want.executer, tt.args.table)
|
||||
if !tt.want.isErr(err) {
|
||||
|
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/cmd/initialise"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/cockroach"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -45,13 +47,20 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func initDB(db *sql.DB) error {
|
||||
username := "zitadel"
|
||||
database := "zitadel"
|
||||
err := initialise.Initialise(db, initialise.VerifyUser(username, ""),
|
||||
initialise.VerifyDatabase(database),
|
||||
initialise.VerifyGrant(database, username))
|
||||
initialise.ReadStmts("cockroach")
|
||||
config := new(database.Config)
|
||||
config.SetConnector(&cockroach.Config{
|
||||
User: cockroach.User{
|
||||
Username: "zitadel",
|
||||
},
|
||||
Database: "zitadel",
|
||||
})
|
||||
err := initialise.Init(db,
|
||||
initialise.VerifyUser(config.Username(), ""),
|
||||
initialise.VerifyDatabase(config.Database()),
|
||||
initialise.VerifyGrant(config.Database(), config.Username()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return initialise.VerifyZitadel(db)
|
||||
return initialise.VerifyZitadel(db, *config)
|
||||
}
|
||||
|
@@ -9,6 +9,7 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/lib/pq"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
@@ -30,7 +31,7 @@ const (
|
||||
" SELECT MAX(event_sequence) seq, 1 join_me" +
|
||||
" FROM eventstore.events" +
|
||||
" WHERE aggregate_type = $2" +
|
||||
" AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
|
||||
" AND (CASE WHEN $9::TEXT IS NULL THEN instance_id is null else instance_id = $9::TEXT END)" +
|
||||
") AS agg_type " +
|
||||
// combined with
|
||||
"LEFT JOIN " +
|
||||
@@ -39,7 +40,7 @@ const (
|
||||
" SELECT event_sequence seq, resource_owner ro, 1 join_me" +
|
||||
" FROM eventstore.events" +
|
||||
" WHERE aggregate_type = $2 AND aggregate_id = $3" +
|
||||
" AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
|
||||
" AND (CASE WHEN $9::TEXT IS NULL THEN instance_id is null else instance_id = $9::TEXT END)" +
|
||||
" ORDER BY event_sequence DESC" +
|
||||
" LIMIT 1" +
|
||||
") AS agg USING(join_me)" +
|
||||
@@ -69,9 +70,9 @@ const (
|
||||
" $5::JSONB AS event_data," +
|
||||
" $6::VARCHAR AS editor_user," +
|
||||
" $7::VARCHAR AS editor_service," +
|
||||
" IFNULL((resource_owner), $8::VARCHAR) AS resource_owner," +
|
||||
" COALESCE((resource_owner), $8::VARCHAR) AS resource_owner," +
|
||||
" $9::VARCHAR AS instance_id," +
|
||||
" NEXTVAL(CONCAT('eventstore.', IF($9 <> '', CONCAT('i_', $9), 'system'), '_seq'))," +
|
||||
" NEXTVAL(CONCAT('eventstore.', (CASE WHEN $9 <> '' THEN CONCAT('i_', $9) ELSE 'system' END), '_seq'))," +
|
||||
" aggregate_sequence AS previous_aggregate_sequence," +
|
||||
" aggregate_type_sequence AS previous_aggregate_type_sequence " +
|
||||
"FROM previous_data " +
|
||||
@@ -156,7 +157,7 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
|
||||
var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`)
|
||||
|
||||
func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error {
|
||||
row := db.client.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1, '_seq')", instanceID)
|
||||
row := db.client.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID)
|
||||
if row.Err() != nil {
|
||||
return caos_errs.ThrowInvalidArgument(row.Err(), "SQL-7gtFA", "Errors.InvalidArgument")
|
||||
}
|
||||
@@ -218,7 +219,7 @@ func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery)
|
||||
return events, nil
|
||||
}
|
||||
|
||||
//LatestSequence returns the latest sequence found by the search query
|
||||
// LatestSequence returns the latest sequence found by the search query
|
||||
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
|
||||
var seq Sequence
|
||||
err := query(ctx, db, searchQuery, &seq)
|
||||
@@ -228,7 +229,7 @@ func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.Sear
|
||||
return uint64(seq), nil
|
||||
}
|
||||
|
||||
//InstanceIDs returns the instance ids found by the search query
|
||||
// InstanceIDs returns the instance ids found by the search query
|
||||
func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQuery) ([]string, error) {
|
||||
var ids []string
|
||||
err := query(ctx, db, searchQuery, &ids)
|
||||
@@ -331,7 +332,7 @@ var (
|
||||
placeholder = regexp.MustCompile(`\?`)
|
||||
)
|
||||
|
||||
//placeholder replaces all "?" with postgres placeholders ($<NUMBER>)
|
||||
// placeholder replaces all "?" with postgres placeholders ($<NUMBER>)
|
||||
func (db *CRDB) placeholder(query string) string {
|
||||
occurances := placeholder.FindAllStringIndex(query, -1)
|
||||
if len(occurances) == 0 {
|
||||
@@ -355,5 +356,10 @@ func (db *CRDB) isUniqueViolationError(err error) bool {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if pgxErr, ok := err.(*pgconn.PgError); ok {
|
||||
if pgxErr.Code == "23505" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
@@ -6,8 +6,7 @@ import (
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
@@ -279,7 +278,7 @@ func TestCRDB_Push_OneAggregate(t *testing.T) {
|
||||
uniqueCount int
|
||||
assetCount int
|
||||
aggType repository.AggregateType
|
||||
aggID []string
|
||||
aggID database.StringArray
|
||||
}
|
||||
type res struct {
|
||||
wantErr bool
|
||||
@@ -430,7 +429,7 @@ func TestCRDB_Push_OneAggregate(t *testing.T) {
|
||||
t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
}
|
||||
|
||||
countEventRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = $1 AND aggregate_id = ANY($2)", tt.res.eventsRes.aggType, pq.Array(tt.res.eventsRes.aggID))
|
||||
countEventRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = $1 AND aggregate_id = ANY($2)", tt.res.eventsRes.aggType, tt.res.eventsRes.aggID)
|
||||
var eventCount int
|
||||
err := countEventRow.Scan(&eventCount)
|
||||
if err != nil {
|
||||
@@ -462,8 +461,8 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
|
||||
}
|
||||
type eventsRes struct {
|
||||
pushedEventsCount int
|
||||
aggType []repository.AggregateType
|
||||
aggID []string
|
||||
aggType database.StringArray
|
||||
aggID database.StringArray
|
||||
}
|
||||
type res struct {
|
||||
wantErr bool
|
||||
@@ -487,7 +486,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
|
||||
eventsRes: eventsRes{
|
||||
pushedEventsCount: 2,
|
||||
aggID: []string{"100", "101"},
|
||||
aggType: []repository.AggregateType{repository.AggregateType(t.Name())},
|
||||
aggType: database.StringArray{t.Name()},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -506,7 +505,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
|
||||
eventsRes: eventsRes{
|
||||
pushedEventsCount: 4,
|
||||
aggID: []string{"102", "103"},
|
||||
aggType: []repository.AggregateType{repository.AggregateType(t.Name())},
|
||||
aggType: database.StringArray{t.Name()},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -533,7 +532,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
|
||||
eventsRes: eventsRes{
|
||||
pushedEventsCount: 12,
|
||||
aggID: []string{"106", "107", "108"},
|
||||
aggType: []repository.AggregateType{repository.AggregateType(t.Name())},
|
||||
aggType: database.StringArray{t.Name()},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -547,7 +546,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
|
||||
t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
}
|
||||
|
||||
countRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2)", pq.Array(tt.res.eventsRes.aggType), pq.Array(tt.res.eventsRes.aggID))
|
||||
countRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2)", tt.res.eventsRes.aggType, tt.res.eventsRes.aggID)
|
||||
var count int
|
||||
err := countRow.Scan(&count)
|
||||
if err != nil {
|
||||
@@ -645,8 +644,8 @@ func TestCRDB_Push_Parallel(t *testing.T) {
|
||||
}
|
||||
type eventsRes struct {
|
||||
pushedEventsCount int
|
||||
aggTypes []repository.AggregateType
|
||||
aggIDs []string
|
||||
aggTypes database.StringArray
|
||||
aggIDs database.StringArray
|
||||
}
|
||||
type res struct {
|
||||
errCount int
|
||||
@@ -681,7 +680,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
|
||||
eventsRes: eventsRes{
|
||||
aggIDs: []string{"200", "201", "202", "203"},
|
||||
pushedEventsCount: 9,
|
||||
aggTypes: []repository.AggregateType{repository.AggregateType(t.Name())},
|
||||
aggTypes: database.StringArray{t.Name()},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -718,7 +717,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
|
||||
eventsRes: eventsRes{
|
||||
aggIDs: []string{"204", "205", "206"},
|
||||
pushedEventsCount: 14,
|
||||
aggTypes: []repository.AggregateType{repository.AggregateType(t.Name())},
|
||||
aggTypes: database.StringArray{t.Name()},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -748,7 +747,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
|
||||
eventsRes: eventsRes{
|
||||
aggIDs: []string{"207", "208"},
|
||||
pushedEventsCount: 11,
|
||||
aggTypes: []repository.AggregateType{repository.AggregateType(t.Name())},
|
||||
aggTypes: database.StringArray{t.Name()},
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -781,7 +780,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
|
||||
t.Errorf("CRDB.Push() error count = %d, wanted err count %d, errs: %v", len(errs), tt.res.errCount, errs)
|
||||
}
|
||||
|
||||
rows, err := testCRDBClient.Query("SELECT event_data FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2) order by event_sequence", pq.Array(tt.res.eventsRes.aggTypes), pq.Array(tt.res.eventsRes.aggIDs))
|
||||
rows, err := testCRDBClient.Query("SELECT event_data FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2) order by event_sequence", tt.res.eventsRes.aggTypes, tt.res.eventsRes.aggIDs)
|
||||
if err != nil {
|
||||
t.Error("unable to query inserted rows: ", err)
|
||||
return
|
||||
@@ -993,10 +992,10 @@ func TestCRDB_Push_ResourceOwner(t *testing.T) {
|
||||
events []*repository.Event
|
||||
}
|
||||
type res struct {
|
||||
resourceOwners []string
|
||||
resourceOwners database.StringArray
|
||||
}
|
||||
type fields struct {
|
||||
aggregateIDs []string
|
||||
aggregateIDs database.StringArray
|
||||
aggregateType string
|
||||
}
|
||||
tests := []struct {
|
||||
@@ -1128,7 +1127,7 @@ func TestCRDB_Push_ResourceOwner(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
rows, err := testCRDBClient.Query("SELECT resource_owner FROM eventstore.events WHERE aggregate_type = $1 AND aggregate_id = ANY($2) ORDER BY event_sequence", tt.fields.aggregateType, pq.Array(tt.fields.aggregateIDs))
|
||||
rows, err := testCRDBClient.Query("SELECT resource_owner FROM eventstore.events WHERE aggregate_type = $1 AND aggregate_id = ANY($2) ORDER BY event_sequence", tt.fields.aggregateType, tt.fields.aggregateIDs)
|
||||
if err != nil {
|
||||
t.Error("unable to query inserted rows: ", err)
|
||||
return
|
||||
|
@@ -9,6 +9,8 @@ import (
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/cmd/initialise"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/cockroach"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -42,15 +44,22 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
func initDB(db *sql.DB) error {
|
||||
username := "zitadel"
|
||||
database := "zitadel"
|
||||
err := initialise.Initialise(db, initialise.VerifyUser(username, ""),
|
||||
initialise.VerifyDatabase(database),
|
||||
initialise.VerifyGrant(database, username))
|
||||
config := new(database.Config)
|
||||
config.SetConnector(&cockroach.Config{User: cockroach.User{Username: "zitadel"}, Database: "zitadel"})
|
||||
|
||||
if err := initialise.ReadStmts("cockroach"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err := initialise.Init(db,
|
||||
initialise.VerifyUser(config.Username(), ""),
|
||||
initialise.VerifyDatabase(config.Database()),
|
||||
initialise.VerifyGrant(config.Database(), config.Username()))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return initialise.VerifyZitadel(db)
|
||||
|
||||
return initialise.VerifyZitadel(db, *config)
|
||||
}
|
||||
|
||||
func fillUniqueData(unique_type, field, instanceID string) error {
|
||||
|
@@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
z_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
@@ -170,8 +169,6 @@ func prepareCondition(criteria querier, filters [][]*repository.Filter) (clause
|
||||
for _, f := range filter {
|
||||
value := f.Value
|
||||
switch value.(type) {
|
||||
case []bool, []float64, []int64, []string, []repository.AggregateType, []repository.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]repository.AggregateType, *[]repository.EventType:
|
||||
value = pq.Array(value)
|
||||
case map[string]interface{}:
|
||||
var err error
|
||||
value, err = json.Marshal(value)
|
||||
|
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
@@ -265,7 +264,7 @@ func Test_prepareCondition(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE ( aggregate_type = ANY(?) )",
|
||||
values: []interface{}{pq.Array([]repository.AggregateType{"user", "org"})},
|
||||
values: []interface{}{[]repository.AggregateType{"user", "org"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -281,7 +280,7 @@ func Test_prepareCondition(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
|
||||
values: []interface{}{pq.Array([]repository.AggregateType{"user", "org"}), "1234", pq.Array([]repository.EventType{"user.created", "org.created"})},
|
||||
values: []interface{}{[]repository.AggregateType{"user", "org"}, "1234", []repository.EventType{"user.created", "org.created"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@@ -3,11 +3,12 @@ package eventstore
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
|
||||
//SearchQueryBuilder represents the builder for your filter
|
||||
// SearchQueryBuilder represents the builder for your filter
|
||||
// if invalid data are set the filter will fail
|
||||
type SearchQueryBuilder struct {
|
||||
columns repository.Columns
|
||||
@@ -79,51 +80,51 @@ func (builder *SearchQueryBuilder) Matches(event Event, existingLen int) (matche
|
||||
return false
|
||||
}
|
||||
|
||||
//Columns defines which fields are set
|
||||
// Columns defines which fields are set
|
||||
func (builder *SearchQueryBuilder) Columns(columns Columns) *SearchQueryBuilder {
|
||||
builder.columns = repository.Columns(columns)
|
||||
return builder
|
||||
}
|
||||
|
||||
//Limit defines how many events are returned maximally.
|
||||
// Limit defines how many events are returned maximally.
|
||||
func (builder *SearchQueryBuilder) Limit(limit uint64) *SearchQueryBuilder {
|
||||
builder.limit = limit
|
||||
return builder
|
||||
}
|
||||
|
||||
//ResourceOwner defines the resource owner (org) of the events
|
||||
// ResourceOwner defines the resource owner (org) of the events
|
||||
func (builder *SearchQueryBuilder) ResourceOwner(resourceOwner string) *SearchQueryBuilder {
|
||||
builder.resourceOwner = resourceOwner
|
||||
return builder
|
||||
}
|
||||
|
||||
//InstanceID defines the instanceID (system) of the events
|
||||
// InstanceID defines the instanceID (system) of the events
|
||||
func (builder *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder {
|
||||
builder.instanceID = instanceID
|
||||
return builder
|
||||
}
|
||||
|
||||
//OrderDesc changes the sorting order of the returned events to descending
|
||||
// OrderDesc changes the sorting order of the returned events to descending
|
||||
func (builder *SearchQueryBuilder) OrderDesc() *SearchQueryBuilder {
|
||||
builder.desc = true
|
||||
return builder
|
||||
}
|
||||
|
||||
//OrderAsc changes the sorting order of the returned events to ascending
|
||||
// OrderAsc changes the sorting order of the returned events to ascending
|
||||
func (builder *SearchQueryBuilder) OrderAsc() *SearchQueryBuilder {
|
||||
builder.desc = false
|
||||
return builder
|
||||
}
|
||||
|
||||
//SetTx ensures that the eventstore library uses the existing transaction
|
||||
// SetTx ensures that the eventstore library uses the existing transaction
|
||||
func (builder *SearchQueryBuilder) SetTx(tx *sql.Tx) *SearchQueryBuilder {
|
||||
builder.tx = tx
|
||||
return builder
|
||||
}
|
||||
|
||||
//AddQuery creates a new sub query.
|
||||
//All fields in the sub query are AND-connected in the storage request.
|
||||
//Multiple sub queries are OR-connected in the storage request.
|
||||
// AddQuery creates a new sub query.
|
||||
// All fields in the sub query are AND-connected in the storage request.
|
||||
// Multiple sub queries are OR-connected in the storage request.
|
||||
func (builder *SearchQueryBuilder) AddQuery() *SearchQuery {
|
||||
query := &SearchQuery{
|
||||
builder: builder,
|
||||
@@ -133,61 +134,61 @@ func (builder *SearchQueryBuilder) AddQuery() *SearchQuery {
|
||||
return query
|
||||
}
|
||||
|
||||
//Or creates a new sub query on the search query builder
|
||||
// Or creates a new sub query on the search query builder
|
||||
func (query SearchQuery) Or() *SearchQuery {
|
||||
return query.builder.AddQuery()
|
||||
}
|
||||
|
||||
//AggregateTypes filters for events with the given aggregate types
|
||||
// AggregateTypes filters for events with the given aggregate types
|
||||
func (query *SearchQuery) AggregateTypes(types ...AggregateType) *SearchQuery {
|
||||
query.aggregateTypes = types
|
||||
return query
|
||||
}
|
||||
|
||||
//SequenceGreater filters for events with sequence greater the requested sequence
|
||||
// SequenceGreater filters for events with sequence greater the requested sequence
|
||||
func (query *SearchQuery) SequenceGreater(sequence uint64) *SearchQuery {
|
||||
query.eventSequenceGreater = sequence
|
||||
return query
|
||||
}
|
||||
|
||||
//SequenceLess filters for events with sequence less the requested sequence
|
||||
// SequenceLess filters for events with sequence less the requested sequence
|
||||
func (query *SearchQuery) SequenceLess(sequence uint64) *SearchQuery {
|
||||
query.eventSequenceLess = sequence
|
||||
return query
|
||||
}
|
||||
|
||||
//AggregateIDs filters for events with the given aggregate id's
|
||||
// AggregateIDs filters for events with the given aggregate id's
|
||||
func (query *SearchQuery) AggregateIDs(ids ...string) *SearchQuery {
|
||||
query.aggregateIDs = ids
|
||||
return query
|
||||
}
|
||||
|
||||
//InstanceID filters for events with the given instanceID
|
||||
// InstanceID filters for events with the given instanceID
|
||||
func (query *SearchQuery) InstanceID(instanceID string) *SearchQuery {
|
||||
query.instanceID = instanceID
|
||||
return query
|
||||
}
|
||||
|
||||
//ExcludedInstanceID filters for events not having the given instanceIDs
|
||||
// ExcludedInstanceID filters for events not having the given instanceIDs
|
||||
func (query *SearchQuery) ExcludedInstanceID(instanceIDs ...string) *SearchQuery {
|
||||
query.excludedInstanceIDs = instanceIDs
|
||||
return query
|
||||
}
|
||||
|
||||
//EventTypes filters for events with the given event types
|
||||
// EventTypes filters for events with the given event types
|
||||
func (query *SearchQuery) EventTypes(types ...EventType) *SearchQuery {
|
||||
query.eventTypes = types
|
||||
return query
|
||||
}
|
||||
|
||||
//EventData filters for events with the given event data.
|
||||
//Use this call with care as it will be slower than the other filters.
|
||||
// EventData filters for events with the given event data.
|
||||
// Use this call with care as it will be slower than the other filters.
|
||||
func (query *SearchQuery) EventData(data map[string]interface{}) *SearchQuery {
|
||||
query.eventData = data
|
||||
return query
|
||||
}
|
||||
|
||||
//Builder returns the SearchQueryBuilder of the sub query
|
||||
// Builder returns the SearchQueryBuilder of the sub query
|
||||
func (query *SearchQuery) Builder() *SearchQueryBuilder {
|
||||
return query.builder
|
||||
}
|
||||
@@ -262,7 +263,7 @@ func (query *SearchQuery) aggregateIDFilter() *repository.Filter {
|
||||
if len(query.aggregateIDs) == 1 {
|
||||
return repository.NewFilter(repository.FieldAggregateID, query.aggregateIDs[0], repository.OperationEquals)
|
||||
}
|
||||
return repository.NewFilter(repository.FieldAggregateID, query.aggregateIDs, repository.OperationIn)
|
||||
return repository.NewFilter(repository.FieldAggregateID, database.StringArray(query.aggregateIDs), repository.OperationIn)
|
||||
}
|
||||
|
||||
func (query *SearchQuery) eventTypeFilter() *repository.Filter {
|
||||
@@ -272,9 +273,9 @@ func (query *SearchQuery) eventTypeFilter() *repository.Filter {
|
||||
if len(query.eventTypes) == 1 {
|
||||
return repository.NewFilter(repository.FieldEventType, repository.EventType(query.eventTypes[0]), repository.OperationEquals)
|
||||
}
|
||||
eventTypes := make([]repository.EventType, len(query.eventTypes))
|
||||
eventTypes := make(database.StringArray, len(query.eventTypes))
|
||||
for i, eventType := range query.eventTypes {
|
||||
eventTypes[i] = repository.EventType(eventType)
|
||||
eventTypes[i] = string(eventType)
|
||||
}
|
||||
return repository.NewFilter(repository.FieldEventType, eventTypes, repository.OperationIn)
|
||||
}
|
||||
@@ -286,9 +287,9 @@ func (query *SearchQuery) aggregateTypeFilter() *repository.Filter {
|
||||
if len(query.aggregateTypes) == 1 {
|
||||
return repository.NewFilter(repository.FieldAggregateType, repository.AggregateType(query.aggregateTypes[0]), repository.OperationEquals)
|
||||
}
|
||||
aggregateTypes := make([]repository.AggregateType, len(query.aggregateTypes))
|
||||
aggregateTypes := make(database.StringArray, len(query.aggregateTypes))
|
||||
for i, aggregateType := range query.aggregateTypes {
|
||||
aggregateTypes[i] = repository.AggregateType(aggregateType)
|
||||
aggregateTypes[i] = string(aggregateType)
|
||||
}
|
||||
return repository.NewFilter(repository.FieldAggregateType, aggregateTypes, repository.OperationIn)
|
||||
}
|
||||
@@ -326,7 +327,7 @@ func (query *SearchQuery) excludedInstanceIDFilter() *repository.Filter {
|
||||
if len(query.excludedInstanceIDs) == 0 {
|
||||
return nil
|
||||
}
|
||||
return repository.NewFilter(repository.FieldInstanceID, query.excludedInstanceIDs, repository.OperationNotIn)
|
||||
return repository.NewFilter(repository.FieldInstanceID, database.StringArray(query.excludedInstanceIDs), repository.OperationNotIn)
|
||||
}
|
||||
|
||||
func (builder *SearchQueryBuilder) resourceOwnerFilter() *repository.Filter {
|
||||
|
@@ -5,6 +5,7 @@ import (
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/eventstore/repository"
|
||||
)
|
||||
@@ -312,7 +313,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
|
||||
Limit: 0,
|
||||
Filters: [][]*repository.Filter{
|
||||
{
|
||||
repository.NewFilter(repository.FieldAggregateType, []repository.AggregateType{"user", "org"}, repository.OperationIn),
|
||||
repository.NewFilter(repository.FieldAggregateType, database.StringArray{"user", "org"}, repository.OperationIn),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -483,7 +484,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
|
||||
Filters: [][]*repository.Filter{
|
||||
{
|
||||
repository.NewFilter(repository.FieldAggregateType, repository.AggregateType("user"), repository.OperationEquals),
|
||||
repository.NewFilter(repository.FieldAggregateID, []string{"1234", "0815"}, repository.OperationIn),
|
||||
repository.NewFilter(repository.FieldAggregateID, database.StringArray{"1234", "0815"}, repository.OperationIn),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -561,7 +562,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
|
||||
Filters: [][]*repository.Filter{
|
||||
{
|
||||
repository.NewFilter(repository.FieldAggregateType, repository.AggregateType("user"), repository.OperationEquals),
|
||||
repository.NewFilter(repository.FieldEventType, []repository.EventType{"user.created", "user.changed"}, repository.OperationIn),
|
||||
repository.NewFilter(repository.FieldEventType, database.StringArray{"user.created", "user.changed"}, repository.OperationIn),
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -740,10 +741,10 @@ func assertRepoQuery(t *testing.T, want, got *repository.SearchQuery) {
|
||||
if !reflect.DeepEqual(got.Columns, want.Columns) {
|
||||
t.Errorf("wrong columns in query: got: %v want: %v", got.Columns, want.Columns)
|
||||
}
|
||||
if !reflect.DeepEqual(got.Desc, want.Desc) {
|
||||
if got.Desc != want.Desc {
|
||||
t.Errorf("wrong desc in query: got: %v want: %v", got.Desc, want.Desc)
|
||||
}
|
||||
if !reflect.DeepEqual(got.Limit, want.Limit) {
|
||||
if got.Limit != want.Limit {
|
||||
t.Errorf("wrong limit in query: got: %v want: %v", got.Limit, want.Limit)
|
||||
}
|
||||
|
||||
|
@@ -2,8 +2,6 @@ package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
func Start(client *sql.DB) *SQL {
|
||||
|
@@ -7,7 +7,6 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/lib/pq"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
z_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
@@ -72,10 +71,6 @@ func prepareCondition(filters [][]*es_models.Filter) (clause string, values []in
|
||||
subClauses := make([]string, 0, len(filter))
|
||||
for _, f := range filter {
|
||||
value := f.GetValue()
|
||||
switch value.(type) {
|
||||
case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType:
|
||||
value = pq.Array(value)
|
||||
}
|
||||
|
||||
subClauses = append(subClauses, getCondition(f))
|
||||
if subClauses[len(subClauses)-1] == "" {
|
||||
|
@@ -6,8 +6,6 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/lib/pq"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
|
||||
)
|
||||
@@ -365,7 +363,7 @@ func Test_prepareCondition(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE ( aggregate_type = ANY(?) )",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
|
||||
values: []interface{}{[]es_models.AggregateType{"user", "org"}},
|
||||
},
|
||||
},
|
||||
{
|
||||
@@ -381,7 +379,7 @@ func Test_prepareCondition(t *testing.T) {
|
||||
},
|
||||
res: res{
|
||||
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
|
||||
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
|
||||
values: []interface{}{[]es_models.AggregateType{"user", "org"}, "1234", []es_models.EventType{"user.created", "org.created"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
@@ -3,9 +3,6 @@ package sql
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
//sql import
|
||||
_ "github.com/lib/pq"
|
||||
)
|
||||
|
||||
type SQL struct {
|
||||
|
@@ -31,7 +31,7 @@ func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel, instanceID string,
|
||||
return crdb.ExecuteTx(context.Background(), dbClient, nil, func(tx *sql.Tx) error {
|
||||
insert := fmt.Sprintf(insertStmtFormat, lockTable)
|
||||
result, err := tx.Exec(insert,
|
||||
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel, instanceID)
|
||||
lockerID, waitTime, viewModel, instanceID)
|
||||
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
|
@@ -38,7 +38,12 @@ func ReduceEvent(handler Handler, event *models.Event) {
|
||||
|
||||
if err != nil {
|
||||
handler.Subscription().Unsubscribe()
|
||||
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
|
||||
logging.WithFields(
|
||||
"cause", err,
|
||||
"stack", string(debug.Stack()),
|
||||
"sequence", event.Sequence,
|
||||
"instnace", event.InstanceID,
|
||||
).Error("reduce panicked")
|
||||
}
|
||||
}()
|
||||
currentSequence, err := handler.CurrentSequence(event.InstanceID)
|
||||
|
@@ -75,7 +75,10 @@ func (s *spooledHandler) load(workerID string) {
|
||||
err := recover()
|
||||
|
||||
if err != nil {
|
||||
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
|
||||
logging.WithFields(
|
||||
"cause", err,
|
||||
"stack", string(debug.Stack()),
|
||||
).Error("reduce panicked")
|
||||
}
|
||||
}()
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
@@ -167,7 +170,7 @@ func (s *spooledHandler) query(ctx context.Context, instanceIDs ...string) ([]*m
|
||||
return s.eventstore.FilterEvents(ctx, query)
|
||||
}
|
||||
|
||||
//lock ensures the lock on the database.
|
||||
// lock ensures the lock on the database.
|
||||
// the returned channel will be closed if ctx is done or an error occured durring lock
|
||||
func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID string) chan bool {
|
||||
renewTimer := time.After(0)
|
||||
|
Reference in New Issue
Block a user