feat(database): support for postgres (#3998)

* beginning with postgres statements

* try pgx

* use pgx

* database

* init works for postgres

* arrays working

* init for cockroach

* init

* start tests

* tests

* TESTS

* ch

* ch

* chore: use go 1.18

* read stmts

* fix typo

* tests

* connection string

* add missing error handler

* cleanup

* start all apis

* go mod tidy

* old update

* switch back to minute

* on conflict

* replace string slice with `database.StringArray` in db models

* fix tests and start

* update go version in dockerfile

* setup go

* clean up

* remove notification migration

* update

* docs: add deploy guide for postgres

* fix: revert sonyflake

* use `database.StringArray` for daos

* use `database.StringArray` every where

* new tables

* index naming,
metadata primary key,
project grant role key type

* docs(postgres): change to beta

* chore: correct compose

* fix(defaults): add empty postgres config

* refactor: remove unused code

* docs: add postgres to self hosted

* fix broken link

* so?

* change title

* add mdx to link

* fix stmt

* update goreleaser in test-code

* docs: improve postgres example

* update more projections

* fix: add beta log for postgres

* revert index name change

* prerelease

* fix: add sequence to v1 "reduce paniced"

* log if nil

* add logging

* fix: log output

* fix(import): check if org exists and user

* refactor: imports

* fix(user): ignore malformed events

* refactor: method naming

* fix: test

* refactor: correct errors.Is call

* ci: don't build dev binaries on main

* fix(go releaser): update version to 1.11.0

* fix(user): projection should not break

* fix(user): handle error properly

* docs: correct config example

* Update .releaserc.js

* Update .releaserc.js

Co-authored-by: Livio Amstutz <livio.a@gmail.com>
Co-authored-by: Elio Bischof <eliobischof@gmail.com>
This commit is contained in:
Silvan
2022-08-31 09:52:43 +02:00
committed by GitHub
parent d6c9815945
commit 77b4fc5487
189 changed files with 3401 additions and 2956 deletions

View File

@@ -6,15 +6,15 @@ import (
"strconv"
"strings"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
updateCurrentSequencesStmtFormat = `UPSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
)
type currentSequences map[eventstore.AggregateType][]*instanceSequence
@@ -24,8 +24,8 @@ type instanceSequence struct {
sequence uint64
}
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs []string) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, pq.StringArray(instanceIDs))
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs database.StringArray) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, instanceIDs)
if err != nil {
return nil, err
}
@@ -74,7 +74,7 @@ func (h *StatementHandler) updateCurrentSequences(tx *sql.Tx, sequences currentS
}
}
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", "), values...)
res, err := tx.Exec(h.updateSequencesBaseStmt+strings.Join(valueQueries, ", ")+updateCurrentSequencesConflictStmt, values...)
if err != nil {
return errors.ThrowInternal(err, "CRDB-TrH2Z", "unable to exec update sequence")
}

View File

@@ -8,8 +8,8 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
@@ -17,7 +17,7 @@ type mockExpectation func(sqlmock.Sqlmock)
func expectFailureCount(tableName string, projectionName, instanceID string, failedSeq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2\ AND instance_id = \$3\) SELECT IF\(EXISTS\(SELECT failure_count FROM failures\), \(SELECT failure_count FROM failures\), 0\) AS failure_count`).
m.ExpectQuery(`WITH failures AS \(SELECT failure_count FROM `+tableName+` WHERE projection_name = \$1 AND failed_sequence = \$2 AND instance_id = \$3\) SELECT COALESCE\(\(SELECT failure_count FROM failures\), 0\) AS failure_count`).
WithArgs(projectionName, failedSeq, instanceID).
WillReturnRows(
sqlmock.NewRows([]string{"failure_count"}).
@@ -28,7 +28,7 @@ func expectFailureCount(tableName string, projectionName, instanceID string, fai
func expectUpdateFailureCount(tableName string, projectionName, instanceID string, seq, failureCount uint64) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec(`UPSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\)`).
m.ExpectExec(`INSERT INTO `+tableName+` \(projection_name, failed_sequence, failure_count, error, instance_id\) VALUES \(\$1, \$2, \$3, \$4\, \$5\) ON CONFLICT \(projection_name, failed_sequence, instance_id\) DO UPDATE SET failure_count = EXCLUDED\.failure_count, error = EXCLUDED\.error`).
WithArgs(projectionName, seq, failureCount, sqlmock.AnyArg(), instanceID).WillReturnResult(sqlmock.NewResult(1, 1))
}
}
@@ -133,7 +133,7 @@ func expectCurrentSequence(tableName, projection string, seq uint64, aggregateTy
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnRows(
rows,
@@ -146,7 +146,7 @@ func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnError(err)
}
@@ -157,7 +157,7 @@ func expectCurrentSequenceNoRows(tableName, projection string, instanceIDs []str
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}),
@@ -170,7 +170,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`).
WithArgs(
projection,
pq.StringArray(instanceIDs),
database.StringArray(instanceIDs),
).
WillReturnRows(
sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}).
@@ -182,7 +182,7 @@ func expectCurrentSequenceScanErr(tableName, projection string, instanceIDs []st
func expectUpdateCurrentSequence(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
@@ -213,7 +213,7 @@ func expectUpdateThreeCurrentSequence(t *testing.T, tableName, projection string
matchers[i] = matcher
}
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\)`).
m.ExpectExec("INSERT INTO " + tableName + ` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\), \(\$5, \$6, \$7, \$8, NOW\(\)\), \(\$9, \$10, \$11, \$12, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
matchers...,
).
@@ -262,7 +262,7 @@ func (m *currentSequenceMatcher) Match(value driver.Value) bool {
func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, err error, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
@@ -275,7 +275,7 @@ func expectUpdateCurrentSequenceErr(tableName, projection string, seq uint64, er
func expectUpdateCurrentSequenceNoRows(tableName, projection string, seq uint64, aggregateType, instanceID string) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectExec("UPSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\)`).
m.ExpectExec("INSERT INTO "+tableName+` \(projection_name, aggregate_type, current_sequence, instance_id, timestamp\) VALUES \(\$1, \$2, \$3, \$4, NOW\(\)\) ON CONFLICT \(projection_name, aggregate_type, instance_id\) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`).
WithArgs(
projection,
aggregateType,
@@ -297,10 +297,10 @@ func expectLock(lockTable, workerName string, d time.Duration, instanceID string
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID,
pq.StringArray{instanceID},
database.StringArray{instanceID},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -317,11 +317,11 @@ func expectLockMultipleInstances(lockTable, workerName string, d time.Duration,
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$6\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID1,
instanceID2,
pq.StringArray{instanceID1, instanceID2},
database.StringArray{instanceID1, instanceID2},
).
WillReturnResult(
sqlmock.NewResult(1, 1),
@@ -338,10 +338,10 @@ func expectLockNoRows(lockTable, workerName string, d time.Duration, instanceID
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID,
pq.StringArray{instanceID},
database.StringArray{instanceID},
).
WillReturnResult(driver.ResultNoRows)
}
@@ -356,10 +356,10 @@ func expectLockErr(lockTable, workerName string, d time.Duration, instanceID str
` WHERE `+lockTable+`\.projection_name = \$3 AND `+lockTable+`\.instance_id = ANY \(\$5\) AND \(`+lockTable+`\.locker_id = \$1 OR `+lockTable+`\.locked_until < now\(\)\)`).
WithArgs(
workerName,
float64(d),
d,
projectionName,
instanceID,
pq.StringArray{instanceID},
database.StringArray{instanceID},
).
WillReturnError(err)
}

View File

@@ -10,15 +10,12 @@ import (
)
const (
setFailureCountStmtFormat = "UPSERT INTO %s" +
setFailureCountStmtFormat = "INSERT INTO %s" +
" (projection_name, failed_sequence, failure_count, error, instance_id)" +
" VALUES ($1, $2, $3, $4, $5)"
" VALUES ($1, $2, $3, $4, $5) ON CONFLICT (projection_name, failed_sequence, instance_id)" +
" DO UPDATE SET failure_count = EXCLUDED.failure_count, error = EXCLUDED.error"
failureCountStmtFormat = "WITH failures AS (SELECT failure_count FROM %s WHERE projection_name = $1 AND failed_sequence = $2 AND instance_id = $3)" +
" SELECT IF(" +
"EXISTS(SELECT failure_count FROM failures)," +
" (SELECT failure_count FROM failures)," +
" 0" +
") AS failure_count"
" SELECT COALESCE((SELECT failure_count FROM failures), 0) AS failure_count"
)
func (h *StatementHandler) handleFailedStmt(tx *sql.Tx, stmt *handler.Statement, execErr error) (shouldContinue bool) {

View File

@@ -72,12 +72,12 @@ func NewStatementHandler(
aggregates: aggregateTypes,
reduces: reduces,
bulkLimit: config.BulkLimit,
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionHandlerConfig.ProjectionName),
Locker: NewLocker(config.Client, config.LockTable, config.ProjectionName),
}
h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.SearchQuery, h.Lock, h.Unlock)
err := h.Init(ctx, config.InitCheck)
logging.OnError(err).Fatal("unable to initialize projections")
logging.OnError(err).WithField("projection", config.ProjectionName).Fatal("unable to initialize projections")
h.Subscribe(h.aggregates...)

View File

@@ -6,11 +6,10 @@ import (
"fmt"
"strings"
"github.com/lib/pq"
"github.com/jackc/pgconn"
"github.com/zitadel/logging"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/handler"
)
@@ -186,7 +185,7 @@ type ForeignKey struct {
RefColumns []string
}
//Init implements handler.Init
// Init implements handler.Init
func (h *StatementHandler) Init(ctx context.Context, checks ...*handler.Check) error {
for _, check := range checks {
if check == nil || check.IsNoop() {
@@ -280,7 +279,7 @@ func isErrAlreadyExists(err error) bool {
if !errors.As(err, &caosErr) {
return false
}
sqlErr, ok := caosErr.GetParent().(*pq.Error)
sqlErr, ok := caosErr.GetParent().(*pgconn.PgError)
if !ok {
return false
}
@@ -288,14 +287,11 @@ func isErrAlreadyExists(err error) bool {
}
func createTableStatement(table *Table, tableName string, suffix string) string {
stmt := fmt.Sprintf("CREATE TABLE %s (%s, PRIMARY KEY (%s)",
stmt := fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (%s, PRIMARY KEY (%s)",
tableName+suffix,
createColumnsStatement(table.columns, tableName),
strings.Join(table.primaryKey, ", "),
)
for _, index := range table.indices {
stmt += fmt.Sprintf(", INDEX %s (%s)", index.Name, strings.Join(index.Columns, ","))
}
for _, key := range table.foreignKeys {
ref := tableName
if len(key.RefColumns) > 0 {
@@ -309,7 +305,13 @@ func createTableStatement(table *Table, tableName string, suffix string) string
for _, constraint := range table.constraints {
stmt += fmt.Sprintf(", CONSTRAINT %s UNIQUE (%s)", constraint.Name, strings.Join(constraint.Columns, ","))
}
return stmt + ");"
stmt += ");"
for _, index := range table.indices {
stmt += fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s);", index.Name, tableName+suffix, strings.Join(index.Columns, ","))
}
return stmt
}
func createViewStatement(viewName string, selectStmt string) string {
@@ -321,7 +323,7 @@ func createViewStatement(viewName string, selectStmt string) string {
func createIndexStatement(index *Index) func(config execConfig) string {
return func(config execConfig) string {
stmt := fmt.Sprintf("CREATE INDEX %s ON %s (%s)",
stmt := fmt.Sprintf("CREATE INDEX IF NOT EXISTS %s ON %s (%s)",
index.Name,
config.tableName,
strings.Join(index.Columns, ","),
@@ -380,9 +382,8 @@ func columnType(columnType ColumnType) string {
case ColumnTypeJSONB:
return "JSONB"
case ColumnTypeBytes:
return "BYTES"
return "BYTEA"
default:
panic("unknown column type")
return ""
}
}

View File

@@ -8,9 +8,9 @@ import (
"strings"
"time"
"github.com/lib/pq"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/id"
)
@@ -91,17 +91,17 @@ func (h *locker) Unlock(instanceIDs ...string) error {
return nil
}
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs []string) (string, []interface{}) {
func (h *locker) lockStatement(lockDuration time.Duration, instanceIDs database.StringArray) (string, []interface{}) {
valueQueries := make([]string, len(instanceIDs))
values := make([]interface{}, len(instanceIDs)+4)
values[0] = h.workerName
//the unit of crdb interval is seconds (https://www.cockroachlabs.com/docs/stable/interval.html).
values[1] = lockDuration.Seconds()
values[1] = lockDuration
values[2] = h.projectionName
for i, instanceID := range instanceIDs {
valueQueries[i] = "($1, now()+$2::INTERVAL, $3, $" + strconv.Itoa(i+4) + ")"
values[i+3] = instanceID
}
values[len(values)-1] = pq.StringArray(instanceIDs)
values[len(values)-1] = instanceIDs
return h.lockStmt(strings.Join(valueQueries, ", "), len(values)), values
}

View File

@@ -10,6 +10,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
z_errs "github.com/zitadel/zitadel/internal/errors"
)
@@ -43,9 +44,9 @@ func TestStatementHandler_handleLock(t *testing.T) {
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2, "instanceID"),
expectLockErr(lockTable, workerName, 2, "instanceID", errLock),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLockErr(lockTable, workerName, 2*time.Second, "instanceID", errLock),
},
},
args: args{
@@ -63,8 +64,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
expectLock(lockTable, workerName, 2*time.Second, "instanceID"),
},
},
args: args{
@@ -81,8 +82,8 @@ func TestStatementHandler_handleLock(t *testing.T) {
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 2*time.Second, "instanceID1", "instanceID2"),
},
},
args: args{
@@ -149,7 +150,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
name: "lock fails",
want: want{
expectations: []mockExpectation{
expectLockErr(lockTable, workerName, 1, "instanceID", sql.ErrTxDone),
expectLockErr(lockTable, workerName, 1*time.Second, "instanceID", sql.ErrTxDone),
},
isErr: func(err error) bool {
return errors.Is(err, sql.ErrTxDone)
@@ -157,14 +158,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 1 * time.Second,
instanceIDs: []string{"instanceID"},
instanceIDs: database.StringArray{"instanceID"},
},
},
{
name: "lock no rows",
want: want{
expectations: []mockExpectation{
expectLockNoRows(lockTable, workerName, 2, "instanceID"),
expectLockNoRows(lockTable, workerName, 2*time.Second, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, renewNoRowsAffectedErr)
@@ -172,14 +173,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 2 * time.Second,
instanceIDs: []string{"instanceID"},
instanceIDs: database.StringArray{"instanceID"},
},
},
{
name: "success",
want: want{
expectations: []mockExpectation{
expectLock(lockTable, workerName, 3, "instanceID"),
expectLock(lockTable, workerName, 3*time.Second, "instanceID"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)
@@ -187,14 +188,14 @@ func TestStatementHandler_renewLock(t *testing.T) {
},
args: args{
lockDuration: 3 * time.Second,
instanceIDs: []string{"instanceID"},
instanceIDs: database.StringArray{"instanceID"},
},
},
{
name: "success with multiple",
want: want{
expectations: []mockExpectation{
expectLockMultipleInstances(lockTable, workerName, 3, "instanceID1", "instanceID2"),
expectLockMultipleInstances(lockTable, workerName, 3*time.Second, "instanceID1", "instanceID2"),
},
isErr: func(err error) bool {
return errors.Is(err, nil)

View File

@@ -4,8 +4,7 @@ import (
"strconv"
"strings"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/database"
caos_errs "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler"
@@ -51,10 +50,13 @@ func NewCreateStatement(event eventstore.Event, values []handler.Column, opts ..
}
}
func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ...execOption) *handler.Statement {
func NewUpsertStatement(event eventstore.Event, conflictCols []handler.Column, values []handler.Column, opts ...execOption) *handler.Statement {
cols, params, args := columnsToQuery(values)
columnNames := strings.Join(cols, ", ")
valuesPlaceholder := strings.Join(params, ", ")
conflictTarget := make([]string, len(conflictCols))
for i, col := range conflictCols {
conflictTarget[i] = col.Name
}
config := execConfig{
args: args,
@@ -64,8 +66,14 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
config.err = handler.ErrNoValues
}
updateCols, updateVals := getUpdateCols(cols, conflictTarget)
if len(updateCols) == 0 || len(updateVals) == 0 {
config.err = handler.ErrNoValues
}
q := func(config execConfig) string {
return "UPSERT INTO " + config.tableName + " (" + columnNames + ") VALUES (" + valuesPlaceholder + ")"
return "INSERT INTO " + config.tableName + " (" + strings.Join(cols, ", ") + ") VALUES (" + strings.Join(params, ", ") + ")" +
" ON CONFLICT (" + strings.Join(conflictTarget, ", ") + ") DO UPDATE SET (" + strings.Join(updateCols, ", ") + ") = (" + strings.Join(updateVals, ", ") + ")"
}
return &handler.Statement{
@@ -77,15 +85,38 @@ func NewUpsertStatement(event eventstore.Event, values []handler.Column, opts ..
}
}
func getUpdateCols(cols, conflictTarget []string) (updateCols, updateVals []string) {
updateCols = make([]string, len(cols))
updateVals = make([]string, len(cols))
copy(updateCols, cols)
for i := len(updateCols) - 1; i >= 0; i-- {
updateVals[i] = "EXCLUDED." + updateCols[i]
for _, conflict := range conflictTarget {
if conflict == updateCols[i] {
copy(updateCols[i:], updateCols[i+1:])
updateCols[len(updateCols)-1] = ""
updateCols = updateCols[:len(updateCols)-1]
copy(updateVals[i:], updateVals[i+1:])
updateVals[len(updateVals)-1] = ""
updateVals = updateVals[:len(updateVals)-1]
break
}
}
}
return updateCols, updateVals
}
func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditions []handler.Condition, opts ...execOption) *handler.Statement {
cols, params, args := columnsToQuery(values)
wheres, whereArgs := conditionsToWhere(conditions, len(params))
args = append(args, whereArgs...)
columnNames := strings.Join(cols, ", ")
valuesPlaceholder := strings.Join(params, ", ")
wheresPlaceholders := strings.Join(wheres, " AND ")
config := execConfig{
args: args,
}
@@ -99,7 +130,7 @@ func NewUpdateStatement(event eventstore.Event, values []handler.Column, conditi
}
q := func(config execConfig) string {
return "UPDATE " + config.tableName + " SET (" + columnNames + ") = (" + valuesPlaceholder + ") WHERE " + wheresPlaceholders
return "UPDATE " + config.tableName + " SET (" + strings.Join(cols, ", ") + ") = (" + strings.Join(params, ", ") + ") WHERE " + strings.Join(wheres, " AND ")
}
return &handler.Statement{
@@ -171,9 +202,9 @@ func AddCreateStatement(columns []handler.Column, opts ...execOption) func(event
}
}
func AddUpsertStatement(values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
func AddUpsertStatement(indexCols []handler.Column, values []handler.Column, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewUpsertStatement(event, values, opts...).Execute
return NewUpsertStatement(event, indexCols, values, opts...).Execute
}
}
@@ -189,9 +220,9 @@ func AddDeleteStatement(conditions []handler.Condition, opts ...execOption) func
}
}
func AddCopyStatement(from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
func AddCopyStatement(conflict, from, to []handler.Column, conditions []handler.Condition, opts ...execOption) func(eventstore.Event) Exec {
return func(event eventstore.Event) Exec {
return NewCopyStatement(event, from, to, conditions, opts...).Execute
return NewCopyStatement(event, conflict, from, to, conditions, opts...).Execute
}
}
@@ -218,11 +249,9 @@ func NewArrayRemoveCol(column string, value interface{}) handler.Column {
func NewArrayIntersectCol(column string, value interface{}) handler.Column {
var arrayType string
switch value.(type) {
case pq.StringArray:
arrayType = "STRING"
case pq.Int32Array,
pq.Int64Array:
arrayType = "INT"
case []string, database.StringArray:
arrayType = "TEXT"
//TODO: handle more types if necessary
}
return handler.Column{
@@ -234,25 +263,29 @@ func NewArrayIntersectCol(column string, value interface{}) handler.Column {
}
}
//NewCopyStatement creates a new upsert statement which updates a column from an existing row
// NewCopyStatement creates a new upsert statement which updates a column from an existing row
// cols represent the columns which are objective to change.
// if the value of a col is empty the data will be copied from the selected row
// if the value of a col is not empty the data will be set by the static value
// conds represent the conditions for the selection subquery
func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
func NewCopyStatement(event eventstore.Event, conflictCols, from, to []handler.Column, conds []handler.Condition, opts ...execOption) *handler.Statement {
columnNames := make([]string, len(to))
selectColumns := make([]string, len(from))
updateColumns := make([]string, len(columnNames))
argCounter := 0
args := []interface{}{}
for i := range from {
for i, col := range from {
columnNames[i] = to[i].Name
selectColumns[i] = from[i].Name
if from[i].Value != nil {
updateColumns[i] = "EXCLUDED." + col.Name
if col.Value != nil {
argCounter++
selectColumns[i] = "$" + strconv.Itoa(argCounter)
args = append(args, from[i].Value)
updateColumns[i] = selectColumns[i]
args = append(args, col.Value)
}
}
wheres := make([]string, len(conds))
@@ -262,6 +295,11 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
args = append(args, cond.Value)
}
conflictTargets := make([]string, len(conflictCols))
for i, conflictCol := range conflictCols {
conflictTargets[i] = conflictCol.Name
}
config := execConfig{
args: args,
}
@@ -275,7 +313,7 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
}
q := func(config execConfig) string {
return "UPSERT INTO " +
return "INSERT INTO " +
config.tableName +
" (" +
strings.Join(columnNames, ", ") +
@@ -283,7 +321,14 @@ func NewCopyStatement(event eventstore.Event, from, to []handler.Column, conds [
strings.Join(selectColumns, ", ") +
" FROM " +
config.tableName + " AS copy_table WHERE " +
strings.Join(wheres, " AND ")
strings.Join(wheres, " AND ") +
" ON CONFLICT (" +
strings.Join(conflictTargets, ", ") +
") DO UPDATE SET (" +
strings.Join(columnNames, ", ") +
") = (" +
strings.Join(updateColumns, ", ") +
")"
}
return &handler.Statement{

View File

@@ -178,9 +178,10 @@ func TestNewCreateStatement(t *testing.T) {
func TestNewUpsertStatement(t *testing.T) {
type args struct {
table string
event *testEvent
values []handler.Column
table string
event *testEvent
conflictCols []handler.Column
values []handler.Column
}
type want struct {
aggregateType eventstore.AggregateType
@@ -249,7 +250,7 @@ func TestNewUpsertStatement(t *testing.T) {
},
},
{
name: "correct",
name: "no update cols",
args: args{
table: "my_table",
event: &testEvent{
@@ -257,6 +258,9 @@ func TestNewUpsertStatement(t *testing.T) {
sequence: 1,
previousSequence: 0,
},
conflictCols: []handler.Column{
handler.NewCol("col1", nil),
},
values: []handler.Column{
{
Name: "col1",
@@ -264,6 +268,42 @@ func TestNewUpsertStatement(t *testing.T) {
},
},
},
want: want{
table: "my_table",
aggregateType: "agg",
sequence: 1,
previousSequence: 1,
executer: &wantExecuter{
shouldExecute: false,
},
isErr: func(err error) bool {
return errors.Is(err, handler.ErrNoValues)
},
},
},
{
name: "correct",
args: args{
table: "my_table",
event: &testEvent{
aggregateType: "agg",
sequence: 1,
previousSequence: 0,
},
conflictCols: []handler.Column{
handler.NewCol("col1", nil),
},
values: []handler.Column{
{
Name: "col1",
Value: "val",
},
{
Name: "col2",
Value: "val",
},
},
},
want: want{
table: "my_table",
aggregateType: "agg",
@@ -272,8 +312,8 @@ func TestNewUpsertStatement(t *testing.T) {
executer: &wantExecuter{
params: []params{
{
query: "UPSERT INTO my_table (col1) VALUES ($1)",
args: []interface{}{"val"},
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
args: []interface{}{"val", "val"},
},
},
shouldExecute: true,
@@ -287,7 +327,7 @@ func TestNewUpsertStatement(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.want.executer.t = t
stmt := NewUpsertStatement(tt.args.event, tt.args.values)
stmt := NewUpsertStatement(tt.args.event, tt.args.conflictCols, tt.args.values)
err := stmt.Execute(tt.want.executer, tt.args.table)
if !tt.want.isErr(err) {
@@ -724,11 +764,18 @@ func TestNewMultiStatement(t *testing.T) {
},
}),
AddUpsertStatement(
[]handler.Column{
handler.NewCol("col1", nil),
},
[]handler.Column{
{
Name: "col1",
Value: 1,
},
{
Name: "col2",
Value: 2,
},
}),
AddUpdateStatement(
[]handler.Column{
@@ -761,8 +808,8 @@ func TestNewMultiStatement(t *testing.T) {
args: []interface{}{1},
},
{
query: "UPSERT INTO my_table (col1) VALUES ($1)",
args: []interface{}{1},
query: "INSERT INTO my_table (col1, col2) VALUES ($1, $2) ON CONFLICT (col1) DO UPDATE SET (col2) = (EXCLUDED.col2)",
args: []interface{}{1, 2},
},
{
query: "UPDATE my_table SET (col1) = ($1) WHERE (col1 = $2)",
@@ -799,11 +846,12 @@ func TestNewMultiStatement(t *testing.T) {
func TestNewCopyStatement(t *testing.T) {
type args struct {
table string
event *testEvent
from []handler.Column
to []handler.Column
conds []handler.Condition
table string
event *testEvent
conflictingCols []handler.Column
from []handler.Column
to []handler.Column
conds []handler.Condition
}
type want struct {
aggregateType eventstore.AggregateType
@@ -1004,7 +1052,7 @@ func TestNewCopyStatement(t *testing.T) {
executer: &wantExecuter{
params: []params{
{
query: "UPSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
query: "INSERT INTO my_table (state, id, col_a, col_b) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_a, col_b) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
args: []interface{}{1, 2, 3},
},
},
@@ -1071,7 +1119,7 @@ func TestNewCopyStatement(t *testing.T) {
executer: &wantExecuter{
params: []params{
{
query: "UPSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3",
query: "INSERT INTO my_table (state, id, col_c, col_d) SELECT $1, id, col_a, col_b FROM my_table AS copy_table WHERE copy_table.id = $2 AND copy_table.state = $3 ON CONFLICT () DO UPDATE SET (state, id, col_c, col_d) = ($1, EXCLUDED.id, EXCLUDED.col_a, EXCLUDED.col_b)",
args: []interface{}{1, 2, 3},
},
},
@@ -1086,7 +1134,7 @@ func TestNewCopyStatement(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
tt.want.executer.t = t
stmt := NewCopyStatement(tt.args.event, tt.args.from, tt.args.to, tt.args.conds)
stmt := NewCopyStatement(tt.args.event, tt.args.conflictingCols, tt.args.from, tt.args.to, tt.args.conds)
err := stmt.Execute(tt.want.executer, tt.args.table)
if !tt.want.isErr(err) {

View File

@@ -9,6 +9,8 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/cmd/initialise"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/cockroach"
)
var (
@@ -45,13 +47,20 @@ func TestMain(m *testing.M) {
}
func initDB(db *sql.DB) error {
username := "zitadel"
database := "zitadel"
err := initialise.Initialise(db, initialise.VerifyUser(username, ""),
initialise.VerifyDatabase(database),
initialise.VerifyGrant(database, username))
initialise.ReadStmts("cockroach")
config := new(database.Config)
config.SetConnector(&cockroach.Config{
User: cockroach.User{
Username: "zitadel",
},
Database: "zitadel",
})
err := initialise.Init(db,
initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.Database()),
initialise.VerifyGrant(config.Database(), config.Username()))
if err != nil {
return err
}
return initialise.VerifyZitadel(db)
return initialise.VerifyZitadel(db, *config)
}

View File

@@ -9,6 +9,7 @@ import (
"strings"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/jackc/pgconn"
"github.com/lib/pq"
"github.com/zitadel/logging"
@@ -30,7 +31,7 @@ const (
" SELECT MAX(event_sequence) seq, 1 join_me" +
" FROM eventstore.events" +
" WHERE aggregate_type = $2" +
" AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
" AND (CASE WHEN $9::TEXT IS NULL THEN instance_id is null else instance_id = $9::TEXT END)" +
") AS agg_type " +
// combined with
"LEFT JOIN " +
@@ -39,7 +40,7 @@ const (
" SELECT event_sequence seq, resource_owner ro, 1 join_me" +
" FROM eventstore.events" +
" WHERE aggregate_type = $2 AND aggregate_id = $3" +
" AND (CASE WHEN $9::STRING IS NULL THEN instance_id is null else instance_id = $9::STRING END)" +
" AND (CASE WHEN $9::TEXT IS NULL THEN instance_id is null else instance_id = $9::TEXT END)" +
" ORDER BY event_sequence DESC" +
" LIMIT 1" +
") AS agg USING(join_me)" +
@@ -69,9 +70,9 @@ const (
" $5::JSONB AS event_data," +
" $6::VARCHAR AS editor_user," +
" $7::VARCHAR AS editor_service," +
" IFNULL((resource_owner), $8::VARCHAR) AS resource_owner," +
" COALESCE((resource_owner), $8::VARCHAR) AS resource_owner," +
" $9::VARCHAR AS instance_id," +
" NEXTVAL(CONCAT('eventstore.', IF($9 <> '', CONCAT('i_', $9), 'system'), '_seq'))," +
" NEXTVAL(CONCAT('eventstore.', (CASE WHEN $9 <> '' THEN CONCAT('i_', $9) ELSE 'system' END), '_seq'))," +
" aggregate_sequence AS previous_aggregate_sequence," +
" aggregate_type_sequence AS previous_aggregate_type_sequence " +
"FROM previous_data " +
@@ -156,7 +157,7 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`)
func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error {
row := db.client.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1, '_seq')", instanceID)
row := db.client.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID)
if row.Err() != nil {
return caos_errs.ThrowInvalidArgument(row.Err(), "SQL-7gtFA", "Errors.InvalidArgument")
}
@@ -218,7 +219,7 @@ func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery)
return events, nil
}
//LatestSequence returns the latest sequence found by the search query
// LatestSequence returns the latest sequence found by the search query
func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.SearchQuery) (uint64, error) {
var seq Sequence
err := query(ctx, db, searchQuery, &seq)
@@ -228,7 +229,7 @@ func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *repository.Sear
return uint64(seq), nil
}
//InstanceIDs returns the instance ids found by the search query
// InstanceIDs returns the instance ids found by the search query
func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQuery) ([]string, error) {
var ids []string
err := query(ctx, db, searchQuery, &ids)
@@ -331,7 +332,7 @@ var (
placeholder = regexp.MustCompile(`\?`)
)
//placeholder replaces all "?" with postgres placeholders ($<NUMBER>)
// placeholder replaces all "?" with postgres placeholders ($<NUMBER>)
func (db *CRDB) placeholder(query string) string {
occurances := placeholder.FindAllStringIndex(query, -1)
if len(occurances) == 0 {
@@ -355,5 +356,10 @@ func (db *CRDB) isUniqueViolationError(err error) bool {
return true
}
}
if pgxErr, ok := err.(*pgconn.PgError); ok {
if pgxErr.Code == "23505" {
return true
}
}
return false
}

View File

@@ -6,8 +6,7 @@ import (
"sync"
"testing"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
@@ -279,7 +278,7 @@ func TestCRDB_Push_OneAggregate(t *testing.T) {
uniqueCount int
assetCount int
aggType repository.AggregateType
aggID []string
aggID database.StringArray
}
type res struct {
wantErr bool
@@ -430,7 +429,7 @@ func TestCRDB_Push_OneAggregate(t *testing.T) {
t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr)
}
countEventRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = $1 AND aggregate_id = ANY($2)", tt.res.eventsRes.aggType, pq.Array(tt.res.eventsRes.aggID))
countEventRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = $1 AND aggregate_id = ANY($2)", tt.res.eventsRes.aggType, tt.res.eventsRes.aggID)
var eventCount int
err := countEventRow.Scan(&eventCount)
if err != nil {
@@ -462,8 +461,8 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
}
type eventsRes struct {
pushedEventsCount int
aggType []repository.AggregateType
aggID []string
aggType database.StringArray
aggID database.StringArray
}
type res struct {
wantErr bool
@@ -487,7 +486,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
eventsRes: eventsRes{
pushedEventsCount: 2,
aggID: []string{"100", "101"},
aggType: []repository.AggregateType{repository.AggregateType(t.Name())},
aggType: database.StringArray{t.Name()},
},
},
},
@@ -506,7 +505,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
eventsRes: eventsRes{
pushedEventsCount: 4,
aggID: []string{"102", "103"},
aggType: []repository.AggregateType{repository.AggregateType(t.Name())},
aggType: database.StringArray{t.Name()},
},
},
},
@@ -533,7 +532,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
eventsRes: eventsRes{
pushedEventsCount: 12,
aggID: []string{"106", "107", "108"},
aggType: []repository.AggregateType{repository.AggregateType(t.Name())},
aggType: database.StringArray{t.Name()},
},
},
},
@@ -547,7 +546,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) {
t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr)
}
countRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2)", pq.Array(tt.res.eventsRes.aggType), pq.Array(tt.res.eventsRes.aggID))
countRow := testCRDBClient.QueryRow("SELECT COUNT(*) FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2)", tt.res.eventsRes.aggType, tt.res.eventsRes.aggID)
var count int
err := countRow.Scan(&count)
if err != nil {
@@ -645,8 +644,8 @@ func TestCRDB_Push_Parallel(t *testing.T) {
}
type eventsRes struct {
pushedEventsCount int
aggTypes []repository.AggregateType
aggIDs []string
aggTypes database.StringArray
aggIDs database.StringArray
}
type res struct {
errCount int
@@ -681,7 +680,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
eventsRes: eventsRes{
aggIDs: []string{"200", "201", "202", "203"},
pushedEventsCount: 9,
aggTypes: []repository.AggregateType{repository.AggregateType(t.Name())},
aggTypes: database.StringArray{t.Name()},
},
},
},
@@ -718,7 +717,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
eventsRes: eventsRes{
aggIDs: []string{"204", "205", "206"},
pushedEventsCount: 14,
aggTypes: []repository.AggregateType{repository.AggregateType(t.Name())},
aggTypes: database.StringArray{t.Name()},
},
},
},
@@ -748,7 +747,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
eventsRes: eventsRes{
aggIDs: []string{"207", "208"},
pushedEventsCount: 11,
aggTypes: []repository.AggregateType{repository.AggregateType(t.Name())},
aggTypes: database.StringArray{t.Name()},
},
},
},
@@ -781,7 +780,7 @@ func TestCRDB_Push_Parallel(t *testing.T) {
t.Errorf("CRDB.Push() error count = %d, wanted err count %d, errs: %v", len(errs), tt.res.errCount, errs)
}
rows, err := testCRDBClient.Query("SELECT event_data FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2) order by event_sequence", pq.Array(tt.res.eventsRes.aggTypes), pq.Array(tt.res.eventsRes.aggIDs))
rows, err := testCRDBClient.Query("SELECT event_data FROM eventstore.events where aggregate_type = ANY($1) AND aggregate_id = ANY($2) order by event_sequence", tt.res.eventsRes.aggTypes, tt.res.eventsRes.aggIDs)
if err != nil {
t.Error("unable to query inserted rows: ", err)
return
@@ -993,10 +992,10 @@ func TestCRDB_Push_ResourceOwner(t *testing.T) {
events []*repository.Event
}
type res struct {
resourceOwners []string
resourceOwners database.StringArray
}
type fields struct {
aggregateIDs []string
aggregateIDs database.StringArray
aggregateType string
}
tests := []struct {
@@ -1128,7 +1127,7 @@ func TestCRDB_Push_ResourceOwner(t *testing.T) {
}
}
rows, err := testCRDBClient.Query("SELECT resource_owner FROM eventstore.events WHERE aggregate_type = $1 AND aggregate_id = ANY($2) ORDER BY event_sequence", tt.fields.aggregateType, pq.Array(tt.fields.aggregateIDs))
rows, err := testCRDBClient.Query("SELECT resource_owner FROM eventstore.events WHERE aggregate_type = $1 AND aggregate_id = ANY($2) ORDER BY event_sequence", tt.fields.aggregateType, tt.fields.aggregateIDs)
if err != nil {
t.Error("unable to query inserted rows: ", err)
return

View File

@@ -9,6 +9,8 @@ import (
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/cmd/initialise"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/cockroach"
)
var (
@@ -42,15 +44,22 @@ func TestMain(m *testing.M) {
}
func initDB(db *sql.DB) error {
username := "zitadel"
database := "zitadel"
err := initialise.Initialise(db, initialise.VerifyUser(username, ""),
initialise.VerifyDatabase(database),
initialise.VerifyGrant(database, username))
config := new(database.Config)
config.SetConnector(&cockroach.Config{User: cockroach.User{Username: "zitadel"}, Database: "zitadel"})
if err := initialise.ReadStmts("cockroach"); err != nil {
return err
}
err := initialise.Init(db,
initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.Database()),
initialise.VerifyGrant(config.Database(), config.Username()))
if err != nil {
return err
}
return initialise.VerifyZitadel(db)
return initialise.VerifyZitadel(db, *config)
}
func fillUniqueData(unique_type, field, instanceID string) error {

View File

@@ -8,7 +8,6 @@ import (
"fmt"
"strings"
"github.com/lib/pq"
"github.com/zitadel/logging"
z_errors "github.com/zitadel/zitadel/internal/errors"
@@ -170,8 +169,6 @@ func prepareCondition(criteria querier, filters [][]*repository.Filter) (clause
for _, f := range filter {
value := f.Value
switch value.(type) {
case []bool, []float64, []int64, []string, []repository.AggregateType, []repository.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]repository.AggregateType, *[]repository.EventType:
value = pq.Array(value)
case map[string]interface{}:
var err error
value, err = json.Marshal(value)

View File

@@ -9,7 +9,6 @@ import (
"time"
"github.com/DATA-DOG/go-sqlmock"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/repository"
@@ -265,7 +264,7 @@ func Test_prepareCondition(t *testing.T) {
},
res: res{
clause: " WHERE ( aggregate_type = ANY(?) )",
values: []interface{}{pq.Array([]repository.AggregateType{"user", "org"})},
values: []interface{}{[]repository.AggregateType{"user", "org"}},
},
},
{
@@ -281,7 +280,7 @@ func Test_prepareCondition(t *testing.T) {
},
res: res{
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
values: []interface{}{pq.Array([]repository.AggregateType{"user", "org"}), "1234", pq.Array([]repository.EventType{"user.created", "org.created"})},
values: []interface{}{[]repository.AggregateType{"user", "org"}, "1234", []repository.EventType{"user.created", "org.created"}},
},
},
}

View File

@@ -3,11 +3,12 @@ package eventstore
import (
"database/sql"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
//SearchQueryBuilder represents the builder for your filter
// SearchQueryBuilder represents the builder for your filter
// if invalid data are set the filter will fail
type SearchQueryBuilder struct {
columns repository.Columns
@@ -79,51 +80,51 @@ func (builder *SearchQueryBuilder) Matches(event Event, existingLen int) (matche
return false
}
//Columns defines which fields are set
// Columns defines which fields are set
func (builder *SearchQueryBuilder) Columns(columns Columns) *SearchQueryBuilder {
builder.columns = repository.Columns(columns)
return builder
}
//Limit defines how many events are returned maximally.
// Limit defines how many events are returned maximally.
func (builder *SearchQueryBuilder) Limit(limit uint64) *SearchQueryBuilder {
builder.limit = limit
return builder
}
//ResourceOwner defines the resource owner (org) of the events
// ResourceOwner defines the resource owner (org) of the events
func (builder *SearchQueryBuilder) ResourceOwner(resourceOwner string) *SearchQueryBuilder {
builder.resourceOwner = resourceOwner
return builder
}
//InstanceID defines the instanceID (system) of the events
// InstanceID defines the instanceID (system) of the events
func (builder *SearchQueryBuilder) InstanceID(instanceID string) *SearchQueryBuilder {
builder.instanceID = instanceID
return builder
}
//OrderDesc changes the sorting order of the returned events to descending
// OrderDesc changes the sorting order of the returned events to descending
func (builder *SearchQueryBuilder) OrderDesc() *SearchQueryBuilder {
builder.desc = true
return builder
}
//OrderAsc changes the sorting order of the returned events to ascending
// OrderAsc changes the sorting order of the returned events to ascending
func (builder *SearchQueryBuilder) OrderAsc() *SearchQueryBuilder {
builder.desc = false
return builder
}
//SetTx ensures that the eventstore library uses the existing transaction
// SetTx ensures that the eventstore library uses the existing transaction
func (builder *SearchQueryBuilder) SetTx(tx *sql.Tx) *SearchQueryBuilder {
builder.tx = tx
return builder
}
//AddQuery creates a new sub query.
//All fields in the sub query are AND-connected in the storage request.
//Multiple sub queries are OR-connected in the storage request.
// AddQuery creates a new sub query.
// All fields in the sub query are AND-connected in the storage request.
// Multiple sub queries are OR-connected in the storage request.
func (builder *SearchQueryBuilder) AddQuery() *SearchQuery {
query := &SearchQuery{
builder: builder,
@@ -133,61 +134,61 @@ func (builder *SearchQueryBuilder) AddQuery() *SearchQuery {
return query
}
//Or creates a new sub query on the search query builder
// Or creates a new sub query on the search query builder
func (query SearchQuery) Or() *SearchQuery {
return query.builder.AddQuery()
}
//AggregateTypes filters for events with the given aggregate types
// AggregateTypes filters for events with the given aggregate types
func (query *SearchQuery) AggregateTypes(types ...AggregateType) *SearchQuery {
query.aggregateTypes = types
return query
}
//SequenceGreater filters for events with sequence greater the requested sequence
// SequenceGreater filters for events with sequence greater the requested sequence
func (query *SearchQuery) SequenceGreater(sequence uint64) *SearchQuery {
query.eventSequenceGreater = sequence
return query
}
//SequenceLess filters for events with sequence less the requested sequence
// SequenceLess filters for events with sequence less the requested sequence
func (query *SearchQuery) SequenceLess(sequence uint64) *SearchQuery {
query.eventSequenceLess = sequence
return query
}
//AggregateIDs filters for events with the given aggregate id's
// AggregateIDs filters for events with the given aggregate id's
func (query *SearchQuery) AggregateIDs(ids ...string) *SearchQuery {
query.aggregateIDs = ids
return query
}
//InstanceID filters for events with the given instanceID
// InstanceID filters for events with the given instanceID
func (query *SearchQuery) InstanceID(instanceID string) *SearchQuery {
query.instanceID = instanceID
return query
}
//ExcludedInstanceID filters for events not having the given instanceIDs
// ExcludedInstanceID filters for events not having the given instanceIDs
func (query *SearchQuery) ExcludedInstanceID(instanceIDs ...string) *SearchQuery {
query.excludedInstanceIDs = instanceIDs
return query
}
//EventTypes filters for events with the given event types
// EventTypes filters for events with the given event types
func (query *SearchQuery) EventTypes(types ...EventType) *SearchQuery {
query.eventTypes = types
return query
}
//EventData filters for events with the given event data.
//Use this call with care as it will be slower than the other filters.
// EventData filters for events with the given event data.
// Use this call with care as it will be slower than the other filters.
func (query *SearchQuery) EventData(data map[string]interface{}) *SearchQuery {
query.eventData = data
return query
}
//Builder returns the SearchQueryBuilder of the sub query
// Builder returns the SearchQueryBuilder of the sub query
func (query *SearchQuery) Builder() *SearchQueryBuilder {
return query.builder
}
@@ -262,7 +263,7 @@ func (query *SearchQuery) aggregateIDFilter() *repository.Filter {
if len(query.aggregateIDs) == 1 {
return repository.NewFilter(repository.FieldAggregateID, query.aggregateIDs[0], repository.OperationEquals)
}
return repository.NewFilter(repository.FieldAggregateID, query.aggregateIDs, repository.OperationIn)
return repository.NewFilter(repository.FieldAggregateID, database.StringArray(query.aggregateIDs), repository.OperationIn)
}
func (query *SearchQuery) eventTypeFilter() *repository.Filter {
@@ -272,9 +273,9 @@ func (query *SearchQuery) eventTypeFilter() *repository.Filter {
if len(query.eventTypes) == 1 {
return repository.NewFilter(repository.FieldEventType, repository.EventType(query.eventTypes[0]), repository.OperationEquals)
}
eventTypes := make([]repository.EventType, len(query.eventTypes))
eventTypes := make(database.StringArray, len(query.eventTypes))
for i, eventType := range query.eventTypes {
eventTypes[i] = repository.EventType(eventType)
eventTypes[i] = string(eventType)
}
return repository.NewFilter(repository.FieldEventType, eventTypes, repository.OperationIn)
}
@@ -286,9 +287,9 @@ func (query *SearchQuery) aggregateTypeFilter() *repository.Filter {
if len(query.aggregateTypes) == 1 {
return repository.NewFilter(repository.FieldAggregateType, repository.AggregateType(query.aggregateTypes[0]), repository.OperationEquals)
}
aggregateTypes := make([]repository.AggregateType, len(query.aggregateTypes))
aggregateTypes := make(database.StringArray, len(query.aggregateTypes))
for i, aggregateType := range query.aggregateTypes {
aggregateTypes[i] = repository.AggregateType(aggregateType)
aggregateTypes[i] = string(aggregateType)
}
return repository.NewFilter(repository.FieldAggregateType, aggregateTypes, repository.OperationIn)
}
@@ -326,7 +327,7 @@ func (query *SearchQuery) excludedInstanceIDFilter() *repository.Filter {
if len(query.excludedInstanceIDs) == 0 {
return nil
}
return repository.NewFilter(repository.FieldInstanceID, query.excludedInstanceIDs, repository.OperationNotIn)
return repository.NewFilter(repository.FieldInstanceID, database.StringArray(query.excludedInstanceIDs), repository.OperationNotIn)
}
func (builder *SearchQueryBuilder) resourceOwnerFilter() *repository.Filter {

View File

@@ -5,6 +5,7 @@ import (
"reflect"
"testing"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/repository"
)
@@ -312,7 +313,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
Limit: 0,
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, []repository.AggregateType{"user", "org"}, repository.OperationIn),
repository.NewFilter(repository.FieldAggregateType, database.StringArray{"user", "org"}, repository.OperationIn),
},
},
},
@@ -483,7 +484,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, repository.AggregateType("user"), repository.OperationEquals),
repository.NewFilter(repository.FieldAggregateID, []string{"1234", "0815"}, repository.OperationIn),
repository.NewFilter(repository.FieldAggregateID, database.StringArray{"1234", "0815"}, repository.OperationIn),
},
},
},
@@ -561,7 +562,7 @@ func TestSearchQuerybuilderBuild(t *testing.T) {
Filters: [][]*repository.Filter{
{
repository.NewFilter(repository.FieldAggregateType, repository.AggregateType("user"), repository.OperationEquals),
repository.NewFilter(repository.FieldEventType, []repository.EventType{"user.created", "user.changed"}, repository.OperationIn),
repository.NewFilter(repository.FieldEventType, database.StringArray{"user.created", "user.changed"}, repository.OperationIn),
},
},
},
@@ -740,10 +741,10 @@ func assertRepoQuery(t *testing.T, want, got *repository.SearchQuery) {
if !reflect.DeepEqual(got.Columns, want.Columns) {
t.Errorf("wrong columns in query: got: %v want: %v", got.Columns, want.Columns)
}
if !reflect.DeepEqual(got.Desc, want.Desc) {
if got.Desc != want.Desc {
t.Errorf("wrong desc in query: got: %v want: %v", got.Desc, want.Desc)
}
if !reflect.DeepEqual(got.Limit, want.Limit) {
if got.Limit != want.Limit {
t.Errorf("wrong limit in query: got: %v want: %v", got.Limit, want.Limit)
}

View File

@@ -2,8 +2,6 @@ package sql
import (
"database/sql"
_ "github.com/lib/pq"
)
func Start(client *sql.DB) *SQL {

View File

@@ -7,7 +7,6 @@ import (
"strconv"
"strings"
"github.com/lib/pq"
"github.com/zitadel/logging"
z_errors "github.com/zitadel/zitadel/internal/errors"
@@ -72,10 +71,6 @@ func prepareCondition(filters [][]*es_models.Filter) (clause string, values []in
subClauses := make([]string, 0, len(filter))
for _, f := range filter {
value := f.GetValue()
switch value.(type) {
case []bool, []float64, []int64, []string, []es_models.AggregateType, []es_models.EventType, *[]bool, *[]float64, *[]int64, *[]string, *[]es_models.AggregateType, *[]es_models.EventType:
value = pq.Array(value)
}
subClauses = append(subClauses, getCondition(f))
if subClauses[len(subClauses)-1] == "" {

View File

@@ -6,8 +6,6 @@ import (
"testing"
"time"
"github.com/lib/pq"
"github.com/zitadel/zitadel/internal/errors"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
)
@@ -365,7 +363,7 @@ func Test_prepareCondition(t *testing.T) {
},
res: res{
clause: " WHERE ( aggregate_type = ANY(?) )",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"})},
values: []interface{}{[]es_models.AggregateType{"user", "org"}},
},
},
{
@@ -381,7 +379,7 @@ func Test_prepareCondition(t *testing.T) {
},
res: res{
clause: " WHERE ( aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) )",
values: []interface{}{pq.Array([]es_models.AggregateType{"user", "org"}), "1234", pq.Array([]es_models.EventType{"user.created", "org.created"})},
values: []interface{}{[]es_models.AggregateType{"user", "org"}, "1234", []es_models.EventType{"user.created", "org.created"}},
},
},
}

View File

@@ -3,9 +3,6 @@ package sql
import (
"context"
"database/sql"
//sql import
_ "github.com/lib/pq"
)
type SQL struct {

View File

@@ -31,7 +31,7 @@ func Renew(dbClient *sql.DB, lockTable, lockerID, viewModel, instanceID string,
return crdb.ExecuteTx(context.Background(), dbClient, nil, func(tx *sql.Tx) error {
insert := fmt.Sprintf(insertStmtFormat, lockTable)
result, err := tx.Exec(insert,
lockerID, waitTime.Milliseconds()/millisecondsAsSeconds, viewModel, instanceID)
lockerID, waitTime, viewModel, instanceID)
if err != nil {
tx.Rollback()

View File

@@ -38,7 +38,12 @@ func ReduceEvent(handler Handler, event *models.Event) {
if err != nil {
handler.Subscription().Unsubscribe()
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
logging.WithFields(
"cause", err,
"stack", string(debug.Stack()),
"sequence", event.Sequence,
"instnace", event.InstanceID,
).Error("reduce panicked")
}
}()
currentSequence, err := handler.CurrentSequence(event.InstanceID)

View File

@@ -75,7 +75,10 @@ func (s *spooledHandler) load(workerID string) {
err := recover()
if err != nil {
logging.WithFields("cause", err, "stack", string(debug.Stack())).Error("reduce panicked")
logging.WithFields(
"cause", err,
"stack", string(debug.Stack()),
).Error("reduce panicked")
}
}()
ctx, cancel := context.WithCancel(context.Background())
@@ -167,7 +170,7 @@ func (s *spooledHandler) query(ctx context.Context, instanceIDs ...string) ([]*m
return s.eventstore.FilterEvents(ctx, query)
}
//lock ensures the lock on the database.
// lock ensures the lock on the database.
// the returned channel will be closed if ctx is done or an error occured durring lock
func (s *spooledHandler) lock(ctx context.Context, errs chan<- error, workerID string) chan bool {
renewTimer := time.After(0)