refactor(eventstore): move push logic to sql (#8816)

# Which Problems Are Solved

If many events are written to the same aggregate id it can happen that
zitadel [starts to retry the push
transaction](48ffc902cc/internal/eventstore/eventstore.go (L101))
because [the locking
behaviour](48ffc902cc/internal/eventstore/v3/sequence.go (L25))
during push does compute the wrong sequence because newly committed
events are not visible to the transaction. These events impact the
current sequence.

In cases with high command traffic on a single aggregate id this can
have severe impact on general performance of zitadel. Because many
connections of the `eventstore pusher` database pool are blocked by each
other.

# How the Problems Are Solved

To improve the performance this locking mechanism was removed and the
business logic of push is moved to sql functions which reduce network
traffic and can be analyzed by the database before the actual push. For
clients of the eventstore framework nothing changed.

# Additional Changes

- after a connection is established prefetches the newly added database
types
- `eventstore.BaseEvent` now returns the correct revision of the event

# Additional Context

- part of https://github.com/zitadel/zitadel/issues/8931

---------

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
Co-authored-by: Livio Spring <livio.a@gmail.com>
Co-authored-by: Max Peintner <max@caos.ch>
Co-authored-by: Elio Bischof <elio@zitadel.com>
Co-authored-by: Stefan Benz <46600784+stebenz@users.noreply.github.com>
Co-authored-by: Miguel Cabrerizo <30386061+doncicuto@users.noreply.github.com>
Co-authored-by: Joakim Lodén <Loddan@users.noreply.github.com>
Co-authored-by: Yxnt <Yxnt@users.noreply.github.com>
Co-authored-by: Stefan Benz <stefan@caos.ch>
Co-authored-by: Harsha Reddy <harsha.reddy@klaviyo.com>
Co-authored-by: Zach H <zhirschtritt@gmail.com>
This commit is contained in:
Silvan
2024-12-04 14:51:40 +01:00
committed by GitHub
parent 14db628856
commit dab5d9e756
42 changed files with 1591 additions and 277 deletions

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
"errors" "errors"
"github.com/jackc/pgx/v5/pgconn" "github.com/jackc/pgx/v5/pgconn"
@@ -8,8 +9,8 @@ import (
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
) )
func exec(db *database.DB, stmt string, possibleErrCodes []string, args ...interface{}) error { func exec(ctx context.Context, db database.ContextExecuter, stmt string, possibleErrCodes []string, args ...interface{}) error {
_, err := db.Exec(stmt, args...) _, err := db.ExecContext(ctx, stmt, args...)
pgErr := new(pgconn.PgError) pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) { if errors.As(err, &pgErr) {
for _, possibleCode := range possibleErrCodes { for _, possibleCode := range possibleErrCodes {

View File

@@ -59,7 +59,7 @@ The user provided by flags needs privileges to
} }
func InitAll(ctx context.Context, config *Config) { func InitAll(ctx context.Context, config *Config) {
err := initialise(config.Database, err := initialise(ctx, config.Database,
VerifyUser(config.Database.Username(), config.Database.Password()), VerifyUser(config.Database.Username(), config.Database.Password()),
VerifyDatabase(config.Database.DatabaseName()), VerifyDatabase(config.Database.DatabaseName()),
VerifyGrant(config.Database.DatabaseName(), config.Database.Username()), VerifyGrant(config.Database.DatabaseName(), config.Database.Username()),
@@ -71,7 +71,7 @@ func InitAll(ctx context.Context, config *Config) {
logging.OnError(err).Fatal("unable to initialize ZITADEL") logging.OnError(err).Fatal("unable to initialize ZITADEL")
} }
func initialise(config database.Config, steps ...func(*database.DB) error) error { func initialise(ctx context.Context, config database.Config, steps ...func(context.Context, *database.DB) error) error {
logging.Info("initialization started") logging.Info("initialization started")
err := ReadStmts(config.Type()) err := ReadStmts(config.Type())
@@ -85,12 +85,12 @@ func initialise(config database.Config, steps ...func(*database.DB) error) error
} }
defer db.Close() defer db.Close()
return Init(db, steps...) return Init(ctx, db, steps...)
} }
func Init(db *database.DB, steps ...func(*database.DB) error) error { func Init(ctx context.Context, db *database.DB, steps ...func(context.Context, *database.DB) error) error {
for _, step := range steps { for _, step := range steps {
if err := step(db); err != nil { if err := step(ctx, db); err != nil {
return err return err
} }
} }

View File

@@ -1,2 +1,2 @@
-- replace %[1]s with the name of the database -- replace %[1]s with the name of the database
CREATE DATABASE IF NOT EXISTS "%[1]s" CREATE DATABASE IF NOT EXISTS "%[1]s";

View File

@@ -19,3 +19,98 @@ CREATE TABLE IF NOT EXISTS eventstore.events2 (
, INDEX es_wm (aggregate_id, instance_id, aggregate_type, event_type) , INDEX es_wm (aggregate_id, instance_id, aggregate_type, event_type)
, INDEX es_projection (instance_id, aggregate_type, event_type, "position" DESC) , INDEX es_projection (instance_id, aggregate_type, event_type, "position" DESC)
); );
-- represents an event to be created.
CREATE TYPE IF NOT EXISTS eventstore.command AS (
instance_id TEXT
, aggregate_type TEXT
, aggregate_id TEXT
, command_type TEXT
, revision INT2
, payload JSONB
, creator TEXT
, owner TEXT
);
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
SELECT
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
, ("c").command_type AS event_type
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY ("c").instance_id, ("c").aggregate_type, ("c").aggregate_id ORDER BY ("c").in_tx_order) AS sequence
, ("c").revision
, hlc_to_timestamp(cluster_logical_timestamp()) AS created_at
, ("c").payload
, ("c").creator
, cs.owner
, cluster_logical_timestamp() AS position
, ("c").in_tx_order
FROM (
SELECT
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
, ("c").command_type
, ("c").revision
, ("c").payload
, ("c").creator
, ("c").owner
, ROW_NUMBER() OVER () AS in_tx_order
FROM
UNNEST(commands) AS "c"
) AS "c"
JOIN (
SELECT
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, CASE WHEN (e.owner IS NOT NULL OR e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
, COALESCE(MAX(e.sequence), 0) AS sequence
FROM (
SELECT DISTINCT
("cmds").instance_id
, ("cmds").aggregate_type
, ("cmds").aggregate_id
, ("cmds").owner
FROM UNNEST(commands) AS "cmds"
) AS cmds
LEFT JOIN eventstore.events2 AS e
ON cmds.instance_id = e.instance_id
AND cmds.aggregate_type = e.aggregate_type
AND cmds.aggregate_id = e.aggregate_id
JOIN (
SELECT
DISTINCT ON (
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
)
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
, ("c").owner
FROM
UNNEST(commands) AS "c"
) AS command_owners ON
cmds.instance_id = command_owners.instance_id
AND cmds.aggregate_type = command_owners.aggregate_type
AND cmds.aggregate_id = command_owners.aggregate_id
GROUP BY
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, 4 -- owner
) AS cs
ON ("c").instance_id = cs.instance_id
AND ("c").aggregate_type = cs.aggregate_type
AND ("c").aggregate_id = cs.aggregate_id
ORDER BY
in_tx_order
$$ LANGUAGE SQL;
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 AS $$
INSERT INTO eventstore.events2
SELECT * FROM eventstore.commands_to_events(commands)
RETURNING *
$$ LANGUAGE SQL;

View File

@@ -19,4 +19,103 @@ CREATE TABLE IF NOT EXISTS eventstore.events2 (
CREATE INDEX IF NOT EXISTS es_active_instances ON eventstore.events2 (created_at DESC, instance_id); CREATE INDEX IF NOT EXISTS es_active_instances ON eventstore.events2 (created_at DESC, instance_id);
CREATE INDEX IF NOT EXISTS es_wm ON eventstore.events2 (aggregate_id, instance_id, aggregate_type, event_type); CREATE INDEX IF NOT EXISTS es_wm ON eventstore.events2 (aggregate_id, instance_id, aggregate_type, event_type);
CREATE INDEX IF NOT EXISTS es_projection ON eventstore.events2 (instance_id, aggregate_type, event_type, "position"); CREATE INDEX IF NOT EXISTS es_projection ON eventstore.events2 (instance_id, aggregate_type, event_type, "position");
-- represents an event to be created.
DO $$ BEGIN
CREATE TYPE eventstore.command AS (
instance_id TEXT
, aggregate_type TEXT
, aggregate_id TEXT
, command_type TEXT
, revision INT2
, payload JSONB
, creator TEXT
, owner TEXT
);
EXCEPTION
WHEN duplicate_object THEN null;
END $$;
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
SELECT
c.instance_id
, c.aggregate_type
, c.aggregate_id
, c.command_type AS event_type
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY c.instance_id, c.aggregate_type, c.aggregate_id ORDER BY c.in_tx_order) AS sequence
, c.revision
, NOW() AS created_at
, c.payload
, c.creator
, cs.owner
, EXTRACT(EPOCH FROM NOW()) AS position
, c.in_tx_order
FROM (
SELECT
c.instance_id
, c.aggregate_type
, c.aggregate_id
, c.command_type
, c.revision
, c.payload
, c.creator
, c.owner
, ROW_NUMBER() OVER () AS in_tx_order
FROM
UNNEST(commands) AS c
) AS c
JOIN (
SELECT
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, CASE WHEN (e.owner IS NOT NULL OR e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
, COALESCE(MAX(e.sequence), 0) AS sequence
FROM (
SELECT DISTINCT
instance_id
, aggregate_type
, aggregate_id
, owner
FROM UNNEST(commands)
) AS cmds
LEFT JOIN eventstore.events2 AS e
ON cmds.instance_id = e.instance_id
AND cmds.aggregate_type = e.aggregate_type
AND cmds.aggregate_id = e.aggregate_id
JOIN (
SELECT
DISTINCT ON (
instance_id
, aggregate_type
, aggregate_id
)
instance_id
, aggregate_type
, aggregate_id
, owner
FROM
UNNEST(commands)
) AS command_owners ON
cmds.instance_id = command_owners.instance_id
AND cmds.aggregate_type = command_owners.aggregate_type
AND cmds.aggregate_id = command_owners.aggregate_id
GROUP BY
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, 4 -- owner
) AS cs
ON c.instance_id = cs.instance_id
AND c.aggregate_type = cs.aggregate_type
AND c.aggregate_id = cs.aggregate_id
ORDER BY
in_tx_order;
$$ LANGUAGE SQL;
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
INSERT INTO eventstore.events2
SELECT * FROM eventstore.commands_to_events(commands)
RETURNING *
$$ LANGUAGE SQL;

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
_ "embed" _ "embed"
"fmt" "fmt"
@@ -28,16 +29,16 @@ The user provided by flags needs privileges to
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.GetViper()) config := MustNewConfig(viper.GetViper())
err := initialise(config.Database, VerifyDatabase(config.Database.DatabaseName())) err := initialise(cmd.Context(), config.Database, VerifyDatabase(config.Database.DatabaseName()))
logging.OnError(err).Fatal("unable to initialize the database") logging.OnError(err).Fatal("unable to initialize the database")
}, },
} }
} }
func VerifyDatabase(databaseName string) func(*database.DB) error { func VerifyDatabase(databaseName string) func(context.Context, *database.DB) error {
return func(db *database.DB) error { return func(ctx context.Context, db *database.DB) error {
logging.WithFields("database", databaseName).Info("verify database") logging.WithFields("database", databaseName).Info("verify database")
return exec(db, fmt.Sprintf(databaseStmt, databaseName), []string{dbAlreadyExistsCode}) return exec(ctx, db, fmt.Sprintf(databaseStmt, databaseName), []string{dbAlreadyExistsCode})
} }
} }

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"testing" "testing"
@@ -55,7 +56,7 @@ func Test_verifyDB(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := VerifyDatabase(tt.args.database)(tt.args.db.db); !errors.Is(err, tt.targetErr) { if err := VerifyDatabase(tt.args.database)(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
t.Errorf("verifyDB() error = %v, want: %v", err, tt.targetErr) t.Errorf("verifyDB() error = %v, want: %v", err, tt.targetErr)
} }
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil { if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
_ "embed" _ "embed"
"fmt" "fmt"
@@ -23,16 +24,16 @@ Prerequisites:
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.GetViper()) config := MustNewConfig(viper.GetViper())
err := initialise(config.Database, VerifyGrant(config.Database.DatabaseName(), config.Database.Username())) err := initialise(cmd.Context(), config.Database, VerifyGrant(config.Database.DatabaseName(), config.Database.Username()))
logging.OnError(err).Fatal("unable to set grant") logging.OnError(err).Fatal("unable to set grant")
}, },
} }
} }
func VerifyGrant(databaseName, username string) func(*database.DB) error { func VerifyGrant(databaseName, username string) func(context.Context, *database.DB) error {
return func(db *database.DB) error { return func(ctx context.Context, db *database.DB) error {
logging.WithFields("user", username, "database", databaseName).Info("verify grant") logging.WithFields("user", username, "database", databaseName).Info("verify grant")
return exec(db, fmt.Sprintf(grantStmt, databaseName, username), nil) return exec(ctx, db, fmt.Sprintf(grantStmt, databaseName, username), nil)
} }
} }

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"testing" "testing"
@@ -53,7 +54,7 @@ func Test_verifyGrant(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := VerifyGrant(tt.args.database, tt.args.username)(tt.args.db.db); !errors.Is(err, tt.targetErr) { if err := VerifyGrant(tt.args.database, tt.args.username)(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
t.Errorf("VerifyGrant() error = %v, want: %v", err, tt.targetErr) t.Errorf("VerifyGrant() error = %v, want: %v", err, tt.targetErr)
} }
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil { if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
_ "embed" _ "embed"
"fmt" "fmt"
@@ -26,19 +27,19 @@ Cockroach
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.GetViper()) config := MustNewConfig(viper.GetViper())
err := initialise(config.Database, VerifySettings(config.Database.DatabaseName(), config.Database.Username())) err := initialise(cmd.Context(), config.Database, VerifySettings(config.Database.DatabaseName(), config.Database.Username()))
logging.OnError(err).Fatal("unable to set settings") logging.OnError(err).Fatal("unable to set settings")
}, },
} }
} }
func VerifySettings(databaseName, username string) func(*database.DB) error { func VerifySettings(databaseName, username string) func(context.Context, *database.DB) error {
return func(db *database.DB) error { return func(ctx context.Context, db *database.DB) error {
if db.Type() == "postgres" { if db.Type() == "postgres" {
return nil return nil
} }
logging.WithFields("user", username, "database", databaseName).Info("verify settings") logging.WithFields("user", username, "database", databaseName).Info("verify settings")
return exec(db, fmt.Sprintf(settingsStmt, databaseName, username), nil) return exec(ctx, db, fmt.Sprintf(settingsStmt, databaseName, username), nil)
} }
} }

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
_ "embed" _ "embed"
"fmt" "fmt"
@@ -28,20 +29,20 @@ The user provided by flags needs privileges to
Run: func(cmd *cobra.Command, args []string) { Run: func(cmd *cobra.Command, args []string) {
config := MustNewConfig(viper.GetViper()) config := MustNewConfig(viper.GetViper())
err := initialise(config.Database, VerifyUser(config.Database.Username(), config.Database.Password())) err := initialise(cmd.Context(), config.Database, VerifyUser(config.Database.Username(), config.Database.Password()))
logging.OnError(err).Fatal("unable to init user") logging.OnError(err).Fatal("unable to init user")
}, },
} }
} }
func VerifyUser(username, password string) func(*database.DB) error { func VerifyUser(username, password string) func(context.Context, *database.DB) error {
return func(db *database.DB) error { return func(ctx context.Context, db *database.DB) error {
logging.WithFields("username", username).Info("verify user") logging.WithFields("username", username).Info("verify user")
if password != "" { if password != "" {
createUserStmt += " WITH PASSWORD '" + password + "'" createUserStmt += " WITH PASSWORD '" + password + "'"
} }
return exec(db, fmt.Sprintf(createUserStmt, username), []string{roleAlreadyExistsCode}) return exec(ctx, db, fmt.Sprintf(createUserStmt, username), []string{roleAlreadyExistsCode})
} }
} }

View File

@@ -1,6 +1,7 @@
package initialise package initialise
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"testing" "testing"
@@ -70,7 +71,7 @@ func Test_verifyUser(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := VerifyUser(tt.args.username, tt.args.password)(tt.args.db.db); !errors.Is(err, tt.targetErr) { if err := VerifyUser(tt.args.username, tt.args.password)(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
t.Errorf("VerifyGrant() error = %v, want: %v", err, tt.targetErr) t.Errorf("VerifyGrant() error = %v, want: %v", err, tt.targetErr)
} }
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil { if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {

View File

@@ -2,6 +2,7 @@ package initialise
import ( import (
"context" "context"
"database/sql"
_ "embed" _ "embed"
"fmt" "fmt"
@@ -11,6 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/database/dialect"
es_v3 "github.com/zitadel/zitadel/internal/eventstore/v3"
) )
func newZitadel() *cobra.Command { func newZitadel() *cobra.Command {
@@ -36,38 +38,44 @@ func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config)
return err return err
} }
conn, err := db.Conn(ctx)
if err != nil {
return err
}
defer conn.Close()
logging.WithFields().Info("verify system") logging.WithFields().Info("verify system")
if err := exec(db, fmt.Sprintf(createSystemStmt, config.Username()), nil); err != nil { if err := exec(ctx, conn, fmt.Sprintf(createSystemStmt, config.Username()), nil); err != nil {
return err return err
} }
logging.WithFields().Info("verify encryption keys") logging.WithFields().Info("verify encryption keys")
if err := createEncryptionKeys(ctx, db); err != nil { if err := createEncryptionKeys(ctx, conn); err != nil {
return err return err
} }
logging.WithFields().Info("verify projections") logging.WithFields().Info("verify projections")
if err := exec(db, fmt.Sprintf(createProjectionsStmt, config.Username()), nil); err != nil { if err := exec(ctx, conn, fmt.Sprintf(createProjectionsStmt, config.Username()), nil); err != nil {
return err return err
} }
logging.WithFields().Info("verify eventstore") logging.WithFields().Info("verify eventstore")
if err := exec(db, fmt.Sprintf(createEventstoreStmt, config.Username()), nil); err != nil { if err := exec(ctx, conn, fmt.Sprintf(createEventstoreStmt, config.Username()), nil); err != nil {
return err return err
} }
logging.WithFields().Info("verify events tables") logging.WithFields().Info("verify events tables")
if err := createEvents(ctx, db); err != nil { if err := createEvents(ctx, conn); err != nil {
return err return err
} }
logging.WithFields().Info("verify system sequence") logging.WithFields().Info("verify system sequence")
if err := exec(db, createSystemSequenceStmt, nil); err != nil { if err := exec(ctx, conn, createSystemSequenceStmt, nil); err != nil {
return err return err
} }
logging.WithFields().Info("verify unique constraints") logging.WithFields().Info("verify unique constraints")
if err := exec(db, createUniqueConstraints, nil); err != nil { if err := exec(ctx, conn, createUniqueConstraints, nil); err != nil {
return err return err
} }
@@ -89,7 +97,7 @@ func verifyZitadel(ctx context.Context, config database.Config) error {
return db.Close() return db.Close()
} }
func createEncryptionKeys(ctx context.Context, db *database.DB) error { func createEncryptionKeys(ctx context.Context, db database.Beginner) error {
tx, err := db.BeginTx(ctx, nil) tx, err := db.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
@@ -103,8 +111,8 @@ func createEncryptionKeys(ctx context.Context, db *database.DB) error {
return tx.Commit() return tx.Commit()
} }
func createEvents(ctx context.Context, db *database.DB) (err error) { func createEvents(ctx context.Context, conn *sql.Conn) (err error) {
tx, err := db.BeginTx(ctx, nil) tx, err := conn.BeginTx(ctx, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -127,5 +135,8 @@ func createEvents(ctx context.Context, db *database.DB) (err error) {
return row.Err() return row.Err()
} }
_, err = tx.Exec(createEventsStmt) _, err = tx.Exec(createEventsStmt)
return err if err != nil {
return err
}
return es_v3.CheckExecutionPlan(ctx, conn)
} }

View File

@@ -108,7 +108,12 @@ func Test_verifyEvents(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
if err := createEvents(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) { conn, err := tt.args.db.db.Conn(context.Background())
if err != nil {
t.Error(err)
return
}
if err := createEvents(context.Background(), conn); !errors.Is(err, tt.targetErr) {
t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr) t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr)
} }
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil { if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {

52
cmd/setup/40.go Normal file
View File

@@ -0,0 +1,52 @@
package setup
import (
"context"
"embed"
"fmt"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed 40/cockroach/*.sql
//go:embed 40/postgres/*.sql
initPushFunc embed.FS
)
type InitPushFunc struct {
dbClient *database.DB
}
func (mig *InitPushFunc) Execute(ctx context.Context, _ eventstore.Event) (err error) {
statements, err := readStatements(initPushFunc, "40", mig.dbClient.Type())
if err != nil {
return err
}
conn, err := mig.dbClient.Conn(ctx)
if err != nil {
return err
}
defer func() {
closeErr := conn.Close()
logging.OnError(closeErr).Debug("failed to release connection")
// Force the pool to reopen connections to apply the new types
mig.dbClient.Pool.Reset()
}()
for _, stmt := range statements {
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
if _, err := conn.ExecContext(ctx, stmt.query); err != nil {
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
}
}
return nil
}
func (mig *InitPushFunc) String() string {
return "40_init_push_func"
}

View File

@@ -0,0 +1,107 @@
-- represents an event to be created.
CREATE TYPE IF NOT EXISTS eventstore.command AS (
instance_id TEXT
, aggregate_type TEXT
, aggregate_id TEXT
, command_type TEXT
, revision INT2
, payload JSONB
, creator TEXT
, owner TEXT
);
/*
select * from eventstore.commands_to_events(
ARRAY[
ROW('', 'system', 'SYSTEM', 'ct1', 1, '{"key": "value"}', 'c1', 'SYSTEM')
, ROW('', 'system', 'SYSTEM', 'ct2', 1, '{"key": "value"}', 'c1', 'SYSTEM')
, ROW('289525561255060732', 'org', '289575074711790844', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844')
, ROW('289525561255060732', 'user', '289575075164906748', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844')
, ROW('289525561255060732', 'oidc_session', 'V2_289575178579535100', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844')
, ROW('', 'system', 'SYSTEM', 'ct3', 1, '{"key": "value"}', 'c1', 'SYSTEM')
]::eventstore.command[]
);
*/
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
SELECT
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
, ("c").command_type AS event_type
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY ("c").instance_id, ("c").aggregate_type, ("c").aggregate_id ORDER BY ("c").in_tx_order) AS sequence
, ("c").revision
, hlc_to_timestamp(cluster_logical_timestamp()) AS created_at
, ("c").payload
, ("c").creator
, cs.owner
, cluster_logical_timestamp() AS position
, ("c").in_tx_order
FROM (
SELECT
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
, ("c").command_type
, ("c").revision
, ("c").payload
, ("c").creator
, ("c").owner
, ROW_NUMBER() OVER () AS in_tx_order
FROM
UNNEST(commands) AS "c"
) AS "c"
JOIN (
SELECT
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, CASE WHEN (e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
, COALESCE(MAX(e.sequence), 0) AS sequence
FROM (
SELECT DISTINCT
("cmds").instance_id
, ("cmds").aggregate_type
, ("cmds").aggregate_id
, ("cmds").owner
FROM UNNEST(commands) AS "cmds"
) AS cmds
LEFT JOIN eventstore.events2 AS e
ON cmds.instance_id = e.instance_id
AND cmds.aggregate_type = e.aggregate_type
AND cmds.aggregate_id = e.aggregate_id
JOIN (
SELECT
DISTINCT ON (
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
)
("c").instance_id
, ("c").aggregate_type
, ("c").aggregate_id
, ("c").owner
FROM
UNNEST(commands) AS "c"
) AS command_owners ON
cmds.instance_id = command_owners.instance_id
AND cmds.aggregate_type = command_owners.aggregate_type
AND cmds.aggregate_id = command_owners.aggregate_id
GROUP BY
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, 4 -- owner
) AS cs
ON ("c").instance_id = cs.instance_id
AND ("c").aggregate_type = cs.aggregate_type
AND ("c").aggregate_id = cs.aggregate_id
ORDER BY
in_tx_order
$$ LANGUAGE SQL;
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 AS $$
INSERT INTO eventstore.events2
SELECT * FROM eventstore.commands_to_events(commands)
RETURNING *
$$ LANGUAGE SQL;

View File

@@ -0,0 +1,15 @@
-- represents an event to be created.
DO $$ BEGIN
CREATE TYPE eventstore.command AS (
instance_id TEXT
, aggregate_type TEXT
, aggregate_id TEXT
, command_type TEXT
, revision INT2
, payload JSONB
, creator TEXT
, owner TEXT
);
EXCEPTION
WHEN duplicate_object THEN null;
END $$;

View File

@@ -0,0 +1,82 @@
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
SELECT
c.instance_id
, c.aggregate_type
, c.aggregate_id
, c.command_type AS event_type
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY c.instance_id, c.aggregate_type, c.aggregate_id ORDER BY c.in_tx_order) AS sequence
, c.revision
, NOW() AS created_at
, c.payload
, c.creator
, cs.owner
, EXTRACT(EPOCH FROM NOW()) AS position
, c.in_tx_order
FROM (
SELECT
c.instance_id
, c.aggregate_type
, c.aggregate_id
, c.command_type
, c.revision
, c.payload
, c.creator
, c.owner
, ROW_NUMBER() OVER () AS in_tx_order
FROM
UNNEST(commands) AS c
) AS c
JOIN (
SELECT
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, CASE WHEN (e.owner IS NOT NULL OR e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
, COALESCE(MAX(e.sequence), 0) AS sequence
FROM (
SELECT DISTINCT
instance_id
, aggregate_type
, aggregate_id
, owner
FROM UNNEST(commands)
) AS cmds
LEFT JOIN eventstore.events2 AS e
ON cmds.instance_id = e.instance_id
AND cmds.aggregate_type = e.aggregate_type
AND cmds.aggregate_id = e.aggregate_id
JOIN (
SELECT
DISTINCT ON (
instance_id
, aggregate_type
, aggregate_id
)
instance_id
, aggregate_type
, aggregate_id
, owner
FROM
UNNEST(commands)
) AS command_owners ON
cmds.instance_id = command_owners.instance_id
AND cmds.aggregate_type = command_owners.aggregate_type
AND cmds.aggregate_id = command_owners.aggregate_id
GROUP BY
cmds.instance_id
, cmds.aggregate_type
, cmds.aggregate_id
, 4 -- owner
) AS cs
ON c.instance_id = cs.instance_id
AND c.aggregate_type = cs.aggregate_type
AND c.aggregate_id = cs.aggregate_id
ORDER BY
in_tx_order;
$$ LANGUAGE SQL;
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
INSERT INTO eventstore.events2
SELECT * FROM eventstore.commands_to_events(commands)
RETURNING *
$$ LANGUAGE SQL;

View File

@@ -126,6 +126,7 @@ type Steps struct {
s36FillV2Milestones *FillV3Milestones s36FillV2Milestones *FillV3Milestones
s37Apps7OIDConfigsBackChannelLogoutURI *Apps7OIDConfigsBackChannelLogoutURI s37Apps7OIDConfigsBackChannelLogoutURI *Apps7OIDConfigsBackChannelLogoutURI
s38BackChannelLogoutNotificationStart *BackChannelLogoutNotificationStart s38BackChannelLogoutNotificationStart *BackChannelLogoutNotificationStart
s40InitPushFunc *InitPushFunc
s39DeleteStaleOrgFields *DeleteStaleOrgFields s39DeleteStaleOrgFields *DeleteStaleOrgFields
} }

View File

@@ -170,6 +170,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: esPusherDBClient} steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: esPusherDBClient}
steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: esPusherDBClient, esClient: eventstoreClient} steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: esPusherDBClient, esClient: eventstoreClient}
steps.s39DeleteStaleOrgFields = &DeleteStaleOrgFields{dbClient: esPusherDBClient} steps.s39DeleteStaleOrgFields = &DeleteStaleOrgFields{dbClient: esPusherDBClient}
steps.s40InitPushFunc = &InitPushFunc{dbClient: esPusherDBClient}
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil) err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
logging.OnError(err).Fatal("unable to start projections") logging.OnError(err).Fatal("unable to start projections")
@@ -190,6 +191,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
for _, step := range []migration.Migration{ for _, step := range []migration.Migration{
steps.s14NewEventsTable, steps.s14NewEventsTable,
steps.s40InitPushFunc,
steps.s1ProjectionTable, steps.s1ProjectionTable,
steps.s2AssetsTable, steps.s2AssetsTable,
steps.s28AddFieldTable, steps.s28AddFieldTable,

View File

@@ -216,6 +216,7 @@ func assertFeatureDisabledError(t *testing.T, err error) {
} }
func checkWebKeyListState(ctx context.Context, t *testing.T, instance *integration.Instance, nKeys int, expectActiveKeyID string, config any, creationDate *timestamppb.Timestamp) { func checkWebKeyListState(ctx context.Context, t *testing.T, instance *integration.Instance, nKeys int, expectActiveKeyID string, config any, creationDate *timestamppb.Timestamp) {
t.Helper()
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute) retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute)
assert.EventuallyWithT(t, func(collect *assert.CollectT) { assert.EventuallyWithT(t, func(collect *assert.CollectT) {

View File

@@ -3,15 +3,18 @@ package cockroach
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib" "github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/database/dialect"
) )
@@ -72,6 +75,12 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
} }
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) { func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) {
dialect.RegisterAfterConnect(func(ctx context.Context, c *pgx.Conn) error {
// CockroachDB by default does not allow multiple modifications of the same table using ON CONFLICT
// This is needed to fill the fields table of the eventstore during eventstore.Push.
_, err := c.Exec(ctx, "SET enable_multiple_modifications_of_table = on")
return err
})
connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose) connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
@@ -82,6 +91,29 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
return nil, nil, err return nil, nil, err
} }
if len(connConfig.AfterConnect) > 0 {
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
for _, f := range connConfig.AfterConnect {
if err := f(ctx, conn); err != nil {
return err
}
}
return nil
}
}
// For the pusher we set the app name with the instance ID
if purpose == dialect.DBPurposeEventPusher {
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID())
}
config.AfterRelease = func(conn *pgx.Conn) bool {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return setAppNameWithID(ctx, conn, purpose, "IDLE")
}
}
if connConfig.MaxOpenConns != 0 { if connConfig.MaxOpenConns != 0 {
config.MaxConns = int32(connConfig.MaxOpenConns) config.MaxConns = int32(connConfig.MaxOpenConns)
} }
@@ -200,3 +232,11 @@ func (c Config) String(useAdmin bool, appName string) string {
return strings.Join(fields, " ") return strings.Join(fields, " ")
} }
func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool {
// needs to be set like this because psql complains about parameters in the SET statement
query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id)
_, err := conn.Exec(ctx, query)
logging.OnError(err).Warn("failed to set application name")
return err == nil
}

View File

@@ -18,21 +18,31 @@ import (
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
type QueryExecuter interface { type ContextQuerier interface {
Query(query string, args ...any) (*sql.Rows, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
Exec(query string, args ...any) (sql.Result, error) }
type ContextExecuter interface {
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
} }
type ContextQueryExecuter interface {
ContextQuerier
ContextExecuter
}
type Client interface { type Client interface {
QueryExecuter ContextQueryExecuter
Beginner
Conn(ctx context.Context) (*sql.Conn, error)
}
type Beginner interface {
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error) BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
Begin() (*sql.Tx, error)
} }
type Tx interface { type Tx interface {
QueryExecuter ContextQueryExecuter
Commit() error Commit() error
Rollback() error Rollback() error
} }

View File

@@ -1,8 +1,13 @@
package dialect package dialect
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"reflect"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
) )
var ( var (
@@ -17,6 +22,7 @@ var (
type ConnectionConfig struct { type ConnectionConfig struct {
MaxOpenConns, MaxOpenConns,
MaxIdleConns uint32 MaxIdleConns uint32
AfterConnect []func(ctx context.Context, c *pgx.Conn) error
} }
// takeRatio of MaxOpenConns and MaxIdleConns from config and returns // takeRatio of MaxOpenConns and MaxIdleConns from config and returns
@@ -29,6 +35,7 @@ func (c *ConnectionConfig) takeRatio(ratio float64) (*ConnectionConfig, error) {
out := &ConnectionConfig{ out := &ConnectionConfig{
MaxOpenConns: uint32(ratio * float64(c.MaxOpenConns)), MaxOpenConns: uint32(ratio * float64(c.MaxOpenConns)),
MaxIdleConns: uint32(ratio * float64(c.MaxIdleConns)), MaxIdleConns: uint32(ratio * float64(c.MaxIdleConns)),
AfterConnect: c.AfterConnect,
} }
if c.MaxOpenConns != 0 && out.MaxOpenConns < 1 && ratio > 0 { if c.MaxOpenConns != 0 && out.MaxOpenConns < 1 && ratio > 0 {
out.MaxOpenConns = 1 out.MaxOpenConns = 1
@@ -40,6 +47,36 @@ func (c *ConnectionConfig) takeRatio(ratio float64) (*ConnectionConfig, error) {
return out, nil return out, nil
} }
var afterConnectFuncs []func(ctx context.Context, c *pgx.Conn) error
func RegisterAfterConnect(f func(ctx context.Context, c *pgx.Conn) error) {
afterConnectFuncs = append(afterConnectFuncs, f)
}
func RegisterDefaultPgTypeVariants[T any](m *pgtype.Map, name, arrayName string) {
// T
var value T
m.RegisterDefaultPgType(value, name)
// *T
valueType := reflect.TypeOf(value)
m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
// []T
sliceType := reflect.SliceOf(valueType)
m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
// *[]T
m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
// []*T
sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface()))
m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName)
// *[]*T
m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName)
}
// NewConnectionConfig calculates [ConnectionConfig] values from the passed ratios // NewConnectionConfig calculates [ConnectionConfig] values from the passed ratios
// and returns the config applicable for the requested purpose. // and returns the config applicable for the requested purpose.
// //
@@ -59,11 +96,13 @@ func NewConnectionConfig(openConns, idleConns uint32, pusherRatio, projectionRat
queryConfig := &ConnectionConfig{ queryConfig := &ConnectionConfig{
MaxOpenConns: openConns, MaxOpenConns: openConns,
MaxIdleConns: idleConns, MaxIdleConns: idleConns,
AfterConnect: afterConnectFuncs,
} }
pusherConfig, err := queryConfig.takeRatio(pusherRatio) pusherConfig, err := queryConfig.takeRatio(pusherRatio)
if err != nil { if err != nil {
return nil, fmt.Errorf("event pusher: %w", err) return nil, fmt.Errorf("event pusher: %w", err)
} }
spoolerConfig, err := queryConfig.takeRatio(projectionRatio) spoolerConfig, err := queryConfig.takeRatio(projectionRatio)
if err != nil { if err != nil {
return nil, fmt.Errorf("projection spooler: %w", err) return nil, fmt.Errorf("projection spooler: %w", err)

View File

@@ -3,15 +3,18 @@ package postgres
import ( import (
"context" "context"
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib" "github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure" "github.com/mitchellh/mapstructure"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/database/dialect"
) )
@@ -83,6 +86,27 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
return nil, nil, err return nil, nil, err
} }
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
for _, f := range connConfig.AfterConnect {
if err := f(ctx, conn); err != nil {
return err
}
}
return nil
}
// For the pusher we set the app name with the instance ID
if purpose == dialect.DBPurposeEventPusher {
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID())
}
config.AfterRelease = func(conn *pgx.Conn) bool {
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
return setAppNameWithID(ctx, conn, purpose, "IDLE")
}
}
if connConfig.MaxOpenConns != 0 { if connConfig.MaxOpenConns != 0 {
config.MaxConns = int32(connConfig.MaxOpenConns) config.MaxConns = int32(connConfig.MaxOpenConns)
} }
@@ -209,3 +233,11 @@ func (c Config) String(useAdmin bool, appName string) string {
return strings.Join(fields, " ") return strings.Join(fields, " ")
} }
func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool {
// needs to be set like this because psql complains about parameters in the SET statement
query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id)
_, err := conn.Exec(ctx, query)
logging.OnError(err).Warn("failed to set application name")
return err == nil
}

View File

@@ -3,8 +3,12 @@ package eventstore
import ( import (
"context" "context"
"encoding/json" "encoding/json"
"strconv"
"strings"
"time" "time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/service" "github.com/zitadel/zitadel/internal/api/service"
) )
@@ -84,8 +88,10 @@ func (e *BaseEvent) DataAsBytes() []byte {
} }
// Revision implements action // Revision implements action
func (*BaseEvent) Revision() uint16 { func (e *BaseEvent) Revision() uint16 {
return 0 revision, err := strconv.ParseUint(strings.TrimPrefix(string(e.Agg.Version), "v"), 10, 16)
logging.OnError(err).Debug("failed to parse event revision")
return uint16(revision)
} }
// Unmarshal implements Event // Unmarshal implements Event

View File

@@ -90,7 +90,7 @@ func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error
// PushWithClient pushes the events in a single transaction using the provided database client // PushWithClient pushes the events in a single transaction using the provided database client
// an event needs at least an aggregate // an event needs at least an aggregate
func (es *Eventstore) PushWithClient(ctx context.Context, client database.QueryExecuter, cmds ...Command) ([]Event, error) { func (es *Eventstore) PushWithClient(ctx context.Context, client database.ContextQueryExecuter, cmds ...Command) ([]Event, error) {
if es.PushTimeout > 0 { if es.PushTimeout > 0 {
var cancel func() var cancel func()
ctx, cancel = context.WithTimeout(ctx, es.PushTimeout) ctx, cancel = context.WithTimeout(ctx, es.PushTimeout)
@@ -301,7 +301,7 @@ type Pusher interface {
// Health checks if the connection to the storage is available // Health checks if the connection to the storage is available
Health(ctx context.Context) error Health(ctx context.Context) error
// Push stores the actions // Push stores the actions
Push(ctx context.Context, client database.QueryExecuter, commands ...Command) (_ []Event, err error) Push(ctx context.Context, client database.ContextQueryExecuter, commands ...Command) (_ []Event, err error)
// Client returns the underlying database connection // Client returns the underlying database connection
Client() *database.DB Client() *database.DB
} }

View File

@@ -347,7 +347,7 @@ func (repo *testPusher) Health(ctx context.Context) error {
return nil return nil
} }
func (repo *testPusher) Push(_ context.Context, _ database.QueryExecuter, commands ...Command) (events []Event, err error) { func (repo *testPusher) Push(_ context.Context, _ database.ContextQueryExecuter, commands ...Command) (events []Event, err error) {
if len(repo.errs) != 0 { if len(repo.errs) != 0 {
err, repo.errs = repo.errs[0], repo.errs[1:] err, repo.errs = repo.errs[0], repo.errs[1:]
return nil, err return nil, err
@@ -490,6 +490,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -534,6 +535,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -585,6 +587,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -596,6 +599,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -658,6 +662,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -669,6 +674,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -682,6 +688,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -778,6 +785,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -828,6 +836,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",
@@ -883,6 +892,7 @@ func TestEventstore_Push(t *testing.T) {
Type: "test.aggregate", Type: "test.aggregate",
ResourceOwner: "caos", ResourceOwner: "caos",
InstanceID: "zitadel", InstanceID: "zitadel",
Version: "v1",
}, },
Data: []byte(nil), Data: []byte(nil),
User: "editorUser", User: "editorUser",

View File

@@ -9,6 +9,8 @@ import (
"time" "time"
"github.com/cockroachdb/cockroach-go/v2/testserver" "github.com/cockroachdb/cockroach-go/v2/testserver"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/cmd/initialise" "github.com/zitadel/zitadel/cmd/initialise"
@@ -39,10 +41,17 @@ func TestMain(m *testing.M) {
testCRDBClient = &database.DB{ testCRDBClient = &database.DB{
Database: new(testDB), Database: new(testDB),
} }
testCRDBClient.DB, err = sql.Open("postgres", ts.PGURL().String())
connConfig, err := pgxpool.ParseConfig(ts.PGURL().String())
if err != nil { if err != nil {
logging.WithFields("error", err).Fatal("unable to connect to db") logging.WithFields("error", err).Fatal("unable to parse db url")
} }
connConfig.AfterConnect = new_es.RegisterEventstoreTypes
pool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
if err != nil {
logging.WithFields("error", err).Fatal("unable to create db pool")
}
testCRDBClient.DB = stdlib.OpenDBFromPool(pool)
if err = testCRDBClient.Ping(); err != nil { if err = testCRDBClient.Ping(); err != nil {
logging.WithFields("error", err).Fatal("unable to ping db") logging.WithFields("error", err).Fatal("unable to ping db")
} }
@@ -55,7 +64,7 @@ func TestMain(m *testing.M) {
clients["v3(inmemory)"] = testCRDBClient clients["v3(inmemory)"] = testCRDBClient
if localDB, err := connectLocalhost(); err == nil { if localDB, err := connectLocalhost(); err == nil {
if err = initDB(localDB); err != nil { if err = initDB(context.Background(), localDB); err != nil {
logging.WithFields("error", err).Fatal("migrations failed") logging.WithFields("error", err).Fatal("migrations failed")
} }
pushers["v3(singlenode)"] = new_es.NewEventstore(localDB) pushers["v3(singlenode)"] = new_es.NewEventstore(localDB)
@@ -69,14 +78,14 @@ func TestMain(m *testing.M) {
ts.Stop() ts.Stop()
}() }()
if err = initDB(testCRDBClient); err != nil { if err = initDB(context.Background(), testCRDBClient); err != nil {
logging.WithFields("error", err).Fatal("migrations failed") logging.WithFields("error", err).Fatal("migrations failed")
} }
os.Exit(m.Run()) os.Exit(m.Run())
} }
func initDB(db *database.DB) error { func initDB(ctx context.Context, db *database.DB) error {
initialise.ReadStmts("cockroach") initialise.ReadStmts("cockroach")
config := new(database.Config) config := new(database.Config)
config.SetConnector(&cockroach.Config{ config.SetConnector(&cockroach.Config{
@@ -85,7 +94,7 @@ func initDB(db *database.DB) error {
}, },
Database: "zitadel", Database: "zitadel",
}) })
err := initialise.Init(db, err := initialise.Init(ctx, db,
initialise.VerifyUser(config.Username(), ""), initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.DatabaseName()), initialise.VerifyDatabase(config.DatabaseName()),
initialise.VerifyGrant(config.DatabaseName(), config.Username()), initialise.VerifyGrant(config.DatabaseName(), config.Username()),
@@ -93,7 +102,7 @@ func initDB(db *database.DB) error {
if err != nil { if err != nil {
return err return err
} }
err = initialise.VerifyZitadel(context.Background(), db, *config) err = initialise.VerifyZitadel(ctx, db, *config)
if err != nil { if err != nil {
return err return err
} }

View File

@@ -3,8 +3,12 @@ package repository
import ( import (
"database/sql" "database/sql"
"encoding/json" "encoding/json"
"strconv"
"strings"
"time" "time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
) )
@@ -82,7 +86,9 @@ func (e *Event) Type() eventstore.EventType {
// Revision implements [eventstore.Event] // Revision implements [eventstore.Event]
func (e *Event) Revision() uint16 { func (e *Event) Revision() uint16 {
return 0 revision, err := strconv.ParseUint(strings.TrimPrefix(string(e.Version), "v"), 10, 16)
logging.OnError(err).Debug("failed to parse event revision")
return uint16(revision)
} }
// Sequence implements [eventstore.Event] // Sequence implements [eventstore.Event]

View File

@@ -165,7 +165,7 @@ func (mr *MockPusherMockRecorder) Health(arg0 any) *gomock.Call {
} }
// Push mocks base method. // Push mocks base method.
func (m *MockPusher) Push(arg0 context.Context, arg1 database.QueryExecuter, arg2 ...eventstore.Command) ([]eventstore.Event, error) { func (m *MockPusher) Push(arg0 context.Context, arg1 database.ContextQueryExecuter, arg2 ...eventstore.Command) ([]eventstore.Event, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
varargs := []any{arg0, arg1} varargs := []any{arg0, arg1}
for _, a := range arg2 { for _, a := range arg2 {

View File

@@ -80,7 +80,7 @@ func (m *MockRepository) ExpectInstanceIDsError(err error) *MockRepository {
// The call will sleep at least the amount of passed duration. // The call will sleep at least the amount of passed duration.
func (m *MockRepository) ExpectPush(expectedCommands []eventstore.Command, sleep time.Duration) *MockRepository { func (m *MockRepository) ExpectPush(expectedCommands []eventstore.Command, sleep time.Duration) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.QueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) { func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
m.MockPusher.ctrl.T.Helper() m.MockPusher.ctrl.T.Helper()
time.Sleep(sleep) time.Sleep(sleep)
@@ -135,7 +135,7 @@ func (m *MockRepository) ExpectPushFailed(err error, expectedCommands []eventsto
m.MockPusher.ctrl.T.Helper() m.MockPusher.ctrl.T.Helper()
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.QueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) { func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
if len(expectedCommands) != len(commands) { if len(expectedCommands) != len(commands) {
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands)) return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
} }
@@ -197,7 +197,7 @@ func (e *mockEvent) CreatedAt() time.Time {
func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command) *MockRepository { func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.QueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) { func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
assert.Len(m.MockPusher.ctrl.T, commands, len(expectedCommands)) assert.Len(m.MockPusher.ctrl.T, commands, len(expectedCommands))
events := make([]eventstore.Event, len(commands)) events := make([]eventstore.Event, len(commands))
@@ -215,7 +215,7 @@ func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command)
func (m *MockRepository) ExpectRandomPushFailed(err error, expectedEvents []eventstore.Command) *MockRepository { func (m *MockRepository) ExpectRandomPushFailed(err error, expectedEvents []eventstore.Command) *MockRepository {
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn( m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
func(ctx context.Context, _ database.QueryExecuter, events ...eventstore.Command) ([]eventstore.Event, error) { func(ctx context.Context, _ database.ContextQueryExecuter, events ...eventstore.Command) ([]eventstore.Event, error) {
assert.Len(m.MockPusher.ctrl.T, events, len(expectedEvents)) assert.Len(m.MockPusher.ctrl.T, events, len(expectedEvents))
return nil, err return nil, err
}, },

View File

@@ -8,11 +8,14 @@ import (
"time" "time"
"github.com/cockroachdb/cockroach-go/v2/testserver" "github.com/cockroachdb/cockroach-go/v2/testserver"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/jackc/pgx/v5/stdlib"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/cmd/initialise" "github.com/zitadel/zitadel/cmd/initialise"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/cockroach" "github.com/zitadel/zitadel/internal/database/cockroach"
new_es "github.com/zitadel/zitadel/internal/eventstore/v3"
) )
var ( var (
@@ -29,10 +32,18 @@ func TestMain(m *testing.M) {
logging.WithFields("error", err).Fatal("unable to start db") logging.WithFields("error", err).Fatal("unable to start db")
} }
testCRDBClient, err = sql.Open("postgres", ts.PGURL().String()) connConfig, err := pgxpool.ParseConfig(ts.PGURL().String())
if err != nil { if err != nil {
logging.WithFields("error", err).Fatal("unable to connect to db") logging.WithFields("error", err).Fatal("unable to parse db url")
} }
connConfig.AfterConnect = new_es.RegisterEventstoreTypes
pool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
if err != nil {
logging.WithFields("error", err).Fatal("unable to create db pool")
}
testCRDBClient = stdlib.OpenDBFromPool(pool)
if err = testCRDBClient.Ping(); err != nil { if err = testCRDBClient.Ping(); err != nil {
logging.WithFields("error", err).Fatal("unable to ping db") logging.WithFields("error", err).Fatal("unable to ping db")
} }
@@ -42,14 +53,14 @@ func TestMain(m *testing.M) {
ts.Stop() ts.Stop()
}() }()
if err = initDB(&database.DB{DB: testCRDBClient, Database: &cockroach.Config{Database: "zitadel"}}); err != nil { if err = initDB(context.Background(), &database.DB{DB: testCRDBClient, Database: &cockroach.Config{Database: "zitadel"}}); err != nil {
logging.WithFields("error", err).Fatal("migrations failed") logging.WithFields("error", err).Fatal("migrations failed")
} }
os.Exit(m.Run()) os.Exit(m.Run())
} }
func initDB(db *database.DB) error { func initDB(ctx context.Context, db *database.DB) error {
config := new(database.Config) config := new(database.Config)
config.SetConnector(&cockroach.Config{User: cockroach.User{Username: "zitadel"}, Database: "zitadel"}) config.SetConnector(&cockroach.Config{User: cockroach.User{Username: "zitadel"}, Database: "zitadel"})
@@ -57,7 +68,7 @@ func initDB(db *database.DB) error {
return err return err
} }
err := initialise.Init(db, err := initialise.Init(ctx, db,
initialise.VerifyUser(config.Username(), ""), initialise.VerifyUser(config.Username(), ""),
initialise.VerifyDatabase(config.DatabaseName()), initialise.VerifyDatabase(config.DatabaseName()),
initialise.VerifyGrant(config.DatabaseName(), config.Username()), initialise.VerifyGrant(config.DatabaseName(), config.Username()),

View File

@@ -1,11 +1,13 @@
package eventstore package eventstore
import ( import (
"context"
"encoding/json" "encoding/json"
"strconv"
"time" "time"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@@ -14,33 +16,98 @@ var (
_ eventstore.Event = (*event)(nil) _ eventstore.Event = (*event)(nil)
) )
type command struct {
InstanceID string
AggregateType string
AggregateID string
CommandType string
Revision uint16
Payload Payload
Creator string
Owner string
}
func (c *command) Aggregate() *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: c.AggregateID,
Type: eventstore.AggregateType(c.AggregateType),
ResourceOwner: c.Owner,
InstanceID: c.InstanceID,
Version: eventstore.Version("v" + strconv.Itoa(int(c.Revision))),
}
}
type event struct { type event struct {
aggregate *eventstore.Aggregate command *command
creator string
revision uint16
typ eventstore.EventType
createdAt time.Time createdAt time.Time
sequence uint64 sequence uint64
position float64 position float64
payload Payload
} }
func commandToEvent(sequence *latestSequence, command eventstore.Command) (_ *event, err error) { // TODO: remove on v3
func commandToEventOld(sequence *latestSequence, cmd eventstore.Command) (_ *event, err error) {
var payload Payload var payload Payload
if command.Payload() != nil { if cmd.Payload() != nil {
payload, err = json.Marshal(command.Payload()) payload, err = json.Marshal(cmd.Payload())
if err != nil { if err != nil {
logging.WithError(err).Warn("marshal payload failed") logging.WithError(err).Warn("marshal payload failed")
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
} }
} }
return &event{ return &event{
aggregate: sequence.aggregate, command: &command{
creator: command.Creator(), InstanceID: sequence.aggregate.InstanceID,
revision: command.Revision(), AggregateType: string(sequence.aggregate.Type),
typ: command.Type(), AggregateID: sequence.aggregate.ID,
payload: payload, CommandType: string(cmd.Type()),
sequence: sequence.sequence, Revision: cmd.Revision(),
Payload: payload,
Creator: cmd.Creator(),
Owner: sequence.aggregate.ResourceOwner,
},
sequence: sequence.sequence,
}, nil
}
func commandsToEvents(ctx context.Context, cmds []eventstore.Command) (_ []eventstore.Event, _ []*command, err error) {
events := make([]eventstore.Event, len(cmds))
commands := make([]*command, len(cmds))
for i, cmd := range cmds {
if cmd.Aggregate().InstanceID == "" {
cmd.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
}
events[i], err = commandToEvent(cmd)
if err != nil {
return nil, nil, err
}
commands[i] = events[i].(*event).command
}
return events, commands, nil
}
func commandToEvent(cmd eventstore.Command) (_ eventstore.Event, err error) {
var payload Payload
if cmd.Payload() != nil {
payload, err = json.Marshal(cmd.Payload())
if err != nil {
logging.WithError(err).Warn("marshal payload failed")
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
}
}
command := &command{
InstanceID: cmd.Aggregate().InstanceID,
AggregateType: string(cmd.Aggregate().Type),
AggregateID: cmd.Aggregate().ID,
CommandType: string(cmd.Type()),
Revision: cmd.Revision(),
Payload: payload,
Creator: cmd.Creator(),
Owner: cmd.Aggregate().ResourceOwner,
}
return &event{
command: command,
}, nil }, nil
} }
@@ -56,22 +123,22 @@ func (e *event) EditorUser() string {
// Aggregate implements [eventstore.Event] // Aggregate implements [eventstore.Event]
func (e *event) Aggregate() *eventstore.Aggregate { func (e *event) Aggregate() *eventstore.Aggregate {
return e.aggregate return e.command.Aggregate()
} }
// Creator implements [eventstore.Event] // Creator implements [eventstore.Event]
func (e *event) Creator() string { func (e *event) Creator() string {
return e.creator return e.command.Creator
} }
// Revision implements [eventstore.Event] // Revision implements [eventstore.Event]
func (e *event) Revision() uint16 { func (e *event) Revision() uint16 {
return e.revision return e.command.Revision
} }
// Type implements [eventstore.Event] // Type implements [eventstore.Event]
func (e *event) Type() eventstore.EventType { func (e *event) Type() eventstore.EventType {
return e.typ return eventstore.EventType(e.command.CommandType)
} }
// CreatedAt implements [eventstore.Event] // CreatedAt implements [eventstore.Event]
@@ -91,10 +158,10 @@ func (e *event) Position() float64 {
// Unmarshal implements [eventstore.Event] // Unmarshal implements [eventstore.Event]
func (e *event) Unmarshal(ptr any) error { func (e *event) Unmarshal(ptr any) error {
if len(e.payload) == 0 { if len(e.command.Payload) == 0 {
return nil return nil
} }
if err := json.Unmarshal(e.payload, ptr); err != nil { if err := json.Unmarshal(e.command.Payload, ptr); err != nil {
return zerrors.ThrowInternal(err, "V3-u8qVo", "Errors.Internal") return zerrors.ThrowInternal(err, "V3-u8qVo", "Errors.Internal")
} }
@@ -103,5 +170,5 @@ func (e *event) Unmarshal(ptr any) error {
// DataAsBytes implements [eventstore.Event] // DataAsBytes implements [eventstore.Event]
func (e *event) DataAsBytes() []byte { func (e *event) DataAsBytes() []byte {
return e.payload return e.command.Payload
} }

View File

@@ -1,16 +1,122 @@
package eventstore package eventstore
import ( import (
"context"
"encoding/json" "encoding/json"
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
) )
func Test_commandToEvent(t *testing.T) { func Test_commandToEvent(t *testing.T) {
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
command eventstore.Command
}
type want struct {
event *event
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
).(*event),
},
},
{
name: "struct payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "pointer payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: &payload,
},
},
want: want{
event: mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
).(*event),
},
},
{
name: "invalid payload",
args: args{
command: &mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
want: want{
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
if tt.want.err == nil {
tt.want.err = func(t *testing.T, err error) {
require.NoError(t, err)
}
}
t.Run(tt.name, func(t *testing.T) {
got, err := commandToEvent(tt.args.command)
tt.want.err(t, err)
if tt.want.event == nil {
assert.Nil(t, got)
return
}
assert.Equal(t, tt.want.event, got)
})
}
}
func Test_commandToEventOld(t *testing.T) {
payload := struct { payload := struct {
ID string ID string
}{ }{
@@ -119,10 +225,258 @@ func Test_commandToEvent(t *testing.T) {
} }
} }
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := commandToEvent(tt.args.sequence, tt.args.command) got, err := commandToEventOld(tt.args.sequence, tt.args.command)
tt.want.err(t, err) tt.want.err(t, err)
assert.Equal(t, tt.want.event, got) assert.Equal(t, tt.want.event, got)
}) })
} }
} }
func Test_commandsToEvents(t *testing.T) {
ctx := context.Background()
payload := struct {
ID string
}{
ID: "test",
}
payloadMarshalled, err := json.Marshal(payload)
if err != nil {
t.Fatalf("marshal of payload failed: %v", err)
}
type args struct {
ctx context.Context
cmds []eventstore.Command
}
type want struct {
events []eventstore.Event
commands []*command
err func(t *testing.T, err error)
}
tests := []struct {
name string
args args
want want
}{
{
name: "no commands",
args: args{
ctx: ctx,
cmds: nil,
},
want: want{
events: []eventstore.Event{},
commands: []*command{},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command no payload",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command no instance id",
args: args{
ctx: authz.WithInstanceID(ctx, "instance from ctx"),
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregateWithInstance("V3-Red9I", ""),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregateWithInstance("V3-Red9I", "instance from ctx"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance from ctx",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "single command with payload",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
Owner: "ro",
CommandType: "event.type",
Revision: 1,
Payload: payloadMarshalled,
Creator: "creator",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "multiple commands",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: payload,
},
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: nil,
},
},
},
want: want{
events: []eventstore.Event{
mockEvent(
mockAggregate("V3-Red9I"),
0,
payloadMarshalled,
),
mockEvent(
mockAggregate("V3-Red9I"),
0,
nil,
),
},
commands: []*command{
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
CommandType: "event.type",
Revision: 1,
Payload: payloadMarshalled,
Creator: "creator",
Owner: "ro",
},
{
InstanceID: "instance",
AggregateType: "type",
AggregateID: "V3-Red9I",
CommandType: "event.type",
Revision: 1,
Payload: nil,
Creator: "creator",
Owner: "ro",
},
},
err: func(t *testing.T, err error) {
require.NoError(t, err)
},
},
},
{
name: "invalid command",
args: args{
ctx: ctx,
cmds: []eventstore.Command{
&mockCommand{
aggregate: mockAggregate("V3-Red9I"),
payload: func() {},
},
},
},
want: want{
events: nil,
commands: nil,
err: func(t *testing.T, err error) {
assert.Error(t, err)
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotEvents, gotCommands, err := commandsToEvents(tt.args.ctx, tt.args.cmds)
tt.want.err(t, err)
assert.Equal(t, tt.want.events, gotEvents)
require.Len(t, gotCommands, len(tt.want.commands))
for i, wantCommand := range tt.want.commands {
assertCommand(t, wantCommand, gotCommands[i])
}
})
}
}
func assertCommand(t *testing.T, want, got *command) {
t.Helper()
assert.Equal(t, want.CommandType, got.CommandType)
assert.Equal(t, want.Payload, got.Payload)
assert.Equal(t, want.Creator, got.Creator)
assert.Equal(t, want.Owner, got.Owner)
assert.Equal(t, want.AggregateID, got.AggregateID)
assert.Equal(t, want.AggregateType, got.AggregateType)
assert.Equal(t, want.InstanceID, got.InstanceID)
assert.Equal(t, want.Revision, got.Revision)
}

View File

@@ -2,11 +2,26 @@ package eventstore
import ( import (
"context" "context"
"database/sql"
"encoding/json"
"errors"
"sync"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgtype"
"github.com/jackc/pgx/v5/stdlib"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
) )
func init() {
dialect.RegisterAfterConnect(RegisterEventstoreTypes)
}
var ( var (
// pushPlaceholderFmt defines how data are inserted into the events table // pushPlaceholderFmt defines how data are inserted into the events table
pushPlaceholderFmt string pushPlaceholderFmt string
@@ -20,6 +35,123 @@ type Eventstore struct {
client *database.DB client *database.DB
} }
var (
textType = &pgtype.Type{
Name: "text",
OID: pgtype.TextOID,
Codec: pgtype.TextCodec{},
}
commandType = &pgtype.Type{
Codec: &pgtype.CompositeCodec{
Fields: []pgtype.CompositeCodecField{
{
Name: "instance_id",
Type: textType,
},
{
Name: "aggregate_type",
Type: textType,
},
{
Name: "aggregate_id",
Type: textType,
},
{
Name: "command_type",
Type: textType,
},
{
Name: "revision",
Type: &pgtype.Type{
Name: "int2",
OID: pgtype.Int2OID,
Codec: pgtype.Int2Codec{},
},
},
{
Name: "payload",
Type: &pgtype.Type{
Name: "jsonb",
OID: pgtype.JSONBOID,
Codec: &pgtype.JSONBCodec{
Marshal: json.Marshal,
Unmarshal: json.Unmarshal,
},
},
},
{
Name: "creator",
Type: textType,
},
{
Name: "owner",
Type: textType,
},
},
},
}
commandArrayCodec = &pgtype.Type{
Codec: &pgtype.ArrayCodec{
ElementType: commandType,
},
}
)
var typeMu sync.Mutex
func RegisterEventstoreTypes(ctx context.Context, conn *pgx.Conn) error {
// conn.TypeMap is not thread safe
typeMu.Lock()
defer typeMu.Unlock()
m := conn.TypeMap()
var cmd *command
if _, ok := m.TypeForValue(cmd); ok {
return nil
}
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
err := conn.QueryRow(ctx, "select oid, typarray from pg_type where typname = $1 and typnamespace = (select oid from pg_namespace where nspname = $2)", "command", "eventstore").
Scan(&commandType.OID, &commandArrayCodec.OID)
if err != nil {
logging.WithError(err).Debug("failed to get oid for command type")
return nil
}
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
logging.Debug("oid for command type not found")
return nil
}
}
m.RegisterTypes([]*pgtype.Type{
{
Name: "eventstore.command",
Codec: commandType.Codec,
OID: commandType.OID,
},
{
Name: "command",
Codec: commandType.Codec,
OID: commandType.OID,
},
{
Name: "eventstore._command",
Codec: commandArrayCodec.Codec,
OID: commandArrayCodec.OID,
},
{
Name: "_command",
Codec: commandArrayCodec.Codec,
OID: commandArrayCodec.OID,
},
})
dialect.RegisterDefaultPgTypeVariants[command](m, "eventstore.command", "eventstore._command")
dialect.RegisterDefaultPgTypeVariants[command](m, "command", "_command")
return nil
}
// Client implements the [eventstore.Pusher] // Client implements the [eventstore.Pusher]
func (es *Eventstore) Client() *database.DB { func (es *Eventstore) Client() *database.DB {
return es.client return es.client
@@ -41,3 +173,45 @@ func NewEventstore(client *database.DB) *Eventstore {
func (es *Eventstore) Health(ctx context.Context) error { func (es *Eventstore) Health(ctx context.Context) error {
return es.client.PingContext(ctx) return es.client.PingContext(ctx)
} }
var errTypesNotFound = errors.New("types not found")
func CheckExecutionPlan(ctx context.Context, conn *sql.Conn) error {
return conn.Raw(func(driverConn any) error {
if _, ok := driverConn.(sqlmock.SqlmockCommon); ok {
return nil
}
conn, ok := driverConn.(*stdlib.Conn)
if !ok {
return errTypesNotFound
}
return RegisterEventstoreTypes(ctx, conn.Conn())
})
}
func (es *Eventstore) pushTx(ctx context.Context, client database.ContextQueryExecuter) (tx database.Tx, deferrable func(err error) error, err error) {
tx, ok := client.(database.Tx)
if ok {
return tx, nil, nil
}
beginner, ok := client.(database.Beginner)
if !ok {
beginner = es.client
}
isolationLevel := sql.LevelReadCommitted
// cockroach requires serializable to execute the push function
// because we use [cluster_logical_timestamp()](https://www.cockroachlabs.com/docs/stable/functions-and-operators#system-info-functions)
if es.client.Type() == "cockroach" {
isolationLevel = sql.LevelSerializable
}
tx, err = beginner.BeginTx(ctx, &sql.TxOptions{
Isolation: isolationLevel,
ReadOnly: false,
})
if err != nil {
return nil, nil, err
}
return tx, func(err error) error { return database.CloseTransaction(tx, err) }, nil
}

View File

@@ -143,7 +143,7 @@ func buildSearchCondition(builder *strings.Builder, index int, conditions map[ev
return args return args
} }
func handleFieldCommands(ctx context.Context, tx database.Tx, commands []eventstore.Command) error { func (es *Eventstore) handleFieldCommands(ctx context.Context, tx database.Tx, commands []eventstore.Command) error {
for _, command := range commands { for _, command := range commands {
if len(command.Fields()) > 0 { if len(command.Fields()) > 0 {
if err := handleFieldOperations(ctx, tx, command.Fields()); err != nil { if err := handleFieldOperations(ctx, tx, command.Fields()); err != nil {

View File

@@ -48,12 +48,17 @@ func (e *mockCommand) Fields() []*eventstore.FieldOperation {
func mockEvent(aggregate *eventstore.Aggregate, sequence uint64, payload Payload) eventstore.Event { func mockEvent(aggregate *eventstore.Aggregate, sequence uint64, payload Payload) eventstore.Event {
return &event{ return &event{
aggregate: aggregate, command: &command{
creator: "creator", InstanceID: aggregate.InstanceID,
revision: 1, AggregateType: string(aggregate.Type),
typ: "event.type", AggregateID: aggregate.ID,
sequence: sequence, Owner: aggregate.ResourceOwner,
payload: payload, Creator: "creator",
Revision: 1,
CommandType: "event.type",
Payload: payload,
},
sequence: sequence,
} }
} }
@@ -66,3 +71,13 @@ func mockAggregate(id string) *eventstore.Aggregate {
Version: "v1", Version: "v1",
} }
} }
func mockAggregateWithInstance(id, instance string) *eventstore.Aggregate {
return &eventstore.Aggregate{
ID: id,
InstanceID: instance,
Type: "type",
ResourceOwner: "ro",
Version: "v1",
}
}

View File

@@ -4,83 +4,58 @@ import (
"context" "context"
"database/sql" "database/sql"
_ "embed" _ "embed"
"errors"
"fmt"
"strconv"
"strings"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
) )
var appNamePrefix = dialect.DBPurposeEventPusher.AppName() + "_"
var pushTxOpts = &sql.TxOptions{ var pushTxOpts = &sql.TxOptions{
Isolation: sql.LevelReadCommitted, Isolation: sql.LevelReadCommitted,
ReadOnly: false, ReadOnly: false,
} }
func (es *Eventstore) Push(ctx context.Context, client database.QueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) { func (es *Eventstore) Push(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
var tx database.Tx events, err = es.writeCommands(ctx, client, commands)
if isSetupNotExecutedError(err) {
return es.pushWithoutFunc(ctx, client, commands...)
}
return events, err
}
func (es *Eventstore) writeCommands(ctx context.Context, client database.ContextQueryExecuter, commands []eventstore.Command) (_ []eventstore.Event, err error) {
var conn *sql.Conn
switch c := client.(type) { switch c := client.(type) {
case database.Tx:
tx = c
case database.Client: case database.Client:
// We cannot use READ COMMITTED on CockroachDB because we use cluster_logical_timestamp() which is not supported in this isolation level conn, err = c.Conn(ctx)
var opts *sql.TxOptions case nil:
if es.client.Database.Type() == "postgres" { conn, err = es.client.Conn(ctx)
opts = pushTxOpts client = conn
}
tx, err = c.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
defer func() {
err = database.CloseTransaction(tx, err)
}()
default:
// We cannot use READ COMMITTED on CockroachDB because we use cluster_logical_timestamp() which is not supported in this isolation level
var opts *sql.TxOptions
if es.client.Database.Type() == "postgres" {
opts = pushTxOpts
}
tx, err = es.client.BeginTx(ctx, opts)
if err != nil {
return nil, err
}
defer func() {
err = database.CloseTransaction(tx, err)
}()
} }
// tx is not closed because [crdb.ExecuteInTx] takes care of that
var (
sequences []*latestSequence
)
// needs to be set like this because psql complains about parameters in the SET statement
_, err = tx.ExecContext(ctx, "SET application_name = '"+appNamePrefix+authz.GetInstance(ctx).InstanceID()+"'")
if err != nil {
logging.WithError(err).Warn("failed to set application name")
return nil, err
}
sequences, err = latestSequences(ctx, tx, commands)
if err != nil { if err != nil {
return nil, err return nil, err
} }
if conn != nil {
defer conn.Close()
}
events, err = insertEvents(ctx, tx, sequences, commands) tx, close, err := es.pushTx(ctx, client)
if err != nil {
return nil, err
}
if close != nil {
defer func() {
err = close(err)
}()
}
events, err := writeEvents(ctx, tx, commands)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -89,16 +64,7 @@ func (es *Eventstore) Push(ctx context.Context, client database.QueryExecuter, c
return nil, err return nil, err
} }
// CockroachDB by default does not allow multiple modifications of the same table using ON CONFLICT err = es.handleFieldCommands(ctx, tx, commands)
// Thats why we enable it manually
if es.client.Type() == "cockroach" {
_, err = tx.Exec("SET enable_multiple_modifications_of_table = on")
if err != nil {
return nil, err
}
}
err = handleFieldCommands(ctx, tx, commands)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -106,120 +72,30 @@ func (es *Eventstore) Push(ctx context.Context, client database.QueryExecuter, c
return events, nil return events, nil
} }
//go:embed push.sql func writeEvents(ctx context.Context, tx database.Tx, commands []eventstore.Command) (_ []eventstore.Event, err error) {
var pushStmt string ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
func insertEvents(ctx context.Context, tx database.Tx, sequences []*latestSequence, commands []eventstore.Command) ([]eventstore.Event, error) { events, cmds, err := commandsToEvents(ctx, commands)
events, placeholders, args, err := mapCommands(commands, sequences)
if err != nil { if err != nil {
return nil, err return nil, err
} }
rows, err := tx.QueryContext(ctx, fmt.Sprintf(pushStmt, strings.Join(placeholders, ", ")), args...) rows, err := tx.QueryContext(ctx, `select owner, created_at, "sequence", position from eventstore.push($1::eventstore.command[])`, cmds)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer rows.Close() defer rows.Close()
for i := 0; rows.Next(); i++ { for i := 0; rows.Next(); i++ {
err = rows.Scan(&events[i].(*event).createdAt, &events[i].(*event).position) err = rows.Scan(&events[i].(*event).command.Owner, &events[i].(*event).createdAt, &events[i].(*event).sequence, &events[i].(*event).position)
if err != nil { if err != nil {
logging.WithError(err).Warn("failed to scan events") logging.WithError(err).Warn("failed to scan events")
return nil, err return nil, err
} }
} }
if err = rows.Err(); err != nil {
if err := rows.Err(); err != nil { return nil, err
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// Check if push tries to write an event just written
// by another transaction
if pgErr.Code == "40001" {
// TODO: @livio-a should we return the parent or not?
return nil, zerrors.ThrowInvalidArgument(err, "V3-p5xAn", "Errors.AlreadyExists")
}
}
logging.WithError(rows.Err()).Warn("failed to push events")
return nil, zerrors.ThrowInternal(err, "V3-VGnZY", "Errors.Internal")
} }
return events, nil return events, nil
} }
const argsPerCommand = 10
func mapCommands(commands []eventstore.Command, sequences []*latestSequence) (events []eventstore.Event, placeholders []string, args []any, err error) {
events = make([]eventstore.Event, len(commands))
args = make([]any, 0, len(commands)*argsPerCommand)
placeholders = make([]string, len(commands))
for i, command := range commands {
sequence := searchSequenceByCommand(sequences, command)
if sequence == nil {
logging.WithFields(
"aggType", command.Aggregate().Type,
"aggID", command.Aggregate().ID,
"instance", command.Aggregate().InstanceID,
).Panic("no sequence found")
// added return for linting
return nil, nil, nil, nil
}
sequence.sequence++
events[i], err = commandToEvent(sequence, command)
if err != nil {
return nil, nil, nil, err
}
placeholders[i] = fmt.Sprintf(pushPlaceholderFmt,
i*argsPerCommand+1,
i*argsPerCommand+2,
i*argsPerCommand+3,
i*argsPerCommand+4,
i*argsPerCommand+5,
i*argsPerCommand+6,
i*argsPerCommand+7,
i*argsPerCommand+8,
i*argsPerCommand+9,
i*argsPerCommand+10,
)
revision, err := strconv.Atoi(strings.TrimPrefix(string(events[i].(*event).aggregate.Version), "v"))
if err != nil {
return nil, nil, nil, zerrors.ThrowInternal(err, "V3-JoZEp", "Errors.Internal")
}
args = append(args,
events[i].(*event).aggregate.InstanceID,
events[i].(*event).aggregate.ResourceOwner,
events[i].(*event).aggregate.Type,
events[i].(*event).aggregate.ID,
revision,
events[i].(*event).creator,
events[i].(*event).typ,
events[i].(*event).payload,
events[i].(*event).sequence,
i,
)
}
return events, placeholders, args, nil
}
type transaction struct {
database.Tx
}
var _ crdb.Tx = (*transaction)(nil)
func (t *transaction) Exec(ctx context.Context, query string, args ...interface{}) error {
_, err := t.Tx.ExecContext(ctx, query, args...)
return err
}
func (t *transaction) Commit(ctx context.Context) error {
return t.Tx.Commit()
}
func (t *transaction) Rollback(ctx context.Context) error {
return t.Tx.Rollback()
}

View File

@@ -70,11 +70,11 @@ func Test_mapCommands(t *testing.T) {
args: []any{ args: []any{
"instance", "instance",
"ro", "ro",
eventstore.AggregateType("type"), "type",
"V3-VEIvq", "V3-VEIvq",
1, uint16(1),
"creator", "creator",
eventstore.EventType("event.type"), "event.type",
Payload(nil), Payload(nil),
uint64(1), uint64(1),
0, 0,
@@ -121,22 +121,22 @@ func Test_mapCommands(t *testing.T) {
// first event // first event
"instance", "instance",
"ro", "ro",
eventstore.AggregateType("type"), "type",
"V3-VEIvq", "V3-VEIvq",
1, uint16(1),
"creator", "creator",
eventstore.EventType("event.type"), "event.type",
Payload(nil), Payload(nil),
uint64(6), uint64(6),
0, 0,
// second event // second event
"instance", "instance",
"ro", "ro",
eventstore.AggregateType("type"), "type",
"V3-VEIvq", "V3-VEIvq",
1, uint16(1),
"creator", "creator",
eventstore.EventType("event.type"), "event.type",
Payload(nil), Payload(nil),
uint64(7), uint64(7),
1, 1,
@@ -187,22 +187,22 @@ func Test_mapCommands(t *testing.T) {
// first event // first event
"instance", "instance",
"ro", "ro",
eventstore.AggregateType("type"), "type",
"V3-VEIvq", "V3-VEIvq",
1, uint16(1),
"creator", "creator",
eventstore.EventType("event.type"), "event.type",
Payload(nil), Payload(nil),
uint64(6), uint64(6),
0, 0,
// second event // second event
"instance", "instance",
"ro", "ro",
eventstore.AggregateType("type"), "type",
"V3-IT6VN", "V3-IT6VN",
1, uint16(1),
"creator", "creator",
eventstore.EventType("event.type"), "event.type",
Payload(nil), Payload(nil),
uint64(1), uint64(1),
1, 1,

View File

@@ -0,0 +1,183 @@
package eventstore
import (
"context"
_ "embed"
"errors"
"fmt"
"strings"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
type transaction struct {
database.Tx
}
var _ crdb.Tx = (*transaction)(nil)
func (t *transaction) Exec(ctx context.Context, query string, args ...interface{}) error {
_, err := t.Tx.ExecContext(ctx, query, args...)
return err
}
func (t *transaction) Commit(ctx context.Context) error {
return t.Tx.Commit()
}
func (t *transaction) Rollback(ctx context.Context) error {
return t.Tx.Rollback()
}
// checks whether the error is caused because setup step 39 was not executed
func isSetupNotExecutedError(err error) bool {
if err == nil {
return false
}
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
return (pgErr.Code == "42704" && strings.Contains(pgErr.Message, "eventstore.command")) ||
(pgErr.Code == "42883" && strings.Contains(pgErr.Message, "eventstore.push"))
}
return errors.Is(err, errTypesNotFound)
}
var (
//go:embed push.sql
pushStmt string
)
// pushWithoutFunc implements pushing events before setup step 39 was introduced.
// TODO: remove with v3
func (es *Eventstore) pushWithoutFunc(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
tx, closeTx, err := es.pushTx(ctx, client)
if err != nil {
return nil, err
}
defer func() {
err = closeTx(err)
}()
// tx is not closed because [crdb.ExecuteInTx] takes care of that
var (
sequences []*latestSequence
)
sequences, err = latestSequences(ctx, tx, commands)
if err != nil {
return nil, err
}
events, err = es.writeEventsOld(ctx, tx, sequences, commands)
if err != nil {
return nil, err
}
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
return nil, err
}
err = es.handleFieldCommands(ctx, tx, commands)
if err != nil {
return nil, err
}
return events, nil
}
func (es *Eventstore) writeEventsOld(ctx context.Context, tx database.Tx, sequences []*latestSequence, commands []eventstore.Command) ([]eventstore.Event, error) {
events, placeholders, args, err := mapCommands(commands, sequences)
if err != nil {
return nil, err
}
rows, err := tx.QueryContext(ctx, fmt.Sprintf(pushStmt, strings.Join(placeholders, ", ")), args...)
if err != nil {
return nil, err
}
defer rows.Close()
for i := 0; rows.Next(); i++ {
err = rows.Scan(&events[i].(*event).createdAt, &events[i].(*event).position)
if err != nil {
logging.WithError(err).Warn("failed to scan events")
return nil, err
}
}
if err := rows.Err(); err != nil {
pgErr := new(pgconn.PgError)
if errors.As(err, &pgErr) {
// Check if push tries to write an event just written
// by another transaction
if pgErr.Code == "40001" {
// TODO: @livio-a should we return the parent or not?
return nil, zerrors.ThrowInvalidArgument(err, "V3-p5xAn", "Errors.AlreadyExists")
}
}
logging.WithError(rows.Err()).Warn("failed to push events")
return nil, zerrors.ThrowInternal(err, "V3-VGnZY", "Errors.Internal")
}
return events, nil
}
const argsPerCommand = 10
func mapCommands(commands []eventstore.Command, sequences []*latestSequence) (events []eventstore.Event, placeholders []string, args []any, err error) {
events = make([]eventstore.Event, len(commands))
args = make([]any, 0, len(commands)*argsPerCommand)
placeholders = make([]string, len(commands))
for i, command := range commands {
sequence := searchSequenceByCommand(sequences, command)
if sequence == nil {
logging.WithFields(
"aggType", command.Aggregate().Type,
"aggID", command.Aggregate().ID,
"instance", command.Aggregate().InstanceID,
).Panic("no sequence found")
// added return for linting
return nil, nil, nil, nil
}
sequence.sequence++
events[i], err = commandToEventOld(sequence, command)
if err != nil {
return nil, nil, nil, err
}
placeholders[i] = fmt.Sprintf(pushPlaceholderFmt,
i*argsPerCommand+1,
i*argsPerCommand+2,
i*argsPerCommand+3,
i*argsPerCommand+4,
i*argsPerCommand+5,
i*argsPerCommand+6,
i*argsPerCommand+7,
i*argsPerCommand+8,
i*argsPerCommand+9,
i*argsPerCommand+10,
)
args = append(args,
events[i].(*event).command.InstanceID,
events[i].(*event).command.Owner,
events[i].(*event).command.AggregateType,
events[i].(*event).command.AggregateID,
events[i].(*event).command.Revision,
events[i].(*event).command.Creator,
events[i].(*event).command.CommandType,
events[i].(*event).command.Payload,
events[i].(*event).sequence,
i,
)
}
return events, placeholders, args, nil
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
) )
@@ -24,7 +25,10 @@ var (
addConstraintStmt string addConstraintStmt string
) )
func handleUniqueConstraints(ctx context.Context, tx database.Tx, commands []eventstore.Command) error { func handleUniqueConstraints(ctx context.Context, tx database.Tx, commands []eventstore.Command) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
deletePlaceholders := make([]string, 0) deletePlaceholders := make([]string, 0)
deleteArgs := make([]any, 0) deleteArgs := make([]any, 0)