mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-04 23:45:07 +00:00
Merge branch 'main' into add-instance-domain-field
This commit is contained in:
commit
27f3b987a0
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
@ -8,8 +9,8 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
func exec(db *database.DB, stmt string, possibleErrCodes []string, args ...interface{}) error {
|
||||
_, err := db.Exec(stmt, args...)
|
||||
func exec(ctx context.Context, db database.ContextExecuter, stmt string, possibleErrCodes []string, args ...interface{}) error {
|
||||
_, err := db.ExecContext(ctx, stmt, args...)
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(err, &pgErr) {
|
||||
for _, possibleCode := range possibleErrCodes {
|
||||
|
@ -59,7 +59,7 @@ The user provided by flags needs privileges to
|
||||
}
|
||||
|
||||
func InitAll(ctx context.Context, config *Config) {
|
||||
err := initialise(config.Database,
|
||||
err := initialise(ctx, config.Database,
|
||||
VerifyUser(config.Database.Username(), config.Database.Password()),
|
||||
VerifyDatabase(config.Database.DatabaseName()),
|
||||
VerifyGrant(config.Database.DatabaseName(), config.Database.Username()),
|
||||
@ -71,7 +71,7 @@ func InitAll(ctx context.Context, config *Config) {
|
||||
logging.OnError(err).Fatal("unable to initialize ZITADEL")
|
||||
}
|
||||
|
||||
func initialise(config database.Config, steps ...func(*database.DB) error) error {
|
||||
func initialise(ctx context.Context, config database.Config, steps ...func(context.Context, *database.DB) error) error {
|
||||
logging.Info("initialization started")
|
||||
|
||||
err := ReadStmts(config.Type())
|
||||
@ -85,12 +85,12 @@ func initialise(config database.Config, steps ...func(*database.DB) error) error
|
||||
}
|
||||
defer db.Close()
|
||||
|
||||
return Init(db, steps...)
|
||||
return Init(ctx, db, steps...)
|
||||
}
|
||||
|
||||
func Init(db *database.DB, steps ...func(*database.DB) error) error {
|
||||
func Init(ctx context.Context, db *database.DB, steps ...func(context.Context, *database.DB) error) error {
|
||||
for _, step := range steps {
|
||||
if err := step(db); err != nil {
|
||||
if err := step(ctx, db); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
@ -1,2 +1,2 @@
|
||||
-- replace %[1]s with the name of the database
|
||||
CREATE DATABASE IF NOT EXISTS "%[1]s"
|
||||
CREATE DATABASE IF NOT EXISTS "%[1]s";
|
||||
|
@ -19,3 +19,98 @@ CREATE TABLE IF NOT EXISTS eventstore.events2 (
|
||||
, INDEX es_wm (aggregate_id, instance_id, aggregate_type, event_type)
|
||||
, INDEX es_projection (instance_id, aggregate_type, event_type, "position" DESC)
|
||||
);
|
||||
|
||||
-- represents an event to be created.
|
||||
CREATE TYPE IF NOT EXISTS eventstore.command AS (
|
||||
instance_id TEXT
|
||||
, aggregate_type TEXT
|
||||
, aggregate_id TEXT
|
||||
, command_type TEXT
|
||||
, revision INT2
|
||||
, payload JSONB
|
||||
, creator TEXT
|
||||
, owner TEXT
|
||||
);
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
|
||||
SELECT
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
, ("c").command_type AS event_type
|
||||
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY ("c").instance_id, ("c").aggregate_type, ("c").aggregate_id ORDER BY ("c").in_tx_order) AS sequence
|
||||
, ("c").revision
|
||||
, hlc_to_timestamp(cluster_logical_timestamp()) AS created_at
|
||||
, ("c").payload
|
||||
, ("c").creator
|
||||
, cs.owner
|
||||
, cluster_logical_timestamp() AS position
|
||||
, ("c").in_tx_order
|
||||
FROM (
|
||||
SELECT
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
, ("c").command_type
|
||||
, ("c").revision
|
||||
, ("c").payload
|
||||
, ("c").creator
|
||||
, ("c").owner
|
||||
, ROW_NUMBER() OVER () AS in_tx_order
|
||||
FROM
|
||||
UNNEST(commands) AS "c"
|
||||
) AS "c"
|
||||
JOIN (
|
||||
SELECT
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, CASE WHEN (e.owner IS NOT NULL OR e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
|
||||
, COALESCE(MAX(e.sequence), 0) AS sequence
|
||||
FROM (
|
||||
SELECT DISTINCT
|
||||
("cmds").instance_id
|
||||
, ("cmds").aggregate_type
|
||||
, ("cmds").aggregate_id
|
||||
, ("cmds").owner
|
||||
FROM UNNEST(commands) AS "cmds"
|
||||
) AS cmds
|
||||
LEFT JOIN eventstore.events2 AS e
|
||||
ON cmds.instance_id = e.instance_id
|
||||
AND cmds.aggregate_type = e.aggregate_type
|
||||
AND cmds.aggregate_id = e.aggregate_id
|
||||
JOIN (
|
||||
SELECT
|
||||
DISTINCT ON (
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
)
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
, ("c").owner
|
||||
FROM
|
||||
UNNEST(commands) AS "c"
|
||||
) AS command_owners ON
|
||||
cmds.instance_id = command_owners.instance_id
|
||||
AND cmds.aggregate_type = command_owners.aggregate_type
|
||||
AND cmds.aggregate_id = command_owners.aggregate_id
|
||||
GROUP BY
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, 4 -- owner
|
||||
) AS cs
|
||||
ON ("c").instance_id = cs.instance_id
|
||||
AND ("c").aggregate_type = cs.aggregate_type
|
||||
AND ("c").aggregate_id = cs.aggregate_id
|
||||
ORDER BY
|
||||
in_tx_order
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 AS $$
|
||||
INSERT INTO eventstore.events2
|
||||
SELECT * FROM eventstore.commands_to_events(commands)
|
||||
RETURNING *
|
||||
$$ LANGUAGE SQL;
|
@ -19,4 +19,103 @@ CREATE TABLE IF NOT EXISTS eventstore.events2 (
|
||||
|
||||
CREATE INDEX IF NOT EXISTS es_active_instances ON eventstore.events2 (created_at DESC, instance_id);
|
||||
CREATE INDEX IF NOT EXISTS es_wm ON eventstore.events2 (aggregate_id, instance_id, aggregate_type, event_type);
|
||||
CREATE INDEX IF NOT EXISTS es_projection ON eventstore.events2 (instance_id, aggregate_type, event_type, "position");
|
||||
CREATE INDEX IF NOT EXISTS es_projection ON eventstore.events2 (instance_id, aggregate_type, event_type, "position");
|
||||
|
||||
-- represents an event to be created.
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE eventstore.command AS (
|
||||
instance_id TEXT
|
||||
, aggregate_type TEXT
|
||||
, aggregate_id TEXT
|
||||
, command_type TEXT
|
||||
, revision INT2
|
||||
, payload JSONB
|
||||
, creator TEXT
|
||||
, owner TEXT
|
||||
);
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
|
||||
SELECT
|
||||
c.instance_id
|
||||
, c.aggregate_type
|
||||
, c.aggregate_id
|
||||
, c.command_type AS event_type
|
||||
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY c.instance_id, c.aggregate_type, c.aggregate_id ORDER BY c.in_tx_order) AS sequence
|
||||
, c.revision
|
||||
, NOW() AS created_at
|
||||
, c.payload
|
||||
, c.creator
|
||||
, cs.owner
|
||||
, EXTRACT(EPOCH FROM NOW()) AS position
|
||||
, c.in_tx_order
|
||||
FROM (
|
||||
SELECT
|
||||
c.instance_id
|
||||
, c.aggregate_type
|
||||
, c.aggregate_id
|
||||
, c.command_type
|
||||
, c.revision
|
||||
, c.payload
|
||||
, c.creator
|
||||
, c.owner
|
||||
, ROW_NUMBER() OVER () AS in_tx_order
|
||||
FROM
|
||||
UNNEST(commands) AS c
|
||||
) AS c
|
||||
JOIN (
|
||||
SELECT
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, CASE WHEN (e.owner IS NOT NULL OR e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
|
||||
, COALESCE(MAX(e.sequence), 0) AS sequence
|
||||
FROM (
|
||||
SELECT DISTINCT
|
||||
instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, owner
|
||||
FROM UNNEST(commands)
|
||||
) AS cmds
|
||||
LEFT JOIN eventstore.events2 AS e
|
||||
ON cmds.instance_id = e.instance_id
|
||||
AND cmds.aggregate_type = e.aggregate_type
|
||||
AND cmds.aggregate_id = e.aggregate_id
|
||||
JOIN (
|
||||
SELECT
|
||||
DISTINCT ON (
|
||||
instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
)
|
||||
instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, owner
|
||||
FROM
|
||||
UNNEST(commands)
|
||||
) AS command_owners ON
|
||||
cmds.instance_id = command_owners.instance_id
|
||||
AND cmds.aggregate_type = command_owners.aggregate_type
|
||||
AND cmds.aggregate_id = command_owners.aggregate_id
|
||||
GROUP BY
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, 4 -- owner
|
||||
) AS cs
|
||||
ON c.instance_id = cs.instance_id
|
||||
AND c.aggregate_type = cs.aggregate_type
|
||||
AND c.aggregate_id = cs.aggregate_id
|
||||
ORDER BY
|
||||
in_tx_order;
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
|
||||
INSERT INTO eventstore.events2
|
||||
SELECT * FROM eventstore.commands_to_events(commands)
|
||||
RETURNING *
|
||||
$$ LANGUAGE SQL;
|
||||
|
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
@ -28,16 +29,16 @@ The user provided by flags needs privileges to
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
config := MustNewConfig(viper.GetViper())
|
||||
|
||||
err := initialise(config.Database, VerifyDatabase(config.Database.DatabaseName()))
|
||||
err := initialise(cmd.Context(), config.Database, VerifyDatabase(config.Database.DatabaseName()))
|
||||
logging.OnError(err).Fatal("unable to initialize the database")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyDatabase(databaseName string) func(*database.DB) error {
|
||||
return func(db *database.DB) error {
|
||||
func VerifyDatabase(databaseName string) func(context.Context, *database.DB) error {
|
||||
return func(ctx context.Context, db *database.DB) error {
|
||||
logging.WithFields("database", databaseName).Info("verify database")
|
||||
|
||||
return exec(db, fmt.Sprintf(databaseStmt, databaseName), []string{dbAlreadyExistsCode})
|
||||
return exec(ctx, db, fmt.Sprintf(databaseStmt, databaseName), []string{dbAlreadyExistsCode})
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
@ -55,7 +56,7 @@ func Test_verifyDB(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := VerifyDatabase(tt.args.database)(tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
if err := VerifyDatabase(tt.args.database)(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
t.Errorf("verifyDB() error = %v, want: %v", err, tt.targetErr)
|
||||
}
|
||||
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
@ -23,16 +24,16 @@ Prerequisites:
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
config := MustNewConfig(viper.GetViper())
|
||||
|
||||
err := initialise(config.Database, VerifyGrant(config.Database.DatabaseName(), config.Database.Username()))
|
||||
err := initialise(cmd.Context(), config.Database, VerifyGrant(config.Database.DatabaseName(), config.Database.Username()))
|
||||
logging.OnError(err).Fatal("unable to set grant")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyGrant(databaseName, username string) func(*database.DB) error {
|
||||
return func(db *database.DB) error {
|
||||
func VerifyGrant(databaseName, username string) func(context.Context, *database.DB) error {
|
||||
return func(ctx context.Context, db *database.DB) error {
|
||||
logging.WithFields("user", username, "database", databaseName).Info("verify grant")
|
||||
|
||||
return exec(db, fmt.Sprintf(grantStmt, databaseName, username), nil)
|
||||
return exec(ctx, db, fmt.Sprintf(grantStmt, databaseName, username), nil)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
@ -53,7 +54,7 @@ func Test_verifyGrant(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := VerifyGrant(tt.args.database, tt.args.username)(tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
if err := VerifyGrant(tt.args.database, tt.args.username)(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
t.Errorf("VerifyGrant() error = %v, want: %v", err, tt.targetErr)
|
||||
}
|
||||
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {
|
||||
|
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
@ -26,19 +27,19 @@ Cockroach
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
config := MustNewConfig(viper.GetViper())
|
||||
|
||||
err := initialise(config.Database, VerifySettings(config.Database.DatabaseName(), config.Database.Username()))
|
||||
err := initialise(cmd.Context(), config.Database, VerifySettings(config.Database.DatabaseName(), config.Database.Username()))
|
||||
logging.OnError(err).Fatal("unable to set settings")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func VerifySettings(databaseName, username string) func(*database.DB) error {
|
||||
return func(db *database.DB) error {
|
||||
func VerifySettings(databaseName, username string) func(context.Context, *database.DB) error {
|
||||
return func(ctx context.Context, db *database.DB) error {
|
||||
if db.Type() == "postgres" {
|
||||
return nil
|
||||
}
|
||||
logging.WithFields("user", username, "database", databaseName).Info("verify settings")
|
||||
|
||||
return exec(db, fmt.Sprintf(settingsStmt, databaseName, username), nil)
|
||||
return exec(ctx, db, fmt.Sprintf(settingsStmt, databaseName, username), nil)
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
@ -28,20 +29,20 @@ The user provided by flags needs privileges to
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
config := MustNewConfig(viper.GetViper())
|
||||
|
||||
err := initialise(config.Database, VerifyUser(config.Database.Username(), config.Database.Password()))
|
||||
err := initialise(cmd.Context(), config.Database, VerifyUser(config.Database.Username(), config.Database.Password()))
|
||||
logging.OnError(err).Fatal("unable to init user")
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func VerifyUser(username, password string) func(*database.DB) error {
|
||||
return func(db *database.DB) error {
|
||||
func VerifyUser(username, password string) func(context.Context, *database.DB) error {
|
||||
return func(ctx context.Context, db *database.DB) error {
|
||||
logging.WithFields("username", username).Info("verify user")
|
||||
|
||||
if password != "" {
|
||||
createUserStmt += " WITH PASSWORD '" + password + "'"
|
||||
}
|
||||
|
||||
return exec(db, fmt.Sprintf(createUserStmt, username), []string{roleAlreadyExistsCode})
|
||||
return exec(ctx, db, fmt.Sprintf(createUserStmt, username), []string{roleAlreadyExistsCode})
|
||||
}
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"testing"
|
||||
@ -70,7 +71,7 @@ func Test_verifyUser(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := VerifyUser(tt.args.username, tt.args.password)(tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
if err := VerifyUser(tt.args.username, tt.args.password)(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
t.Errorf("VerifyGrant() error = %v, want: %v", err, tt.targetErr)
|
||||
}
|
||||
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {
|
||||
|
@ -2,6 +2,7 @@ package initialise
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"fmt"
|
||||
|
||||
@ -11,6 +12,7 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
es_v3 "github.com/zitadel/zitadel/internal/eventstore/v3"
|
||||
)
|
||||
|
||||
func newZitadel() *cobra.Command {
|
||||
@ -36,38 +38,44 @@ func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config)
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := db.Conn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
logging.WithFields().Info("verify system")
|
||||
if err := exec(db, fmt.Sprintf(createSystemStmt, config.Username()), nil); err != nil {
|
||||
if err := exec(ctx, conn, fmt.Sprintf(createSystemStmt, config.Username()), nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify encryption keys")
|
||||
if err := createEncryptionKeys(ctx, db); err != nil {
|
||||
if err := createEncryptionKeys(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify projections")
|
||||
if err := exec(db, fmt.Sprintf(createProjectionsStmt, config.Username()), nil); err != nil {
|
||||
if err := exec(ctx, conn, fmt.Sprintf(createProjectionsStmt, config.Username()), nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify eventstore")
|
||||
if err := exec(db, fmt.Sprintf(createEventstoreStmt, config.Username()), nil); err != nil {
|
||||
if err := exec(ctx, conn, fmt.Sprintf(createEventstoreStmt, config.Username()), nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify events tables")
|
||||
if err := createEvents(ctx, db); err != nil {
|
||||
if err := createEvents(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify system sequence")
|
||||
if err := exec(db, createSystemSequenceStmt, nil); err != nil {
|
||||
if err := exec(ctx, conn, createSystemSequenceStmt, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
logging.WithFields().Info("verify unique constraints")
|
||||
if err := exec(db, createUniqueConstraints, nil); err != nil {
|
||||
if err := exec(ctx, conn, createUniqueConstraints, nil); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@ -89,7 +97,7 @@ func verifyZitadel(ctx context.Context, config database.Config) error {
|
||||
return db.Close()
|
||||
}
|
||||
|
||||
func createEncryptionKeys(ctx context.Context, db *database.DB) error {
|
||||
func createEncryptionKeys(ctx context.Context, db database.Beginner) error {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
@ -103,8 +111,8 @@ func createEncryptionKeys(ctx context.Context, db *database.DB) error {
|
||||
return tx.Commit()
|
||||
}
|
||||
|
||||
func createEvents(ctx context.Context, db *database.DB) (err error) {
|
||||
tx, err := db.BeginTx(ctx, nil)
|
||||
func createEvents(ctx context.Context, conn *sql.Conn) (err error) {
|
||||
tx, err := conn.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -127,5 +135,8 @@ func createEvents(ctx context.Context, db *database.DB) (err error) {
|
||||
return row.Err()
|
||||
}
|
||||
_, err = tx.Exec(createEventsStmt)
|
||||
return err
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return es_v3.CheckExecutionPlan(ctx, conn)
|
||||
}
|
||||
|
@ -108,7 +108,12 @@ func Test_verifyEvents(t *testing.T) {
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := createEvents(context.Background(), tt.args.db.db); !errors.Is(err, tt.targetErr) {
|
||||
conn, err := tt.args.db.db.Conn(context.Background())
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
if err := createEvents(context.Background(), conn); !errors.Is(err, tt.targetErr) {
|
||||
t.Errorf("createEvents() error = %v, want: %v", err, tt.targetErr)
|
||||
}
|
||||
if err := tt.args.db.mock.ExpectationsWereMet(); err != nil {
|
||||
|
52
cmd/setup/40.go
Normal file
52
cmd/setup/40.go
Normal file
@ -0,0 +1,52 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed 40/cockroach/*.sql
|
||||
//go:embed 40/postgres/*.sql
|
||||
initPushFunc embed.FS
|
||||
)
|
||||
|
||||
type InitPushFunc struct {
|
||||
dbClient *database.DB
|
||||
}
|
||||
|
||||
func (mig *InitPushFunc) Execute(ctx context.Context, _ eventstore.Event) (err error) {
|
||||
statements, err := readStatements(initPushFunc, "40", mig.dbClient.Type())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn, err := mig.dbClient.Conn(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := conn.Close()
|
||||
logging.OnError(closeErr).Debug("failed to release connection")
|
||||
// Force the pool to reopen connections to apply the new types
|
||||
mig.dbClient.Pool.Reset()
|
||||
}()
|
||||
|
||||
for _, stmt := range statements {
|
||||
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
|
||||
if _, err := conn.ExecContext(ctx, stmt.query); err != nil {
|
||||
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mig *InitPushFunc) String() string {
|
||||
return "40_init_push_func"
|
||||
}
|
107
cmd/setup/40/cockroach/40_init_push_func.sql
Normal file
107
cmd/setup/40/cockroach/40_init_push_func.sql
Normal file
@ -0,0 +1,107 @@
|
||||
-- represents an event to be created.
|
||||
CREATE TYPE IF NOT EXISTS eventstore.command AS (
|
||||
instance_id TEXT
|
||||
, aggregate_type TEXT
|
||||
, aggregate_id TEXT
|
||||
, command_type TEXT
|
||||
, revision INT2
|
||||
, payload JSONB
|
||||
, creator TEXT
|
||||
, owner TEXT
|
||||
);
|
||||
|
||||
/*
|
||||
select * from eventstore.commands_to_events(
|
||||
ARRAY[
|
||||
ROW('', 'system', 'SYSTEM', 'ct1', 1, '{"key": "value"}', 'c1', 'SYSTEM')
|
||||
, ROW('', 'system', 'SYSTEM', 'ct2', 1, '{"key": "value"}', 'c1', 'SYSTEM')
|
||||
, ROW('289525561255060732', 'org', '289575074711790844', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844')
|
||||
, ROW('289525561255060732', 'user', '289575075164906748', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844')
|
||||
, ROW('289525561255060732', 'oidc_session', 'V2_289575178579535100', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844')
|
||||
, ROW('', 'system', 'SYSTEM', 'ct3', 1, '{"key": "value"}', 'c1', 'SYSTEM')
|
||||
]::eventstore.command[]
|
||||
);
|
||||
*/
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
|
||||
SELECT
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
, ("c").command_type AS event_type
|
||||
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY ("c").instance_id, ("c").aggregate_type, ("c").aggregate_id ORDER BY ("c").in_tx_order) AS sequence
|
||||
, ("c").revision
|
||||
, hlc_to_timestamp(cluster_logical_timestamp()) AS created_at
|
||||
, ("c").payload
|
||||
, ("c").creator
|
||||
, cs.owner
|
||||
, cluster_logical_timestamp() AS position
|
||||
, ("c").in_tx_order
|
||||
FROM (
|
||||
SELECT
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
, ("c").command_type
|
||||
, ("c").revision
|
||||
, ("c").payload
|
||||
, ("c").creator
|
||||
, ("c").owner
|
||||
, ROW_NUMBER() OVER () AS in_tx_order
|
||||
FROM
|
||||
UNNEST(commands) AS "c"
|
||||
) AS "c"
|
||||
JOIN (
|
||||
SELECT
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, CASE WHEN (e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
|
||||
, COALESCE(MAX(e.sequence), 0) AS sequence
|
||||
FROM (
|
||||
SELECT DISTINCT
|
||||
("cmds").instance_id
|
||||
, ("cmds").aggregate_type
|
||||
, ("cmds").aggregate_id
|
||||
, ("cmds").owner
|
||||
FROM UNNEST(commands) AS "cmds"
|
||||
) AS cmds
|
||||
LEFT JOIN eventstore.events2 AS e
|
||||
ON cmds.instance_id = e.instance_id
|
||||
AND cmds.aggregate_type = e.aggregate_type
|
||||
AND cmds.aggregate_id = e.aggregate_id
|
||||
JOIN (
|
||||
SELECT
|
||||
DISTINCT ON (
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
)
|
||||
("c").instance_id
|
||||
, ("c").aggregate_type
|
||||
, ("c").aggregate_id
|
||||
, ("c").owner
|
||||
FROM
|
||||
UNNEST(commands) AS "c"
|
||||
) AS command_owners ON
|
||||
cmds.instance_id = command_owners.instance_id
|
||||
AND cmds.aggregate_type = command_owners.aggregate_type
|
||||
AND cmds.aggregate_id = command_owners.aggregate_id
|
||||
GROUP BY
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, 4 -- owner
|
||||
) AS cs
|
||||
ON ("c").instance_id = cs.instance_id
|
||||
AND ("c").aggregate_type = cs.aggregate_type
|
||||
AND ("c").aggregate_id = cs.aggregate_id
|
||||
ORDER BY
|
||||
in_tx_order
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 AS $$
|
||||
INSERT INTO eventstore.events2
|
||||
SELECT * FROM eventstore.commands_to_events(commands)
|
||||
RETURNING *
|
||||
$$ LANGUAGE SQL;
|
15
cmd/setup/40/postgres/01_type.sql
Normal file
15
cmd/setup/40/postgres/01_type.sql
Normal file
@ -0,0 +1,15 @@
|
||||
-- represents an event to be created.
|
||||
DO $$ BEGIN
|
||||
CREATE TYPE eventstore.command AS (
|
||||
instance_id TEXT
|
||||
, aggregate_type TEXT
|
||||
, aggregate_id TEXT
|
||||
, command_type TEXT
|
||||
, revision INT2
|
||||
, payload JSONB
|
||||
, creator TEXT
|
||||
, owner TEXT
|
||||
);
|
||||
EXCEPTION
|
||||
WHEN duplicate_object THEN null;
|
||||
END $$;
|
82
cmd/setup/40/postgres/02_func.sql
Normal file
82
cmd/setup/40/postgres/02_func.sql
Normal file
@ -0,0 +1,82 @@
|
||||
CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
|
||||
SELECT
|
||||
c.instance_id
|
||||
, c.aggregate_type
|
||||
, c.aggregate_id
|
||||
, c.command_type AS event_type
|
||||
, cs.sequence + ROW_NUMBER() OVER (PARTITION BY c.instance_id, c.aggregate_type, c.aggregate_id ORDER BY c.in_tx_order) AS sequence
|
||||
, c.revision
|
||||
, NOW() AS created_at
|
||||
, c.payload
|
||||
, c.creator
|
||||
, cs.owner
|
||||
, EXTRACT(EPOCH FROM NOW()) AS position
|
||||
, c.in_tx_order
|
||||
FROM (
|
||||
SELECT
|
||||
c.instance_id
|
||||
, c.aggregate_type
|
||||
, c.aggregate_id
|
||||
, c.command_type
|
||||
, c.revision
|
||||
, c.payload
|
||||
, c.creator
|
||||
, c.owner
|
||||
, ROW_NUMBER() OVER () AS in_tx_order
|
||||
FROM
|
||||
UNNEST(commands) AS c
|
||||
) AS c
|
||||
JOIN (
|
||||
SELECT
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, CASE WHEN (e.owner IS NOT NULL OR e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner
|
||||
, COALESCE(MAX(e.sequence), 0) AS sequence
|
||||
FROM (
|
||||
SELECT DISTINCT
|
||||
instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, owner
|
||||
FROM UNNEST(commands)
|
||||
) AS cmds
|
||||
LEFT JOIN eventstore.events2 AS e
|
||||
ON cmds.instance_id = e.instance_id
|
||||
AND cmds.aggregate_type = e.aggregate_type
|
||||
AND cmds.aggregate_id = e.aggregate_id
|
||||
JOIN (
|
||||
SELECT
|
||||
DISTINCT ON (
|
||||
instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
)
|
||||
instance_id
|
||||
, aggregate_type
|
||||
, aggregate_id
|
||||
, owner
|
||||
FROM
|
||||
UNNEST(commands)
|
||||
) AS command_owners ON
|
||||
cmds.instance_id = command_owners.instance_id
|
||||
AND cmds.aggregate_type = command_owners.aggregate_type
|
||||
AND cmds.aggregate_id = command_owners.aggregate_id
|
||||
GROUP BY
|
||||
cmds.instance_id
|
||||
, cmds.aggregate_type
|
||||
, cmds.aggregate_id
|
||||
, 4 -- owner
|
||||
) AS cs
|
||||
ON c.instance_id = cs.instance_id
|
||||
AND c.aggregate_type = cs.aggregate_type
|
||||
AND c.aggregate_id = cs.aggregate_id
|
||||
ORDER BY
|
||||
in_tx_order;
|
||||
$$ LANGUAGE SQL;
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$
|
||||
INSERT INTO eventstore.events2
|
||||
SELECT * FROM eventstore.commands_to_events(commands)
|
||||
RETURNING *
|
||||
$$ LANGUAGE SQL;
|
@ -126,6 +126,7 @@ type Steps struct {
|
||||
s36FillV2Milestones *FillV3Milestones
|
||||
s37Apps7OIDConfigsBackChannelLogoutURI *Apps7OIDConfigsBackChannelLogoutURI
|
||||
s38BackChannelLogoutNotificationStart *BackChannelLogoutNotificationStart
|
||||
s40InitPushFunc *InitPushFunc
|
||||
s39DeleteStaleOrgFields *DeleteStaleOrgFields
|
||||
s41FillFieldsForInstanceDomains *FillFieldsForInstanceDomains
|
||||
}
|
||||
|
@ -170,6 +170,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
steps.s37Apps7OIDConfigsBackChannelLogoutURI = &Apps7OIDConfigsBackChannelLogoutURI{dbClient: esPusherDBClient}
|
||||
steps.s38BackChannelLogoutNotificationStart = &BackChannelLogoutNotificationStart{dbClient: esPusherDBClient, esClient: eventstoreClient}
|
||||
steps.s39DeleteStaleOrgFields = &DeleteStaleOrgFields{dbClient: esPusherDBClient}
|
||||
steps.s40InitPushFunc = &InitPushFunc{dbClient: esPusherDBClient}
|
||||
steps.s41FillFieldsForInstanceDomains = &FillFieldsForInstanceDomains{eventstore: eventstoreClient}
|
||||
|
||||
err = projection.Create(ctx, projectionDBClient, eventstoreClient, config.Projections, nil, nil, nil)
|
||||
@ -191,6 +192,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
|
||||
for _, step := range []migration.Migration{
|
||||
steps.s14NewEventsTable,
|
||||
steps.s40InitPushFunc,
|
||||
steps.s1ProjectionTable,
|
||||
steps.s2AssetsTable,
|
||||
steps.s28AddFieldTable,
|
||||
|
@ -469,12 +469,12 @@ func updateAppleProviderToCommand(req *admin_pb.UpdateAppleProviderRequest) comm
|
||||
}
|
||||
}
|
||||
|
||||
func addSAMLProviderToCommand(req *admin_pb.AddSAMLProviderRequest) command.SAMLProvider {
|
||||
func addSAMLProviderToCommand(req *admin_pb.AddSAMLProviderRequest) *command.SAMLProvider {
|
||||
var nameIDFormat *domain.SAMLNameIDFormat
|
||||
if req.NameIdFormat != nil {
|
||||
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
|
||||
}
|
||||
return command.SAMLProvider{
|
||||
return &command.SAMLProvider{
|
||||
Name: req.Name,
|
||||
Metadata: req.GetMetadataXml(),
|
||||
MetadataURL: req.GetMetadataUrl(),
|
||||
@ -486,12 +486,12 @@ func addSAMLProviderToCommand(req *admin_pb.AddSAMLProviderRequest) command.SAML
|
||||
}
|
||||
}
|
||||
|
||||
func updateSAMLProviderToCommand(req *admin_pb.UpdateSAMLProviderRequest) command.SAMLProvider {
|
||||
func updateSAMLProviderToCommand(req *admin_pb.UpdateSAMLProviderRequest) *command.SAMLProvider {
|
||||
var nameIDFormat *domain.SAMLNameIDFormat
|
||||
if req.NameIdFormat != nil {
|
||||
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
|
||||
}
|
||||
return command.SAMLProvider{
|
||||
return &command.SAMLProvider{
|
||||
Name: req.Name,
|
||||
Metadata: req.GetMetadataXml(),
|
||||
MetadataURL: req.GetMetadataUrl(),
|
||||
|
@ -462,12 +462,12 @@ func updateAppleProviderToCommand(req *mgmt_pb.UpdateAppleProviderRequest) comma
|
||||
}
|
||||
}
|
||||
|
||||
func addSAMLProviderToCommand(req *mgmt_pb.AddSAMLProviderRequest) command.SAMLProvider {
|
||||
func addSAMLProviderToCommand(req *mgmt_pb.AddSAMLProviderRequest) *command.SAMLProvider {
|
||||
var nameIDFormat *domain.SAMLNameIDFormat
|
||||
if req.NameIdFormat != nil {
|
||||
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
|
||||
}
|
||||
return command.SAMLProvider{
|
||||
return &command.SAMLProvider{
|
||||
Name: req.Name,
|
||||
Metadata: req.GetMetadataXml(),
|
||||
MetadataURL: req.GetMetadataUrl(),
|
||||
@ -479,12 +479,12 @@ func addSAMLProviderToCommand(req *mgmt_pb.AddSAMLProviderRequest) command.SAMLP
|
||||
}
|
||||
}
|
||||
|
||||
func updateSAMLProviderToCommand(req *mgmt_pb.UpdateSAMLProviderRequest) command.SAMLProvider {
|
||||
func updateSAMLProviderToCommand(req *mgmt_pb.UpdateSAMLProviderRequest) *command.SAMLProvider {
|
||||
var nameIDFormat *domain.SAMLNameIDFormat
|
||||
if req.NameIdFormat != nil {
|
||||
nameIDFormat = gu.Ptr(idp_grpc.SAMLNameIDFormatToDomain(req.GetNameIdFormat()))
|
||||
}
|
||||
return command.SAMLProvider{
|
||||
return &command.SAMLProvider{
|
||||
Name: req.Name,
|
||||
Metadata: req.GetMetadataXml(),
|
||||
MetadataURL: req.GetMetadataUrl(),
|
||||
|
@ -216,6 +216,7 @@ func assertFeatureDisabledError(t *testing.T, err error) {
|
||||
}
|
||||
|
||||
func checkWebKeyListState(ctx context.Context, t *testing.T, instance *integration.Instance, nKeys int, expectActiveKeyID string, config any, creationDate *timestamppb.Timestamp) {
|
||||
t.Helper()
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(ctx, time.Minute)
|
||||
assert.EventuallyWithT(t, func(collect *assert.CollectT) {
|
||||
|
@ -58,7 +58,7 @@ func TestServer_AddOTPSMS(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user mismatch",
|
||||
name: "no permission",
|
||||
args: args{
|
||||
ctx: integration.WithAuthorizationToken(context.Background(), sessionTokenOtherUser),
|
||||
req: &user.AddOTPSMSRequest{
|
||||
@ -127,14 +127,24 @@ func TestServer_RemoveOTPSMS(t *testing.T) {
|
||||
|
||||
userVerified := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userVerified.GetUserId())
|
||||
_, sessionTokenVerified, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, userVerified.GetUserId())
|
||||
userVerifiedCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenVerified)
|
||||
_, err := Instance.Client.UserV2.VerifyPhone(userVerifiedCtx, &user.VerifyPhoneRequest{
|
||||
_, err := Instance.Client.UserV2.VerifyPhone(CTX, &user.VerifyPhoneRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
VerificationCode: userVerified.GetPhoneCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.Client.UserV2.AddOTPSMS(userVerifiedCtx, &user.AddOTPSMSRequest{UserId: userVerified.GetUserId()})
|
||||
_, err = Instance.Client.UserV2.AddOTPSMS(CTX, &user.AddOTPSMSRequest{UserId: userVerified.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
userSelf := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userSelf.GetUserId())
|
||||
_, sessionTokenSelf, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, userSelf.GetUserId())
|
||||
userSelfCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenSelf)
|
||||
_, err = Instance.Client.UserV2.VerifyPhone(CTX, &user.VerifyPhoneRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
VerificationCode: userSelf.GetPhoneCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.Client.UserV2.AddOTPSMS(CTX, &user.AddOTPSMSRequest{UserId: userSelf.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
@ -157,10 +167,24 @@ func TestServer_RemoveOTPSMS(t *testing.T) {
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success, self",
|
||||
args: args{
|
||||
ctx: userSelfCtx,
|
||||
req: &user.RemoveOTPSMSRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
},
|
||||
},
|
||||
want: &user.RemoveOTPSMSResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Instance.DefaultOrg.Details.ResourceOwner,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
ctx: userVerifiedCtx,
|
||||
ctx: CTX,
|
||||
req: &user.RemoveOTPSMSRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
},
|
||||
@ -230,7 +254,7 @@ func TestServer_AddOTPEmail(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user mismatch",
|
||||
name: "no permission",
|
||||
args: args{
|
||||
ctx: integration.WithAuthorizationToken(context.Background(), sessionTokenOtherUser),
|
||||
req: &user.AddOTPEmailRequest{
|
||||
@ -301,14 +325,24 @@ func TestServer_RemoveOTPEmail(t *testing.T) {
|
||||
|
||||
userVerified := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userVerified.GetUserId())
|
||||
_, sessionTokenVerified, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, userVerified.GetUserId())
|
||||
userVerifiedCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenVerified)
|
||||
_, err := Instance.Client.UserV2.VerifyEmail(userVerifiedCtx, &user.VerifyEmailRequest{
|
||||
_, err := Instance.Client.UserV2.VerifyEmail(CTX, &user.VerifyEmailRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
VerificationCode: userVerified.GetEmailCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.Client.UserV2.AddOTPEmail(userVerifiedCtx, &user.AddOTPEmailRequest{UserId: userVerified.GetUserId()})
|
||||
_, err = Instance.Client.UserV2.AddOTPEmail(CTX, &user.AddOTPEmailRequest{UserId: userVerified.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
userSelf := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userSelf.GetUserId())
|
||||
_, sessionTokenSelf, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, userSelf.GetUserId())
|
||||
userSelfCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenSelf)
|
||||
_, err = Instance.Client.UserV2.VerifyEmail(CTX, &user.VerifyEmailRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
VerificationCode: userSelf.GetEmailCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.Client.UserV2.AddOTPEmail(CTX, &user.AddOTPEmailRequest{UserId: userSelf.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
@ -331,10 +365,25 @@ func TestServer_RemoveOTPEmail(t *testing.T) {
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success, self",
|
||||
args: args{
|
||||
ctx: userSelfCtx,
|
||||
req: &user.RemoveOTPEmailRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
},
|
||||
},
|
||||
want: &user.RemoveOTPEmailResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.DefaultOrg.Details.ResourceOwner,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
ctx: userVerifiedCtx,
|
||||
ctx: CTX,
|
||||
req: &user.RemoveOTPEmailRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
},
|
||||
|
@ -93,15 +93,30 @@ func TestServer_RegisterPasskey(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user mismatch",
|
||||
name: "user no permission",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
ctx: UserCTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user permission",
|
||||
args: args{
|
||||
ctx: IamCTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
want: &user.RegisterPasskeyResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.DefaultOrg.Id,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user setting its own passkey",
|
||||
args: args{
|
||||
|
@ -13,7 +13,6 @@ func (s *Server) AddOTPSMS(ctx context.Context, req *user.AddOTPSMSRequest) (*us
|
||||
return nil, err
|
||||
}
|
||||
return &user.AddOTPSMSResponse{Details: object.DomainToDetailsPb(details)}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *Server) RemoveOTPSMS(ctx context.Context, req *user.RemoveOTPSMSRequest) (*user.RemoveOTPSMSResponse, error) {
|
||||
|
@ -58,7 +58,7 @@ func TestServer_AddOTPSMS(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user mismatch",
|
||||
name: "no permission",
|
||||
args: args{
|
||||
ctx: integration.WithAuthorizationToken(context.Background(), sessionTokenOtherUser),
|
||||
req: &user.AddOTPSMSRequest{
|
||||
@ -127,14 +127,24 @@ func TestServer_RemoveOTPSMS(t *testing.T) {
|
||||
|
||||
userVerified := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userVerified.GetUserId())
|
||||
_, sessionTokenVerified, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, userVerified.GetUserId())
|
||||
userVerifiedCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenVerified)
|
||||
_, err := Client.VerifyPhone(userVerifiedCtx, &user.VerifyPhoneRequest{
|
||||
_, err := Instance.Client.UserV2beta.VerifyPhone(CTX, &user.VerifyPhoneRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
VerificationCode: userVerified.GetPhoneCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Client.AddOTPSMS(userVerifiedCtx, &user.AddOTPSMSRequest{UserId: userVerified.GetUserId()})
|
||||
_, err = Instance.Client.UserV2beta.AddOTPSMS(CTX, &user.AddOTPSMSRequest{UserId: userVerified.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
userSelf := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userSelf.GetUserId())
|
||||
_, sessionTokenSelf, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, userSelf.GetUserId())
|
||||
userSelfCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenSelf)
|
||||
_, err = Instance.Client.UserV2beta.VerifyPhone(CTX, &user.VerifyPhoneRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
VerificationCode: userSelf.GetPhoneCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Instance.Client.UserV2beta.AddOTPSMS(CTX, &user.AddOTPSMSRequest{UserId: userSelf.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
@ -157,10 +167,24 @@ func TestServer_RemoveOTPSMS(t *testing.T) {
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success, self",
|
||||
args: args{
|
||||
ctx: userSelfCtx,
|
||||
req: &user.RemoveOTPSMSRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
},
|
||||
},
|
||||
want: &user.RemoveOTPSMSResponse{
|
||||
Details: &object.Details{
|
||||
ResourceOwner: Instance.DefaultOrg.Details.ResourceOwner,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
ctx: userVerifiedCtx,
|
||||
ctx: CTX,
|
||||
req: &user.RemoveOTPSMSRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
},
|
||||
@ -301,14 +325,24 @@ func TestServer_RemoveOTPEmail(t *testing.T) {
|
||||
|
||||
userVerified := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userVerified.GetUserId())
|
||||
_, sessionTokenVerified, _, _ := Instance.CreateVerifiedWebAuthNSession(t, CTX, userVerified.GetUserId())
|
||||
userVerifiedCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenVerified)
|
||||
_, err := Client.VerifyEmail(userVerifiedCtx, &user.VerifyEmailRequest{
|
||||
_, err := Client.VerifyEmail(CTX, &user.VerifyEmailRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
VerificationCode: userVerified.GetEmailCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Client.AddOTPEmail(userVerifiedCtx, &user.AddOTPEmailRequest{UserId: userVerified.GetUserId()})
|
||||
_, err = Client.AddOTPEmail(CTX, &user.AddOTPEmailRequest{UserId: userVerified.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
userSelf := Instance.CreateHumanUser(CTX)
|
||||
Instance.RegisterUserPasskey(CTX, userSelf.GetUserId())
|
||||
_, sessionTokenSelf, _, _ := Instance.CreateVerifiedWebAuthNSession(t, IamCTX, userSelf.GetUserId())
|
||||
userSelfCtx := integration.WithAuthorizationToken(context.Background(), sessionTokenSelf)
|
||||
_, err = Client.VerifyEmail(CTX, &user.VerifyEmailRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
VerificationCode: userSelf.GetEmailCode(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
_, err = Client.AddOTPEmail(CTX, &user.AddOTPEmailRequest{UserId: userSelf.GetUserId()})
|
||||
require.NoError(t, err)
|
||||
|
||||
type args struct {
|
||||
@ -331,10 +365,25 @@ func TestServer_RemoveOTPEmail(t *testing.T) {
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "success, self",
|
||||
args: args{
|
||||
ctx: userSelfCtx,
|
||||
req: &user.RemoveOTPEmailRequest{
|
||||
UserId: userSelf.GetUserId(),
|
||||
},
|
||||
},
|
||||
want: &user.RemoveOTPEmailResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.DefaultOrg.Details.ResourceOwner,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "success",
|
||||
args: args{
|
||||
ctx: userVerifiedCtx,
|
||||
ctx: CTX,
|
||||
req: &user.RemoveOTPEmailRequest{
|
||||
UserId: userVerified.GetUserId(),
|
||||
},
|
||||
|
@ -92,15 +92,30 @@ func TestServer_RegisterPasskey(t *testing.T) {
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user mismatch",
|
||||
name: "user no permission",
|
||||
args: args{
|
||||
ctx: CTX,
|
||||
ctx: UserCTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "user permission",
|
||||
args: args{
|
||||
ctx: IamCTX,
|
||||
req: &user.RegisterPasskeyRequest{
|
||||
UserId: userID,
|
||||
},
|
||||
},
|
||||
want: &user.RegisterPasskeyResponse{
|
||||
Details: &object.Details{
|
||||
ChangeDate: timestamppb.Now(),
|
||||
ResourceOwner: Instance.DefaultOrg.Id,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "user setting its own passkey",
|
||||
args: args{
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/saml"
|
||||
"github.com/zitadel/zitadel/internal/repository/instance"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -511,10 +512,10 @@ func (c *Commands) UpdateInstanceAppleProvider(ctx context.Context, id string, p
|
||||
return pushedEventsToObjectDetails(pushedEvents), nil
|
||||
}
|
||||
|
||||
func (c *Commands) AddInstanceSAMLProvider(ctx context.Context, provider SAMLProvider) (string, *domain.ObjectDetails, error) {
|
||||
func (c *Commands) AddInstanceSAMLProvider(ctx context.Context, provider *SAMLProvider) (id string, details *domain.ObjectDetails, err error) {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
instanceAgg := instance.NewAggregate(instanceID)
|
||||
id, err := c.idGenerator.Next()
|
||||
id, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@ -530,7 +531,7 @@ func (c *Commands) AddInstanceSAMLProvider(ctx context.Context, provider SAMLPro
|
||||
return id, pushedEventsToObjectDetails(pushedEvents), nil
|
||||
}
|
||||
|
||||
func (c *Commands) UpdateInstanceSAMLProvider(ctx context.Context, id string, provider SAMLProvider) (*domain.ObjectDetails, error) {
|
||||
func (c *Commands) UpdateInstanceSAMLProvider(ctx context.Context, id string, provider *SAMLProvider) (details *domain.ObjectDetails, err error) {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
instanceAgg := instance.NewAggregate(instanceID)
|
||||
writeModel := NewSAMLInstanceIDPWriteModel(instanceID, id)
|
||||
@ -1719,7 +1720,7 @@ func (c *Commands) prepareUpdateInstanceAppleProvider(a *instance.Aggregate, wri
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) prepareAddInstanceSAMLProvider(a *instance.Aggregate, writeModel *InstanceSAMLIDPWriteModel, provider SAMLProvider) preparation.Validation {
|
||||
func (c *Commands) prepareAddInstanceSAMLProvider(a *instance.Aggregate, writeModel *InstanceSAMLIDPWriteModel, provider *SAMLProvider) preparation.Validation {
|
||||
return func() (preparation.CreateCommands, error) {
|
||||
if provider.Name = strings.TrimSpace(provider.Name); provider.Name == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "INST-o07zjotgnd", "Errors.Invalid.Argument")
|
||||
@ -1734,6 +1735,9 @@ func (c *Commands) prepareAddInstanceSAMLProvider(a *instance.Aggregate, writeMo
|
||||
if len(provider.Metadata) == 0 {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "INST-3bi3esi16t", "Errors.Invalid.Argument")
|
||||
}
|
||||
if _, err := saml.ParseMetadata(provider.Metadata); err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "INST-SF3rwhgh", "Errors.Project.App.SAMLMetadataFormat")
|
||||
}
|
||||
return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
|
||||
events, err := filter(ctx, writeModel.Query())
|
||||
if err != nil {
|
||||
@ -1772,7 +1776,7 @@ func (c *Commands) prepareAddInstanceSAMLProvider(a *instance.Aggregate, writeMo
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) prepareUpdateInstanceSAMLProvider(a *instance.Aggregate, writeModel *InstanceSAMLIDPWriteModel, provider SAMLProvider) preparation.Validation {
|
||||
func (c *Commands) prepareUpdateInstanceSAMLProvider(a *instance.Aggregate, writeModel *InstanceSAMLIDPWriteModel, provider *SAMLProvider) preparation.Validation {
|
||||
return func() (preparation.CreateCommands, error) {
|
||||
if writeModel.ID = strings.TrimSpace(writeModel.ID); writeModel.ID == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "INST-7o3rq1owpm", "Errors.Invalid.Argument")
|
||||
@ -1790,6 +1794,9 @@ func (c *Commands) prepareUpdateInstanceSAMLProvider(a *instance.Aggregate, writ
|
||||
}
|
||||
provider.Metadata = data
|
||||
}
|
||||
if _, err := saml.ParseMetadata(provider.Metadata); err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "INST-dsfj3kl2", "Errors.Project.App.SAMLMetadataFormat")
|
||||
}
|
||||
return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
|
||||
events, err := filter(ctx, writeModel.Query())
|
||||
if err != nil {
|
||||
|
@ -22,6 +22,73 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var (
|
||||
validSAMLMetadata = []byte(`<?xml version="1.0" encoding="UTF-8"?>
|
||||
<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8080/saml/v2/metadata" ID="_8b02ecf6-aea4-4eda-96c6-190551f05b07">
|
||||
<Signature xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<SignedInfo xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<CanonicalizationMethod xmlns="http://www.w3.org/2000/09/xmldsig#" Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"></CanonicalizationMethod>
|
||||
<SignatureMethod xmlns="http://www.w3.org/2000/09/xmldsig#" Algorithm="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256"></SignatureMethod>
|
||||
<Reference xmlns="http://www.w3.org/2000/09/xmldsig#" URI="#_8b02ecf6-aea4-4eda-96c6-190551f05b07">
|
||||
<Transforms xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<Transform xmlns="http://www.w3.org/2000/09/xmldsig#" Algorithm="http://www.w3.org/2000/09/xmldsig#enveloped-signature"></Transform>
|
||||
<Transform xmlns="http://www.w3.org/2000/09/xmldsig#" Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"></Transform>
|
||||
</Transforms>
|
||||
<DigestMethod xmlns="http://www.w3.org/2000/09/xmldsig#" Algorithm="http://www.w3.org/2001/04/xmlenc#sha256"></DigestMethod>
|
||||
<DigestValue xmlns="http://www.w3.org/2000/09/xmldsig#">Tyw4csdpNNq0E7wi5FXWdVNkdPNg+cM6kK21VB2+iF0=</DigestValue>
|
||||
</Reference>
|
||||
</SignedInfo>
|
||||
<SignatureValue xmlns="http://www.w3.org/2000/09/xmldsig#">hWQSYmnBJENy/okk2qRDuHaZiyqpDsdV6BF9/T/LNjUh/8z4dV2NEZvkNhFEyQ+bqdj+NmRWvKqpg1dtgNJxQc32+IsLQvXNYyhMCtyG570/jaTOtm8daV4NKJyTV7SdwM6yfXgubz5YCRTyV13W2gBIFYppIRImIv5NDcjz+lEmWhnrkw8G2wRSFUY7VvkDn9rgsTzw/Pnsw6hlzpjGDYPMPx3ux3kjFVevdhFGNo+VC7t9ozruuGyH3yue9Re6FZoqa4oyWaPSOwei0ZH6UNqkX93Eo5Y49QKwaO8Rm+kWsOhdTqebVmCc+SpWbbrZbQj4nSLgWGlvCkZSivmH7ezr4Ol1ZkRetQ92UQ7xJS7E0y6uXAGvdgpDnyqHCOFfhTS6yqltHtc3m7JZex327xkv6e69uAEOSiv++sifVUIE0h/5u3hZLvwmTPrkoRVY4wgZ4ieb86QPvhw4UPeYapOhCBk5RfjoEFIeYwPUw5rtOlpTyeBJiKMpH1+mDAoa+8HQytZoMrnnY1s612vINtY7jU5igMwIk6MitQpRGibnBVBHRc2A6aE+XS333ganFK9hX6TzNkpHUb66NINDZ8Rgb1thn3MABArGlomtM5/enrAixWExZp70TSElor7SBdBW57H7OZCYUCobZuPRDLsCO6LLKeVrbdygWeRqr/o=</SignatureValue>
|
||||
<KeyInfo xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<X509Data xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<X509Certificate xmlns="http://www.w3.org/2000/09/xmldsig#">MIIFIjCCAwqgAwIBAgICA7YwDQYJKoZIhvcNAQELBQAwLDEQMA4GA1UEChMHWklUQURFTDEYMBYGA1UEAxMPWklUQURFTCBTQU1MIENBMB4XDTI0MTEyNzEwMjc0NFoXDTI1MTEyNzE2Mjc0NFowMjEQMA4GA1UEChMHWklUQURFTDEeMBwGA1UEAxMVWklUQURFTCBTQU1MIG1ldGFkYXRhMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEApEpYT7EjbRBp0Hw7PGCiSgUoJtwd2nwZOhGy5WZVWvraAtHzW5ih2B6UwEShjwCmRJZeKYEN9JKJbpAy2EdL/l2rm/pArVNvSQu6sN4izz5p2rd9NfHAO3/EcvYdrelWLQj8WQx6LVM282Z4wbclp8Jz1y8Ow43352hGfFVc1x8gauoNl5MAy4kdbvs8UqihqcRmEyIOWl6UwTApb+XIRSRz0Yop99Fv9ALJwfUppsx+d4j9rlRDvrQJMJz7GC/19L9INTbY0HsVEiTltdAWHwREwrpwxNJQt42p3W/zpf1mjwXd3qNNDZAr1t2POPP4SXd598kabBZ3EMWGGxFw+NYYajyjG5EFOZw09FFJn2jIcovejvigfdqem5DGPECvHefqcqHkBPGukI3RaotXpAYyAGfnV7slVytSW484IX3KloAJLICbETbFGGsGQzIDw8rUqWyaOCOttw2fVNDyRFUMHrGe1PhJ9qA1If+KCWYD0iJqF03rIEhdrvNSdQNYkRa0DdtpacQLpzQtqsUioODqX0W3uzLceJEXLBbU0ZEk8mWZM/auwMo3ycPNXDVwrb6AkUKar+sqSumUuixw7da3KF1/mynh6M2Eo4NRB16oUiyN0EYrit/RRJjsTdH+71cj0V+8KqO88cBpmm+lO6x4RM5xpOf/EwwQHivxgRkCAwEAAaNIMEYwDgYDVR0PAQH/BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMB8GA1UdIwQYMBaAFIzl7uckcPWldirXeOFL3rH6K8FLMA0GCSqGSIb3DQEBCwUAA4ICAQBz+7R99uX1Us9T4BB2RK3RD9K8Q5foNmxJ8GbxpOQFL8IG1DE3FqBssciJkOsKY+1+Y6eow2TgmD9MxfCY444C8k8YDDjxIcs+4dEaWMUxA6NoEy378ciy0U1E6rpYLxWYTxXmsELyODWwTrRNIiWfbBD2m0w9HYbK6QvX6IYQqYoTOJJ3WJKsMCeQ8XhQsJYNINZEq8RsERY/aikOlTWN7ax4Mkr3bfnz1euXGClExCOM6ej4m2I33i4nyYBvvRkRRZRQCfkAQ+5WFVZoVXrQHNe/Oifit7tfLaDuybcjgkzzY3o0YbczzbdV69fVoj53VpR3QQOB+PCF/VJPUMtUFPEC05yH76g24KVBiM/Ws8GaERW1AxgupHSmvTY3GSiwDXQ2NzgDxUHfRHo8rxenJdEcPlGM0DstbUONDSFGLwvGDiidUVtqj1UB4yGL26bgtmwf61G4qsTn9PJMWdRmCeeOf7fmloRxTA0EEey3bulBBHim466tWHUhgOP+g1X0iE7CnwL8aJ//CCiQOAv1O6x5RLyxrmVTehPLr1T8qvnBmxpmuYU0kfbYpO3tMVe7VLabBx0cYh7izClZKHhgEj1w4aE9tIk7nqVAwvVocT3io8RrcKixlnBrFd7RYIuF3+RsYC/kYEgnZYKAig5u2TySgGmJ7nIS24FYW68WDg==</X509Certificate>
|
||||
</X509Data>
|
||||
</KeyInfo>
|
||||
</Signature>
|
||||
<IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" WantAuthnRequestsSigned="1" ID="_fd70402c-8a31-4a9a-a4a7-da526524c609" validUntil="2024-12-02T16:54:55.656Z" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<SingleSignOnService xmlns="urn:oasis:names:tc:SAML:2.0:metadata" Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8080/saml/v2/SSO"></SingleSignOnService>
|
||||
<SingleSignOnService xmlns="urn:oasis:names:tc:SAML:2.0:metadata" Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="http://localhost:8080/saml/v2/SSO"></SingleSignOnService>
|
||||
<AttributeProfile>urn:oasis:names:tc:SAML:2.0:profiles:attribute:basic</AttributeProfile>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="Email" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="SurName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="FirstName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="FullName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="UserName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="UserID" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<SingleLogoutService xmlns="urn:oasis:names:tc:SAML:2.0:metadata" Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8080/saml/v2/SLO"></SingleLogoutService>
|
||||
<SingleLogoutService xmlns="urn:oasis:names:tc:SAML:2.0:metadata" Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="http://localhost:8080/saml/v2/SLO"></SingleLogoutService>
|
||||
<NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:persistent</NameIDFormat>
|
||||
<KeyDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" use="signing">
|
||||
<KeyInfo xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<KeyName>http://localhost:8080/saml/v2/metadata IDP signing</KeyName>
|
||||
<X509Data xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<X509Certificate xmlns="http://www.w3.org/2000/09/xmldsig#">MIIFIjCCAwqgAwIBAgICA7QwDQYJKoZIhvcNAQELBQAwLDEQMA4GA1UEChMHWklUQURFTDEYMBYGA1UEAxMPWklUQURFTCBTQU1MIENBMB4XDTI0MTEyNzEwMjUwMloXDTI1MTEyNzE2MjUwMlowMjEQMA4GA1UEChMHWklUQURFTDEeMBwGA1UEAxMVWklUQURFTCBTQU1MIHJlc3BvbnNlMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA2lUgaI6AS/9xvM9DNSWK6Ho64LpK8UIioM26QfvAfeQ/I2pgX6SwWxEbd7qv+PkJzaFTjrXSlwOmWsJYma+UsdyFClaGFRyCgY8SWxPceandC8a+hQIDS/irLd9XF33RWp0b/09HjQl+n0HZ4teUFDUd2U1mUf3XCpn0+Ho316bmi6xSW6zaMy5RsbUl01hgWj2fgapAsGAHSBphwCE3Dz/9I/UfHWQw1k2/UTgjc9uIujcza6WgOxfsKluXYIOxwNKTfmzzOJMUwXz6GRgB2jhQI29MuKOZOITA7pXq5kZKf0lSRU8zKFTMJaK4zAHQ6f877Drr8XdAHemuXGZ2JdH/Dbdwarzy3YBMCWsAYlpeEvaVAdiSpyR7fAZktNuHd39Zg00Vlj2wdc44Vk5yVssW7pv5qnVZ7JTrXX2uBYFecLAXmplQ2ph1VdSXZLEDGgjiNA2T/fBj7G4/VjsuCBZFm1I0KCJp3HWEJx5dwwhSVc5wOJEzl7fMuPYMKWH/RM6P/7LnO1ulpdmiKPa4gHzdg3hDZn42NKcVt3UYf0phtxpWMrZp/DUEeizhckrC4ed6cfGtS3CUtJEqoycrCROJ5Hy+ONHl5Aqxt+JoPU+t/XATuctfPxQVcDr0itHzo2cjh/AVTU+IC7C0oQHSS9CC8Fp58UqbtYwFtSAd7ecCAwEAAaNIMEYwDgYDVR0PAQH/BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMB8GA1UdIwQYMBaAFIzl7uckcPWldirXeOFL3rH6K8FLMA0GCSqGSIb3DQEBCwUAA4ICAQAp+IGZScVIbRdCq5HPjlYBPOY7UbL8ZXnlMW/HLELV9GndnULuFhnuQTIdA5dquCsk8RI1fKsScEV1rqWvHZeSo5nVbvUaPJctoD/4GACqE6F8axs1AgSOvpJMyuycjSzSh6gDM1z37Fdqc/2IRqgi7SKdDsfJpi8XW8LtErpp4kyE1rEXopsXG2fe1UH25bZpXraUqYvp61rwVUCazAtV/U7ARG5AnT0mPqzUriIPrfL+v/+2ntV/BSc8/uCqYnHbwpIwjPURCaxo1Pmm6EEkm+V/Ss4ieNwwkD2bLLLST1LoVMim7Ebfy53PEKpsznKsGlVSu0YYKUsStWQVpwhKQw0bQLCJHdpvZtZSDgS9RbSMZz+aY/fpoNx6wDvmMgtdrb3pVXZ8vPKdq9YDrGfFqP60QdZ3CuSHXCM/zX4742GgImJ4KYAcTuF1+BkGf5JLAJOUZBkfCQ/kBT5wr8+EotLxASOC6717whLBYMEG6N8osEk+LDqoJRTLqkzirJsyOHWChKK47yGkdS3HBIZfo91QrJwKpfATYziBjEnqipkTu+6jFylBIkxKTPye4b3vgcodZP8LSNVXAsMGTPNPJxzPWQ37ba4zMnYZ5iUerlaox/SNsn68DT6RajIb1A1JDq+HNFc3hQP2bzk2y5pCax8zo5swjdklnm4clfB2Lw==</X509Certificate>
|
||||
</X509Data>
|
||||
</KeyInfo>
|
||||
</KeyDescriptor>
|
||||
</IDPSSODescriptor>
|
||||
<AttributeAuthorityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" ID="_b3fed381-af56-4160-abf5-5ffd1e21cf61" validUntil="2024-12-02T16:54:55.656Z" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
|
||||
<AttributeService xmlns="urn:oasis:names:tc:SAML:2.0:metadata" Binding="urn:oasis:names:tc:SAML:2.0:bindings:SOAP" Location="http://localhost:8080/saml/v2/attribute"></AttributeService>
|
||||
<NameIDFormat>urn:oasis:names:tc:SAML:2.0:nameid-format:persistent</NameIDFormat>
|
||||
<AttributeProfile>urn:oasis:names:tc:SAML:2.0:profiles:attribute:basic</AttributeProfile>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="Email" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="SurName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="FirstName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="FullName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="UserName" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<Attribute xmlns="urn:oasis:names:tc:SAML:2.0:assertion" Name="UserID" NameFormat="urn:oasis:names:tc:SAML:2.0:attrname-format:basic"><AttributeValue></AttributeValue></Attribute>
|
||||
<KeyDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" use="signing">
|
||||
<KeyInfo xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<KeyName>http://localhost:8080/saml/v2/metadata IDP signing</KeyName>
|
||||
<X509Data xmlns="http://www.w3.org/2000/09/xmldsig#">
|
||||
<X509Certificate xmlns="http://www.w3.org/2000/09/xmldsig#">MIIFIjCCAwqgAwIBAgICA7QwDQYJKoZIhvcNAQELBQAwLDEQMA4GA1UEChMHWklUQURFTDEYMBYGA1UEAxMPWklUQURFTCBTQU1MIENBMB4XDTI0MTEyNzEwMjUwMloXDTI1MTEyNzE2MjUwMlowMjEQMA4GA1UEChMHWklUQURFTDEeMBwGA1UEAxMVWklUQURFTCBTQU1MIHJlc3BvbnNlMIICIjANBgkqhkiG9w0BAQEFAAOCAg8AMIICCgKCAgEA2lUgaI6AS/9xvM9DNSWK6Ho64LpK8UIioM26QfvAfeQ/I2pgX6SwWxEbd7qv+PkJzaFTjrXSlwOmWsJYma+UsdyFClaGFRyCgY8SWxPceandC8a+hQIDS/irLd9XF33RWp0b/09HjQl+n0HZ4teUFDUd2U1mUf3XCpn0+Ho316bmi6xSW6zaMy5RsbUl01hgWj2fgapAsGAHSBphwCE3Dz/9I/UfHWQw1k2/UTgjc9uIujcza6WgOxfsKluXYIOxwNKTfmzzOJMUwXz6GRgB2jhQI29MuKOZOITA7pXq5kZKf0lSRU8zKFTMJaK4zAHQ6f877Drr8XdAHemuXGZ2JdH/Dbdwarzy3YBMCWsAYlpeEvaVAdiSpyR7fAZktNuHd39Zg00Vlj2wdc44Vk5yVssW7pv5qnVZ7JTrXX2uBYFecLAXmplQ2ph1VdSXZLEDGgjiNA2T/fBj7G4/VjsuCBZFm1I0KCJp3HWEJx5dwwhSVc5wOJEzl7fMuPYMKWH/RM6P/7LnO1ulpdmiKPa4gHzdg3hDZn42NKcVt3UYf0phtxpWMrZp/DUEeizhckrC4ed6cfGtS3CUtJEqoycrCROJ5Hy+ONHl5Aqxt+JoPU+t/XATuctfPxQVcDr0itHzo2cjh/AVTU+IC7C0oQHSS9CC8Fp58UqbtYwFtSAd7ecCAwEAAaNIMEYwDgYDVR0PAQH/BAQDAgWgMBMGA1UdJQQMMAoGCCsGAQUFBwMCMB8GA1UdIwQYMBaAFIzl7uckcPWldirXeOFL3rH6K8FLMA0GCSqGSIb3DQEBCwUAA4ICAQAp+IGZScVIbRdCq5HPjlYBPOY7UbL8ZXnlMW/HLELV9GndnULuFhnuQTIdA5dquCsk8RI1fKsScEV1rqWvHZeSo5nVbvUaPJctoD/4GACqE6F8axs1AgSOvpJMyuycjSzSh6gDM1z37Fdqc/2IRqgi7SKdDsfJpi8XW8LtErpp4kyE1rEXopsXG2fe1UH25bZpXraUqYvp61rwVUCazAtV/U7ARG5AnT0mPqzUriIPrfL+v/+2ntV/BSc8/uCqYnHbwpIwjPURCaxo1Pmm6EEkm+V/Ss4ieNwwkD2bLLLST1LoVMim7Ebfy53PEKpsznKsGlVSu0YYKUsStWQVpwhKQw0bQLCJHdpvZtZSDgS9RbSMZz+aY/fpoNx6wDvmMgtdrb3pVXZ8vPKdq9YDrGfFqP60QdZ3CuSHXCM/zX4742GgImJ4KYAcTuF1+BkGf5JLAJOUZBkfCQ/kBT5wr8+EotLxASOC6717whLBYMEG6N8osEk+LDqoJRTLqkzirJsyOHWChKK47yGkdS3HBIZfo91QrJwKpfATYziBjEnqipkTu+6jFylBIkxKTPye4b3vgcodZP8LSNVXAsMGTPNPJxzPWQ37ba4zMnYZ5iUerlaox/SNsn68DT6RajIb1A1JDq+HNFc3hQP2bzk2y5pCax8zo5swjdklnm4clfB2Lw==</X509Certificate>
|
||||
</X509Data>
|
||||
</KeyInfo>
|
||||
</KeyDescriptor>
|
||||
</AttributeAuthorityDescriptor>
|
||||
</EntityDescriptor>`)
|
||||
)
|
||||
|
||||
func TestCommandSide_AddInstanceGenericOAuthIDP(t *testing.T) {
|
||||
type fields struct {
|
||||
eventstore func(*testing.T) *eventstore.Eventstore
|
||||
@ -5180,7 +5247,7 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
provider SAMLProvider
|
||||
provider *SAMLProvider
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
@ -5201,7 +5268,7 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
provider: SAMLProvider{},
|
||||
provider: &SAMLProvider{},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
@ -5210,14 +5277,14 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata",
|
||||
"no metadata",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "id1"),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
},
|
||||
},
|
||||
@ -5227,6 +5294,25 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata, error",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "id1"),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, zerrors.ThrowInvalidArgument(nil, "INST-SF3rwhgh", "Errors.Project.App.SAMLMetadataFormat"))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
@ -5236,7 +5322,7 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
instance.NewSAMLIDPAddedEvent(context.Background(), &instance.NewAggregate("instance1").Aggregate,
|
||||
"id1",
|
||||
"name",
|
||||
[]byte("metadata"),
|
||||
validSAMLMetadata,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
@ -5258,9 +5344,9 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@ -5277,7 +5363,7 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
instance.NewSAMLIDPAddedEvent(context.Background(), &instance.NewAggregate("instance1").Aggregate,
|
||||
"id1",
|
||||
"name",
|
||||
[]byte("metadata"),
|
||||
validSAMLMetadata,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
@ -5304,9 +5390,9 @@ func TestCommandSide_AddInstanceSAMLIDP(t *testing.T) {
|
||||
},
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
Binding: "binding",
|
||||
WithSignedRequest: true,
|
||||
NameIDFormat: gu.Ptr(domain.SAMLNameIDFormatTransient),
|
||||
@ -5356,7 +5442,7 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
id string
|
||||
provider SAMLProvider
|
||||
provider *SAMLProvider
|
||||
}
|
||||
type res struct {
|
||||
want *domain.ObjectDetails
|
||||
@ -5375,7 +5461,7 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
provider: SAMLProvider{},
|
||||
provider: &SAMLProvider{},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
@ -5391,7 +5477,7 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
id: "id1",
|
||||
provider: SAMLProvider{},
|
||||
provider: &SAMLProvider{},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
@ -5400,14 +5486,14 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata",
|
||||
"no metadata",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
},
|
||||
},
|
||||
@ -5417,6 +5503,25 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata, error",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
id: "id1",
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, zerrors.ThrowInvalidArgument(nil, "INST-dsfj3kl2", "Errors.Project.App.SAMLMetadataFormat"))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
fields: fields{
|
||||
@ -5427,9 +5532,9 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@ -5445,7 +5550,7 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
instance.NewSAMLIDPAddedEvent(context.Background(), &instance.NewAggregate("instance1").Aggregate,
|
||||
"id1",
|
||||
"name",
|
||||
[]byte("metadata"),
|
||||
validSAMLMetadata,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
@ -5465,9 +5570,9 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@ -5505,7 +5610,7 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
"id1",
|
||||
[]idp.SAMLIDPChanges{
|
||||
idp.ChangeSAMLName("new name"),
|
||||
idp.ChangeSAMLMetadata([]byte("new metadata")),
|
||||
idp.ChangeSAMLMetadata(validSAMLMetadata),
|
||||
idp.ChangeSAMLBinding("new binding"),
|
||||
idp.ChangeSAMLWithSignedRequest(true),
|
||||
idp.ChangeSAMLNameIDFormat(gu.Ptr(domain.SAMLNameIDFormatTransient)),
|
||||
@ -5527,9 +5632,9 @@ func TestCommandSide_UpdateInstanceGenericSAMLIDP(t *testing.T) {
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(context.Background(), "instance1"),
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "new name",
|
||||
Metadata: []byte("new metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
Binding: "new binding",
|
||||
WithSignedRequest: true,
|
||||
NameIDFormat: gu.Ptr(domain.SAMLNameIDFormatTransient),
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/saml"
|
||||
"github.com/zitadel/zitadel/internal/repository/org"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -446,9 +447,9 @@ func (c *Commands) UpdateOrgLDAPProvider(ctx context.Context, resourceOwner, id
|
||||
return pushedEventsToObjectDetails(pushedEvents), nil
|
||||
}
|
||||
|
||||
func (c *Commands) AddOrgSAMLProvider(ctx context.Context, resourceOwner string, provider SAMLProvider) (string, *domain.ObjectDetails, error) {
|
||||
func (c *Commands) AddOrgSAMLProvider(ctx context.Context, resourceOwner string, provider *SAMLProvider) (id string, details *domain.ObjectDetails, err error) {
|
||||
orgAgg := org.NewAggregate(resourceOwner)
|
||||
id, err := c.idGenerator.Next()
|
||||
id, err = c.idGenerator.Next()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
@ -464,7 +465,7 @@ func (c *Commands) AddOrgSAMLProvider(ctx context.Context, resourceOwner string,
|
||||
return id, pushedEventsToObjectDetails(pushedEvents), nil
|
||||
}
|
||||
|
||||
func (c *Commands) UpdateOrgSAMLProvider(ctx context.Context, resourceOwner, id string, provider SAMLProvider) (*domain.ObjectDetails, error) {
|
||||
func (c *Commands) UpdateOrgSAMLProvider(ctx context.Context, resourceOwner, id string, provider *SAMLProvider) (details *domain.ObjectDetails, err error) {
|
||||
orgAgg := org.NewAggregate(resourceOwner)
|
||||
writeModel := NewSAMLOrgIDPWriteModel(resourceOwner, id)
|
||||
cmds, err := preparation.PrepareCommands(ctx, c.eventstore.Filter, c.prepareUpdateOrgSAMLProvider(orgAgg, writeModel, provider))
|
||||
@ -1703,7 +1704,7 @@ func (c *Commands) prepareUpdateOrgAppleProvider(a *org.Aggregate, writeModel *O
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) prepareAddOrgSAMLProvider(a *org.Aggregate, writeModel *OrgSAMLIDPWriteModel, provider SAMLProvider) preparation.Validation {
|
||||
func (c *Commands) prepareAddOrgSAMLProvider(a *org.Aggregate, writeModel *OrgSAMLIDPWriteModel, provider *SAMLProvider) preparation.Validation {
|
||||
return func() (preparation.CreateCommands, error) {
|
||||
if provider.Name = strings.TrimSpace(provider.Name); provider.Name == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "ORG-957lr0f8u3", "Errors.Invalid.Argument")
|
||||
@ -1718,6 +1719,9 @@ func (c *Commands) prepareAddOrgSAMLProvider(a *org.Aggregate, writeModel *OrgSA
|
||||
}
|
||||
provider.Metadata = data
|
||||
}
|
||||
if _, err := saml.ParseMetadata(provider.Metadata); err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "ORG-SF3rwhgh", "Errors.Project.App.SAMLMetadataFormat")
|
||||
}
|
||||
return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
|
||||
events, err := filter(ctx, writeModel.Query())
|
||||
if err != nil {
|
||||
@ -1755,7 +1759,7 @@ func (c *Commands) prepareAddOrgSAMLProvider(a *org.Aggregate, writeModel *OrgSA
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Commands) prepareUpdateOrgSAMLProvider(a *org.Aggregate, writeModel *OrgSAMLIDPWriteModel, provider SAMLProvider) preparation.Validation {
|
||||
func (c *Commands) prepareUpdateOrgSAMLProvider(a *org.Aggregate, writeModel *OrgSAMLIDPWriteModel, provider *SAMLProvider) preparation.Validation {
|
||||
return func() (preparation.CreateCommands, error) {
|
||||
if writeModel.ID = strings.TrimSpace(writeModel.ID); writeModel.ID == "" {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "ORG-wwdwdlaya0", "Errors.Invalid.Argument")
|
||||
@ -1773,6 +1777,9 @@ func (c *Commands) prepareUpdateOrgSAMLProvider(a *org.Aggregate, writeModel *Or
|
||||
if provider.Metadata == nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(nil, "ORG-j6spncd74m", "Errors.Invalid.Argument")
|
||||
}
|
||||
if _, err := saml.ParseMetadata(provider.Metadata); err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "ORG-SFqqh42", "Errors.Project.App.SAMLMetadataFormat")
|
||||
}
|
||||
return func(ctx context.Context, filter preparation.FilterToQueryReducer) ([]eventstore.Command, error) {
|
||||
events, err := filter(ctx, writeModel.Query())
|
||||
if err != nil {
|
||||
|
@ -5348,7 +5348,7 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
resourceOwner string
|
||||
provider SAMLProvider
|
||||
provider *SAMLProvider
|
||||
}
|
||||
type res struct {
|
||||
id string
|
||||
@ -5370,7 +5370,7 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
provider: SAMLProvider{},
|
||||
provider: &SAMLProvider{},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
@ -5379,7 +5379,7 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata",
|
||||
"no metadata",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "id1"),
|
||||
@ -5387,7 +5387,7 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
},
|
||||
},
|
||||
@ -5397,6 +5397,26 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata, fail on error",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
idGenerator: id_mock.NewIDGeneratorExpectIDs(t, "id1"),
|
||||
},
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, zerrors.ThrowInvalidArgument(nil, "ORG-SF3rwhgh", "Errors.Project.App.SAMLMetadataFormat"))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "ok",
|
||||
fields: fields{
|
||||
@ -5406,7 +5426,7 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
org.NewSAMLIDPAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
|
||||
"id1",
|
||||
"name",
|
||||
[]byte("metadata"),
|
||||
validSAMLMetadata,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
@ -5428,9 +5448,9 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@ -5447,7 +5467,7 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
org.NewSAMLIDPAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
|
||||
"id1",
|
||||
"name",
|
||||
[]byte("metadata"),
|
||||
validSAMLMetadata,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
@ -5475,9 +5495,9 @@ func TestCommandSide_AddOrgSAMLIDP(t *testing.T) {
|
||||
args: args{
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
Binding: "binding",
|
||||
WithSignedRequest: true,
|
||||
NameIDFormat: gu.Ptr(domain.SAMLNameIDFormatTransient),
|
||||
@ -5528,7 +5548,7 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
ctx context.Context
|
||||
resourceOwner string
|
||||
id string
|
||||
provider SAMLProvider
|
||||
provider *SAMLProvider
|
||||
}
|
||||
type res struct {
|
||||
want *domain.ObjectDetails
|
||||
@ -5548,7 +5568,7 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
provider: SAMLProvider{},
|
||||
provider: &SAMLProvider{},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
@ -5565,7 +5585,7 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
id: "id1",
|
||||
provider: SAMLProvider{},
|
||||
provider: &SAMLProvider{},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
@ -5574,7 +5594,7 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata",
|
||||
"no metadata",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
@ -5582,7 +5602,7 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
},
|
||||
},
|
||||
@ -5592,6 +5612,26 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
"invalid metadata, error",
|
||||
fields{
|
||||
eventstore: expectEventstore(),
|
||||
},
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
id: "id1",
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
},
|
||||
},
|
||||
res{
|
||||
err: func(err error) bool {
|
||||
return errors.Is(err, zerrors.ThrowInvalidArgument(nil, "ORG-SFqqh42", "Errors.Project.App.SAMLMetadataFormat"))
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "not found",
|
||||
fields: fields{
|
||||
@ -5603,9 +5643,9 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@ -5623,7 +5663,7 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
org.NewSAMLIDPAddedEvent(context.Background(), &org.NewAggregate("org1").Aggregate,
|
||||
"id1",
|
||||
"name",
|
||||
[]byte("metadata"),
|
||||
validSAMLMetadata,
|
||||
&crypto.CryptoValue{
|
||||
CryptoType: crypto.TypeEncryption,
|
||||
Algorithm: "enc",
|
||||
@ -5644,9 +5684,9 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "name",
|
||||
Metadata: []byte("metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
},
|
||||
},
|
||||
res: res{
|
||||
@ -5684,7 +5724,7 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
"id1",
|
||||
[]idp.SAMLIDPChanges{
|
||||
idp.ChangeSAMLName("new name"),
|
||||
idp.ChangeSAMLMetadata([]byte("new metadata")),
|
||||
idp.ChangeSAMLMetadata(validSAMLMetadata),
|
||||
idp.ChangeSAMLBinding("new binding"),
|
||||
idp.ChangeSAMLWithSignedRequest(true),
|
||||
idp.ChangeSAMLNameIDFormat(gu.Ptr(domain.SAMLNameIDFormatTransient)),
|
||||
@ -5707,9 +5747,9 @@ func TestCommandSide_UpdateOrgSAMLIDP(t *testing.T) {
|
||||
ctx: context.Background(),
|
||||
resourceOwner: "org1",
|
||||
id: "id1",
|
||||
provider: SAMLProvider{
|
||||
provider: &SAMLProvider{
|
||||
Name: "new name",
|
||||
Metadata: []byte("new metadata"),
|
||||
Metadata: validSAMLMetadata,
|
||||
Binding: "new binding",
|
||||
WithSignedRequest: true,
|
||||
NameIDFormat: gu.Ptr(domain.SAMLNameIDFormatTransient),
|
||||
|
@ -205,12 +205,6 @@ func (c *Commands) RemoveProjectGrant(ctx context.Context, projectID, grantID, r
|
||||
if grantID == "" || projectID == "" {
|
||||
return details, zerrors.ThrowInvalidArgument(nil, "PROJECT-1m9fJ", "Errors.IDMissing")
|
||||
}
|
||||
|
||||
err = c.checkProjectExists(ctx, projectID, resourceOwner)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
existingGrant, err := c.projectGrantWriteModelByID(ctx, grantID, projectID, resourceOwner)
|
||||
if err != nil {
|
||||
return details, err
|
||||
|
@ -1273,11 +1273,25 @@ func TestCommandSide_RemoveProjectGrant(t *testing.T) {
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "project not existing, precondition failed error",
|
||||
name: "project already removed, precondition failed error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(),
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewGrantAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectgrant1",
|
||||
"grantedorg1",
|
||||
[]string{"key1"},
|
||||
)),
|
||||
eventFromEventPusher(
|
||||
project.NewProjectRemovedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1",
|
||||
nil,
|
||||
),
|
||||
),
|
||||
),
|
||||
),
|
||||
},
|
||||
args: args{
|
||||
@ -1287,7 +1301,7 @@ func TestCommandSide_RemoveProjectGrant(t *testing.T) {
|
||||
resourceOwner: "org1",
|
||||
},
|
||||
res: res{
|
||||
err: zerrors.IsPreconditionFailed,
|
||||
err: zerrors.IsNotFound,
|
||||
},
|
||||
},
|
||||
{
|
||||
@ -1295,15 +1309,6 @@ func TestCommandSide_RemoveProjectGrant(t *testing.T) {
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(),
|
||||
),
|
||||
},
|
||||
@ -1322,15 +1327,6 @@ func TestCommandSide_RemoveProjectGrant(t *testing.T) {
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewGrantAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
@ -1365,15 +1361,6 @@ func TestCommandSide_RemoveProjectGrant(t *testing.T) {
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewGrantAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
@ -1410,15 +1397,6 @@ func TestCommandSide_RemoveProjectGrant(t *testing.T) {
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(
|
||||
t,
|
||||
expectFilter(
|
||||
eventFromEventPusher(
|
||||
project.NewProjectAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
"projectname1", true, true, true,
|
||||
domain.PrivateLabelingSettingUnspecified,
|
||||
),
|
||||
),
|
||||
),
|
||||
expectFilter(
|
||||
eventFromEventPusher(project.NewGrantAddedEvent(context.Background(),
|
||||
&project.NewAggregate("project1", "org1").Aggregate,
|
||||
|
@ -7,7 +7,6 @@ import (
|
||||
"github.com/pquerna/otp"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
@ -79,10 +78,8 @@ func (c *Commands) createHumanTOTP(ctx context.Context, userID, resourceOwner st
|
||||
logging.WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Debug("unable to get human for loginname")
|
||||
return nil, zerrors.ThrowPreconditionFailed(err, "COMMAND-SqyJz", "Errors.User.NotFound")
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserCredentialWrite, human.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUserCredentials(ctx, human.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
org, err := c.getOrg(ctx, human.ResourceOwner)
|
||||
if err != nil {
|
||||
@ -139,10 +136,8 @@ func (c *Commands) HumanCheckMFATOTPSetup(ctx context.Context, userID, code, use
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserCredentialWrite, existingOTP.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUserCredentials(ctx, existingOTP.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if existingOTP.State == domain.MFAStateUnspecified || existingOTP.State == domain.MFAStateRemoved {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-3Mif9s", "Errors.User.MFA.OTP.NotExisting")
|
||||
@ -242,10 +237,8 @@ func (c *Commands) HumanRemoveTOTP(ctx context.Context, userID, resourceOwner st
|
||||
if existingOTP.State == domain.MFAStateUnspecified || existingOTP.State == domain.MFAStateRemoved {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Hd9sd", "Errors.User.MFA.OTP.NotExisting")
|
||||
}
|
||||
if userID != authz.GetCtxData(ctx).UserID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserWrite, existingOTP.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUser(ctx, existingOTP.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
userAgg := UserAggregateFromWriteModel(&existingOTP.WriteModel)
|
||||
pushedEvents, err := c.eventstore.Push(ctx, user.NewHumanOTPRemovedEvent(ctx, userAgg))
|
||||
@ -286,10 +279,8 @@ func (c *Commands) addHumanOTPSMS(ctx context.Context, userID, resourceOwner str
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserCredentialWrite, otpWriteModel.ResourceOwner(), userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUserCredentials(ctx, otpWriteModel.ResourceOwner(), userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if otpWriteModel.otpAdded {
|
||||
return nil, zerrors.ThrowAlreadyExists(nil, "COMMAND-Ad3g2", "Errors.User.MFA.OTP.AlreadyReady")
|
||||
@ -318,10 +309,8 @@ func (c *Commands) RemoveHumanOTPSMS(ctx context.Context, userID, resourceOwner
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userID != authz.GetCtxData(ctx).UserID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserWrite, existingOTP.WriteModel.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUser(ctx, existingOTP.WriteModel.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !existingOTP.otpAdded {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-Sr3h3", "Errors.User.MFA.OTP.NotExisting")
|
||||
@ -420,10 +409,8 @@ func (c *Commands) addHumanOTPEmail(ctx context.Context, userID, resourceOwner s
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserCredentialWrite, otpWriteModel.ResourceOwner(), userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUserCredentials(ctx, otpWriteModel.ResourceOwner(), userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if otpWriteModel.otpAdded {
|
||||
return nil, zerrors.ThrowAlreadyExists(nil, "COMMAND-MKL2s", "Errors.User.MFA.OTP.AlreadyReady")
|
||||
@ -452,10 +439,8 @@ func (c *Commands) RemoveHumanOTPEmail(ctx context.Context, userID, resourceOwne
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if userID != authz.GetCtxData(ctx).UserID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserWrite, existingOTP.WriteModel.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUser(ctx, existingOTP.WriteModel.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !existingOTP.otpAdded {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-b312D", "Errors.User.MFA.OTP.NotExisting")
|
||||
|
@ -110,7 +110,7 @@ type setPasswordVerification func(ctx context.Context) (newEncodedPassword strin
|
||||
// setPasswordWithPermission returns a permission check as [setPasswordVerification] implementation
|
||||
func (c *Commands) setPasswordWithPermission(userID, orgID string) setPasswordVerification {
|
||||
return func(ctx context.Context) (_ string, err error) {
|
||||
return "", c.checkPermission(ctx, domain.PermissionUserWrite, orgID, userID)
|
||||
return "", c.checkPermissionUpdateUser(ctx, orgID, userID)
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@ -146,10 +145,8 @@ func (c *Commands) addHumanWebAuthN(ctx context.Context, userID, resourceowner,
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err = c.checkPermission(ctx, domain.PermissionUserCredentialWrite, user.ResourceOwner, userID); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUserCredentials(ctx, user.ResourceOwner, userID); err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
org, err := c.getOrg(ctx, user.ResourceOwner)
|
||||
if err != nil {
|
||||
@ -603,10 +600,9 @@ func (c *Commands) removeHumanWebAuthN(ctx context.Context, userID, webAuthNID,
|
||||
if existingWebAuthN.State == domain.MFAStateUnspecified || existingWebAuthN.State == domain.MFAStateRemoved {
|
||||
return nil, zerrors.ThrowNotFound(nil, "COMMAND-DAfb2", "Errors.User.WebAuthN.NotFound")
|
||||
}
|
||||
if userID != authz.GetCtxData(ctx).UserID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserWrite, existingWebAuthN.ResourceOwner, existingWebAuthN.AggregateID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err := c.checkPermissionUpdateUser(ctx, existingWebAuthN.ResourceOwner, existingWebAuthN.AggregateID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
userAgg := UserAggregateFromWriteModel(&existingWebAuthN.WriteModel)
|
||||
|
@ -127,6 +127,16 @@ func (c *Commands) checkPermissionUpdateUser(ctx context.Context, resourceOwner,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Commands) checkPermissionUpdateUserCredentials(ctx context.Context, resourceOwner, userID string) error {
|
||||
if userID != "" && userID == authz.GetCtxData(ctx).UserID {
|
||||
return nil
|
||||
}
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserCredentialWrite, resourceOwner, userID); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Commands) checkPermissionDeleteUser(ctx context.Context, resourceOwner, userID string) error {
|
||||
if userID != "" && userID == authz.GetCtxData(ctx).UserID {
|
||||
return nil
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@ -118,10 +117,8 @@ func (c *Commands) changeUserEmailWithGeneratorEvents(ctx context.Context, userI
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err = c.checkPermission(ctx, domain.PermissionUserWrite, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = c.checkPermissionUpdateUser(ctx, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = cmd.Change(ctx, domain.EmailAddress(email)); err != nil {
|
||||
return nil, err
|
||||
@ -137,10 +134,8 @@ func (c *Commands) resendUserEmailCodeWithGeneratorEvents(ctx context.Context, u
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err = c.checkPermission(ctx, domain.PermissionUserWrite, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = c.checkPermissionUpdateUser(ctx, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cmd.model.Code == nil {
|
||||
return nil, zerrors.ThrowPreconditionFailed(err, "EMAIL-5w5ilin4yt", "Errors.User.Code.Empty")
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
@ -74,10 +73,8 @@ func (c *Commands) ResendInviteCode(ctx context.Context, userID, resourceOwner,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err := c.checkPermission(ctx, domain.PermissionUserWrite, existingCode.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err := c.checkPermissionUpdateUser(ctx, existingCode.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !existingCode.UserState.Exists() {
|
||||
return nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-H3b2a", "Errors.User.NotFound")
|
||||
|
@ -6,7 +6,6 @@ import (
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/command/preparation"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@ -18,7 +17,7 @@ import (
|
||||
// RegisterUserPasskey creates a passkey registration for the current authenticated user.
|
||||
// UserID, usually taken from the request is compared against the user ID in the context.
|
||||
func (c *Commands) RegisterUserPasskey(ctx context.Context, userID, resourceOwner, rpID string, authenticator domain.AuthenticatorAttachment) (*domain.WebAuthNRegistrationDetails, error) {
|
||||
if err := authz.UserIDInCTX(ctx, userID); err != nil {
|
||||
if err := c.checkPermissionUpdateUserCredentials(ctx, resourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return c.registerUserPasskey(ctx, userID, resourceOwner, rpID, authenticator)
|
||||
|
@ -34,8 +34,9 @@ func TestCommands_RegisterUserPasskey(t *testing.T) {
|
||||
}
|
||||
userAgg := &user.NewAggregate("user1", "org1").Aggregate
|
||||
type fields struct {
|
||||
eventstore *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
eventstore func(t *testing.T) *eventstore.Eventstore
|
||||
idGenerator id.Generator
|
||||
checkPermission domain.PermissionCheck
|
||||
}
|
||||
type args struct {
|
||||
userID string
|
||||
@ -51,18 +52,22 @@ func TestCommands_RegisterUserPasskey(t *testing.T) {
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
name: "wrong user",
|
||||
name: "no permission",
|
||||
fields: fields{
|
||||
eventstore: expectEventstore(),
|
||||
checkPermission: newMockPermissionCheckNotAllowed(),
|
||||
},
|
||||
args: args{
|
||||
userID: "foo",
|
||||
resourceOwner: "org1",
|
||||
authenticator: domain.AuthenticatorAttachmentCrossPlattform,
|
||||
},
|
||||
wantErr: zerrors.ThrowPermissionDenied(nil, "AUTH-Bohd2", "Errors.User.UserIDWrong"),
|
||||
wantErr: zerrors.ThrowPermissionDenied(nil, "AUTHZ-HKJD33", "Errors.PermissionDenied"),
|
||||
},
|
||||
{
|
||||
name: "get human passwordless error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
eventstore: expectEventstore(
|
||||
expectFilterError(io.ErrClosedPipe),
|
||||
),
|
||||
},
|
||||
@ -76,7 +81,7 @@ func TestCommands_RegisterUserPasskey(t *testing.T) {
|
||||
{
|
||||
name: "id generator error",
|
||||
fields: fields{
|
||||
eventstore: eventstoreExpect(t,
|
||||
eventstore: expectEventstore(
|
||||
expectFilter(), // getHumanPasswordlessTokens
|
||||
expectFilter(eventFromEventPusher(
|
||||
user.NewHumanAddedEvent(ctx,
|
||||
@ -118,9 +123,10 @@ func TestCommands_RegisterUserPasskey(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := &Commands{
|
||||
eventstore: tt.fields.eventstore,
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
webauthnConfig: webauthnConfig,
|
||||
eventstore: tt.fields.eventstore(t),
|
||||
idGenerator: tt.fields.idGenerator,
|
||||
webauthnConfig: webauthnConfig,
|
||||
checkPermission: tt.fields.checkPermission,
|
||||
}
|
||||
_, err := c.RegisterUserPasskey(ctx, tt.args.userID, tt.args.resourceOwner, tt.args.rpID, tt.args.authenticator)
|
||||
require.ErrorIs(t, err, tt.wantErr)
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/repository/user"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@ -50,10 +49,8 @@ func (c *Commands) requestPasswordReset(ctx context.Context, userID string, retu
|
||||
if model.UserState == domain.UserStateInitial {
|
||||
return nil, nil, zerrors.ThrowPreconditionFailed(nil, "COMMAND-Sfe4g", "Errors.User.NotInitialised")
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err = c.checkPermission(ctx, domain.PermissionUserWrite, model.ResourceOwner, userID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
if err = c.checkPermissionUpdateUser(ctx, model.ResourceOwner, userID); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
var passwordCode *EncryptedCode
|
||||
var generatorID string
|
||||
|
@ -82,10 +82,8 @@ func (c *Commands) changeUserPhoneWithGenerator(ctx context.Context, userID, pho
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err = c.checkPermission(ctx, domain.PermissionUserWrite, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = c.checkPermissionUpdateUser(ctx, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = cmd.Change(ctx, domain.PhoneNumber(phone)); err != nil {
|
||||
return nil, err
|
||||
@ -104,10 +102,8 @@ func (c *Commands) resendUserPhoneCodeWithGenerator(ctx context.Context, userID
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if authz.GetCtxData(ctx).UserID != userID {
|
||||
if err = c.checkPermission(ctx, domain.PermissionUserWrite, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = c.checkPermissionUpdateUser(ctx, cmd.aggregate.ResourceOwner, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if cmd.model.Code == nil && cmd.model.GeneratorID == "" {
|
||||
return nil, zerrors.ThrowPreconditionFailed(err, "PHONE-5xrra88eq8", "Errors.User.Code.Empty")
|
||||
|
@ -3,15 +3,18 @@ package cockroach
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
@ -72,6 +75,12 @@ func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) {
|
||||
}
|
||||
|
||||
func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpose dialect.DBPurpose) (*sql.DB, *pgxpool.Pool, error) {
|
||||
dialect.RegisterAfterConnect(func(ctx context.Context, c *pgx.Conn) error {
|
||||
// CockroachDB by default does not allow multiple modifications of the same table using ON CONFLICT
|
||||
// This is needed to fill the fields table of the eventstore during eventstore.Push.
|
||||
_, err := c.Exec(ctx, "SET enable_multiple_modifications_of_table = on")
|
||||
return err
|
||||
})
|
||||
connConfig, err := dialect.NewConnectionConfig(c.MaxOpenConns, c.MaxIdleConns, pusherRatio, spoolerRatio, purpose)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
@ -82,6 +91,29 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if len(connConfig.AfterConnect) > 0 {
|
||||
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
for _, f := range connConfig.AfterConnect {
|
||||
if err := f(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// For the pusher we set the app name with the instance ID
|
||||
if purpose == dialect.DBPurposeEventPusher {
|
||||
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
|
||||
return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID())
|
||||
}
|
||||
config.AfterRelease = func(conn *pgx.Conn) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
return setAppNameWithID(ctx, conn, purpose, "IDLE")
|
||||
}
|
||||
}
|
||||
|
||||
if connConfig.MaxOpenConns != 0 {
|
||||
config.MaxConns = int32(connConfig.MaxOpenConns)
|
||||
}
|
||||
@ -200,3 +232,11 @@ func (c Config) String(useAdmin bool, appName string) string {
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
||||
|
||||
func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool {
|
||||
// needs to be set like this because psql complains about parameters in the SET statement
|
||||
query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id)
|
||||
_, err := conn.Exec(ctx, query)
|
||||
logging.OnError(err).Warn("failed to set application name")
|
||||
return err == nil
|
||||
}
|
||||
|
@ -18,21 +18,31 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type QueryExecuter interface {
|
||||
Query(query string, args ...any) (*sql.Rows, error)
|
||||
type ContextQuerier interface {
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
Exec(query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type ContextExecuter interface {
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
}
|
||||
|
||||
type ContextQueryExecuter interface {
|
||||
ContextQuerier
|
||||
ContextExecuter
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
QueryExecuter
|
||||
ContextQueryExecuter
|
||||
Beginner
|
||||
Conn(ctx context.Context) (*sql.Conn, error)
|
||||
}
|
||||
|
||||
type Beginner interface {
|
||||
BeginTx(ctx context.Context, opts *sql.TxOptions) (*sql.Tx, error)
|
||||
Begin() (*sql.Tx, error)
|
||||
}
|
||||
|
||||
type Tx interface {
|
||||
QueryExecuter
|
||||
ContextQueryExecuter
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
@ -1,8 +1,13 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"reflect"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -17,6 +22,7 @@ var (
|
||||
type ConnectionConfig struct {
|
||||
MaxOpenConns,
|
||||
MaxIdleConns uint32
|
||||
AfterConnect []func(ctx context.Context, c *pgx.Conn) error
|
||||
}
|
||||
|
||||
// takeRatio of MaxOpenConns and MaxIdleConns from config and returns
|
||||
@ -29,6 +35,7 @@ func (c *ConnectionConfig) takeRatio(ratio float64) (*ConnectionConfig, error) {
|
||||
out := &ConnectionConfig{
|
||||
MaxOpenConns: uint32(ratio * float64(c.MaxOpenConns)),
|
||||
MaxIdleConns: uint32(ratio * float64(c.MaxIdleConns)),
|
||||
AfterConnect: c.AfterConnect,
|
||||
}
|
||||
if c.MaxOpenConns != 0 && out.MaxOpenConns < 1 && ratio > 0 {
|
||||
out.MaxOpenConns = 1
|
||||
@ -40,6 +47,36 @@ func (c *ConnectionConfig) takeRatio(ratio float64) (*ConnectionConfig, error) {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
var afterConnectFuncs []func(ctx context.Context, c *pgx.Conn) error
|
||||
|
||||
func RegisterAfterConnect(f func(ctx context.Context, c *pgx.Conn) error) {
|
||||
afterConnectFuncs = append(afterConnectFuncs, f)
|
||||
}
|
||||
|
||||
func RegisterDefaultPgTypeVariants[T any](m *pgtype.Map, name, arrayName string) {
|
||||
// T
|
||||
var value T
|
||||
m.RegisterDefaultPgType(value, name)
|
||||
|
||||
// *T
|
||||
valueType := reflect.TypeOf(value)
|
||||
m.RegisterDefaultPgType(reflect.New(valueType).Interface(), name)
|
||||
|
||||
// []T
|
||||
sliceType := reflect.SliceOf(valueType)
|
||||
m.RegisterDefaultPgType(reflect.MakeSlice(sliceType, 0, 0).Interface(), arrayName)
|
||||
|
||||
// *[]T
|
||||
m.RegisterDefaultPgType(reflect.New(sliceType).Interface(), arrayName)
|
||||
|
||||
// []*T
|
||||
sliceOfPointerType := reflect.SliceOf(reflect.TypeOf(reflect.New(valueType).Interface()))
|
||||
m.RegisterDefaultPgType(reflect.MakeSlice(sliceOfPointerType, 0, 0).Interface(), arrayName)
|
||||
|
||||
// *[]*T
|
||||
m.RegisterDefaultPgType(reflect.New(sliceOfPointerType).Interface(), arrayName)
|
||||
}
|
||||
|
||||
// NewConnectionConfig calculates [ConnectionConfig] values from the passed ratios
|
||||
// and returns the config applicable for the requested purpose.
|
||||
//
|
||||
@ -59,11 +96,13 @@ func NewConnectionConfig(openConns, idleConns uint32, pusherRatio, projectionRat
|
||||
queryConfig := &ConnectionConfig{
|
||||
MaxOpenConns: openConns,
|
||||
MaxIdleConns: idleConns,
|
||||
AfterConnect: afterConnectFuncs,
|
||||
}
|
||||
pusherConfig, err := queryConfig.takeRatio(pusherRatio)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("event pusher: %w", err)
|
||||
}
|
||||
|
||||
spoolerConfig, err := queryConfig.takeRatio(projectionRatio)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("projection spooler: %w", err)
|
||||
|
@ -3,15 +3,18 @@ package postgres
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
)
|
||||
|
||||
@ -83,6 +86,27 @@ func (c *Config) Connect(useAdmin bool, pusherRatio, spoolerRatio float64, purpo
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
config.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
for _, f := range connConfig.AfterConnect {
|
||||
if err := f(ctx, conn); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// For the pusher we set the app name with the instance ID
|
||||
if purpose == dialect.DBPurposeEventPusher {
|
||||
config.BeforeAcquire = func(ctx context.Context, conn *pgx.Conn) bool {
|
||||
return setAppNameWithID(ctx, conn, purpose, authz.GetInstance(ctx).InstanceID())
|
||||
}
|
||||
config.AfterRelease = func(conn *pgx.Conn) bool {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
return setAppNameWithID(ctx, conn, purpose, "IDLE")
|
||||
}
|
||||
}
|
||||
|
||||
if connConfig.MaxOpenConns != 0 {
|
||||
config.MaxConns = int32(connConfig.MaxOpenConns)
|
||||
}
|
||||
@ -209,3 +233,11 @@ func (c Config) String(useAdmin bool, appName string) string {
|
||||
|
||||
return strings.Join(fields, " ")
|
||||
}
|
||||
|
||||
func setAppNameWithID(ctx context.Context, conn *pgx.Conn, purpose dialect.DBPurpose, id string) bool {
|
||||
// needs to be set like this because psql complains about parameters in the SET statement
|
||||
query := fmt.Sprintf("SET application_name = '%s_%s'", purpose.AppName(), id)
|
||||
_, err := conn.Exec(ctx, query)
|
||||
logging.OnError(err).Warn("failed to set application name")
|
||||
return err == nil
|
||||
}
|
||||
|
@ -3,8 +3,12 @@ package eventstore
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/service"
|
||||
)
|
||||
@ -84,8 +88,10 @@ func (e *BaseEvent) DataAsBytes() []byte {
|
||||
}
|
||||
|
||||
// Revision implements action
|
||||
func (*BaseEvent) Revision() uint16 {
|
||||
return 0
|
||||
func (e *BaseEvent) Revision() uint16 {
|
||||
revision, err := strconv.ParseUint(strings.TrimPrefix(string(e.Agg.Version), "v"), 10, 16)
|
||||
logging.OnError(err).Debug("failed to parse event revision")
|
||||
return uint16(revision)
|
||||
}
|
||||
|
||||
// Unmarshal implements Event
|
||||
|
@ -90,7 +90,7 @@ func (es *Eventstore) Push(ctx context.Context, cmds ...Command) ([]Event, error
|
||||
|
||||
// PushWithClient pushes the events in a single transaction using the provided database client
|
||||
// an event needs at least an aggregate
|
||||
func (es *Eventstore) PushWithClient(ctx context.Context, client database.QueryExecuter, cmds ...Command) ([]Event, error) {
|
||||
func (es *Eventstore) PushWithClient(ctx context.Context, client database.ContextQueryExecuter, cmds ...Command) ([]Event, error) {
|
||||
if es.PushTimeout > 0 {
|
||||
var cancel func()
|
||||
ctx, cancel = context.WithTimeout(ctx, es.PushTimeout)
|
||||
@ -301,7 +301,7 @@ type Pusher interface {
|
||||
// Health checks if the connection to the storage is available
|
||||
Health(ctx context.Context) error
|
||||
// Push stores the actions
|
||||
Push(ctx context.Context, client database.QueryExecuter, commands ...Command) (_ []Event, err error)
|
||||
Push(ctx context.Context, client database.ContextQueryExecuter, commands ...Command) (_ []Event, err error)
|
||||
// Client returns the underlying database connection
|
||||
Client() *database.DB
|
||||
}
|
||||
|
@ -347,7 +347,7 @@ func (repo *testPusher) Health(ctx context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (repo *testPusher) Push(_ context.Context, _ database.QueryExecuter, commands ...Command) (events []Event, err error) {
|
||||
func (repo *testPusher) Push(_ context.Context, _ database.ContextQueryExecuter, commands ...Command) (events []Event, err error) {
|
||||
if len(repo.errs) != 0 {
|
||||
err, repo.errs = repo.errs[0], repo.errs[1:]
|
||||
return nil, err
|
||||
@ -490,6 +490,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -534,6 +535,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -585,6 +587,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -596,6 +599,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -658,6 +662,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -669,6 +674,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -682,6 +688,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -778,6 +785,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -828,6 +836,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
@ -883,6 +892,7 @@ func TestEventstore_Push(t *testing.T) {
|
||||
Type: "test.aggregate",
|
||||
ResourceOwner: "caos",
|
||||
InstanceID: "zitadel",
|
||||
Version: "v1",
|
||||
},
|
||||
Data: []byte(nil),
|
||||
User: "editorUser",
|
||||
|
@ -9,6 +9,8 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/testserver"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/cmd/initialise"
|
||||
@ -39,10 +41,17 @@ func TestMain(m *testing.M) {
|
||||
testCRDBClient = &database.DB{
|
||||
Database: new(testDB),
|
||||
}
|
||||
testCRDBClient.DB, err = sql.Open("postgres", ts.PGURL().String())
|
||||
|
||||
connConfig, err := pgxpool.ParseConfig(ts.PGURL().String())
|
||||
if err != nil {
|
||||
logging.WithFields("error", err).Fatal("unable to connect to db")
|
||||
logging.WithFields("error", err).Fatal("unable to parse db url")
|
||||
}
|
||||
connConfig.AfterConnect = new_es.RegisterEventstoreTypes
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
|
||||
if err != nil {
|
||||
logging.WithFields("error", err).Fatal("unable to create db pool")
|
||||
}
|
||||
testCRDBClient.DB = stdlib.OpenDBFromPool(pool)
|
||||
if err = testCRDBClient.Ping(); err != nil {
|
||||
logging.WithFields("error", err).Fatal("unable to ping db")
|
||||
}
|
||||
@ -55,7 +64,7 @@ func TestMain(m *testing.M) {
|
||||
clients["v3(inmemory)"] = testCRDBClient
|
||||
|
||||
if localDB, err := connectLocalhost(); err == nil {
|
||||
if err = initDB(localDB); err != nil {
|
||||
if err = initDB(context.Background(), localDB); err != nil {
|
||||
logging.WithFields("error", err).Fatal("migrations failed")
|
||||
}
|
||||
pushers["v3(singlenode)"] = new_es.NewEventstore(localDB)
|
||||
@ -69,14 +78,14 @@ func TestMain(m *testing.M) {
|
||||
ts.Stop()
|
||||
}()
|
||||
|
||||
if err = initDB(testCRDBClient); err != nil {
|
||||
if err = initDB(context.Background(), testCRDBClient); err != nil {
|
||||
logging.WithFields("error", err).Fatal("migrations failed")
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func initDB(db *database.DB) error {
|
||||
func initDB(ctx context.Context, db *database.DB) error {
|
||||
initialise.ReadStmts("cockroach")
|
||||
config := new(database.Config)
|
||||
config.SetConnector(&cockroach.Config{
|
||||
@ -85,7 +94,7 @@ func initDB(db *database.DB) error {
|
||||
},
|
||||
Database: "zitadel",
|
||||
})
|
||||
err := initialise.Init(db,
|
||||
err := initialise.Init(ctx, db,
|
||||
initialise.VerifyUser(config.Username(), ""),
|
||||
initialise.VerifyDatabase(config.DatabaseName()),
|
||||
initialise.VerifyGrant(config.DatabaseName(), config.Username()),
|
||||
@ -93,7 +102,7 @@ func initDB(db *database.DB) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = initialise.VerifyZitadel(context.Background(), db, *config)
|
||||
err = initialise.VerifyZitadel(ctx, db, *config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -3,8 +3,12 @@ package repository
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
@ -82,7 +86,9 @@ func (e *Event) Type() eventstore.EventType {
|
||||
|
||||
// Revision implements [eventstore.Event]
|
||||
func (e *Event) Revision() uint16 {
|
||||
return 0
|
||||
revision, err := strconv.ParseUint(strings.TrimPrefix(string(e.Version), "v"), 10, 16)
|
||||
logging.OnError(err).Debug("failed to parse event revision")
|
||||
return uint16(revision)
|
||||
}
|
||||
|
||||
// Sequence implements [eventstore.Event]
|
||||
|
@ -165,7 +165,7 @@ func (mr *MockPusherMockRecorder) Health(arg0 any) *gomock.Call {
|
||||
}
|
||||
|
||||
// Push mocks base method.
|
||||
func (m *MockPusher) Push(arg0 context.Context, arg1 database.QueryExecuter, arg2 ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
func (m *MockPusher) Push(arg0 context.Context, arg1 database.ContextQueryExecuter, arg2 ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
m.ctrl.T.Helper()
|
||||
varargs := []any{arg0, arg1}
|
||||
for _, a := range arg2 {
|
||||
|
@ -80,7 +80,7 @@ func (m *MockRepository) ExpectInstanceIDsError(err error) *MockRepository {
|
||||
// The call will sleep at least the amount of passed duration.
|
||||
func (m *MockRepository) ExpectPush(expectedCommands []eventstore.Command, sleep time.Duration) *MockRepository {
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.QueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
m.MockPusher.ctrl.T.Helper()
|
||||
|
||||
time.Sleep(sleep)
|
||||
@ -135,7 +135,7 @@ func (m *MockRepository) ExpectPushFailed(err error, expectedCommands []eventsto
|
||||
m.MockPusher.ctrl.T.Helper()
|
||||
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.QueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
if len(expectedCommands) != len(commands) {
|
||||
return nil, fmt.Errorf("unexpected amount of commands: want %d, got %d", len(expectedCommands), len(commands))
|
||||
}
|
||||
@ -197,7 +197,7 @@ func (e *mockEvent) CreatedAt() time.Time {
|
||||
|
||||
func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command) *MockRepository {
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.QueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, commands ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
assert.Len(m.MockPusher.ctrl.T, commands, len(expectedCommands))
|
||||
|
||||
events := make([]eventstore.Event, len(commands))
|
||||
@ -215,7 +215,7 @@ func (m *MockRepository) ExpectRandomPush(expectedCommands []eventstore.Command)
|
||||
|
||||
func (m *MockRepository) ExpectRandomPushFailed(err error, expectedEvents []eventstore.Command) *MockRepository {
|
||||
m.MockPusher.EXPECT().Push(gomock.Any(), gomock.Any(), gomock.Any()).DoAndReturn(
|
||||
func(ctx context.Context, _ database.QueryExecuter, events ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
func(ctx context.Context, _ database.ContextQueryExecuter, events ...eventstore.Command) ([]eventstore.Event, error) {
|
||||
assert.Len(m.MockPusher.ctrl.T, events, len(expectedEvents))
|
||||
return nil, err
|
||||
},
|
||||
|
@ -8,11 +8,14 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/testserver"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/cmd/initialise"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/cockroach"
|
||||
new_es "github.com/zitadel/zitadel/internal/eventstore/v3"
|
||||
)
|
||||
|
||||
var (
|
||||
@ -29,10 +32,18 @@ func TestMain(m *testing.M) {
|
||||
logging.WithFields("error", err).Fatal("unable to start db")
|
||||
}
|
||||
|
||||
testCRDBClient, err = sql.Open("postgres", ts.PGURL().String())
|
||||
connConfig, err := pgxpool.ParseConfig(ts.PGURL().String())
|
||||
if err != nil {
|
||||
logging.WithFields("error", err).Fatal("unable to connect to db")
|
||||
logging.WithFields("error", err).Fatal("unable to parse db url")
|
||||
}
|
||||
connConfig.AfterConnect = new_es.RegisterEventstoreTypes
|
||||
pool, err := pgxpool.NewWithConfig(context.Background(), connConfig)
|
||||
if err != nil {
|
||||
logging.WithFields("error", err).Fatal("unable to create db pool")
|
||||
}
|
||||
|
||||
testCRDBClient = stdlib.OpenDBFromPool(pool)
|
||||
|
||||
if err = testCRDBClient.Ping(); err != nil {
|
||||
logging.WithFields("error", err).Fatal("unable to ping db")
|
||||
}
|
||||
@ -42,14 +53,14 @@ func TestMain(m *testing.M) {
|
||||
ts.Stop()
|
||||
}()
|
||||
|
||||
if err = initDB(&database.DB{DB: testCRDBClient, Database: &cockroach.Config{Database: "zitadel"}}); err != nil {
|
||||
if err = initDB(context.Background(), &database.DB{DB: testCRDBClient, Database: &cockroach.Config{Database: "zitadel"}}); err != nil {
|
||||
logging.WithFields("error", err).Fatal("migrations failed")
|
||||
}
|
||||
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func initDB(db *database.DB) error {
|
||||
func initDB(ctx context.Context, db *database.DB) error {
|
||||
config := new(database.Config)
|
||||
config.SetConnector(&cockroach.Config{User: cockroach.User{Username: "zitadel"}, Database: "zitadel"})
|
||||
|
||||
@ -57,7 +68,7 @@ func initDB(db *database.DB) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err := initialise.Init(db,
|
||||
err := initialise.Init(ctx, db,
|
||||
initialise.VerifyUser(config.Username(), ""),
|
||||
initialise.VerifyDatabase(config.DatabaseName()),
|
||||
initialise.VerifyGrant(config.DatabaseName(), config.Username()),
|
||||
|
@ -1,11 +1,13 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@ -14,33 +16,98 @@ var (
|
||||
_ eventstore.Event = (*event)(nil)
|
||||
)
|
||||
|
||||
type command struct {
|
||||
InstanceID string
|
||||
AggregateType string
|
||||
AggregateID string
|
||||
CommandType string
|
||||
Revision uint16
|
||||
Payload Payload
|
||||
Creator string
|
||||
Owner string
|
||||
}
|
||||
|
||||
func (c *command) Aggregate() *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
ID: c.AggregateID,
|
||||
Type: eventstore.AggregateType(c.AggregateType),
|
||||
ResourceOwner: c.Owner,
|
||||
InstanceID: c.InstanceID,
|
||||
Version: eventstore.Version("v" + strconv.Itoa(int(c.Revision))),
|
||||
}
|
||||
}
|
||||
|
||||
type event struct {
|
||||
aggregate *eventstore.Aggregate
|
||||
creator string
|
||||
revision uint16
|
||||
typ eventstore.EventType
|
||||
command *command
|
||||
createdAt time.Time
|
||||
sequence uint64
|
||||
position float64
|
||||
payload Payload
|
||||
}
|
||||
|
||||
func commandToEvent(sequence *latestSequence, command eventstore.Command) (_ *event, err error) {
|
||||
// TODO: remove on v3
|
||||
func commandToEventOld(sequence *latestSequence, cmd eventstore.Command) (_ *event, err error) {
|
||||
var payload Payload
|
||||
if command.Payload() != nil {
|
||||
payload, err = json.Marshal(command.Payload())
|
||||
if cmd.Payload() != nil {
|
||||
payload, err = json.Marshal(cmd.Payload())
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("marshal payload failed")
|
||||
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
|
||||
}
|
||||
}
|
||||
return &event{
|
||||
aggregate: sequence.aggregate,
|
||||
creator: command.Creator(),
|
||||
revision: command.Revision(),
|
||||
typ: command.Type(),
|
||||
payload: payload,
|
||||
sequence: sequence.sequence,
|
||||
command: &command{
|
||||
InstanceID: sequence.aggregate.InstanceID,
|
||||
AggregateType: string(sequence.aggregate.Type),
|
||||
AggregateID: sequence.aggregate.ID,
|
||||
CommandType: string(cmd.Type()),
|
||||
Revision: cmd.Revision(),
|
||||
Payload: payload,
|
||||
Creator: cmd.Creator(),
|
||||
Owner: sequence.aggregate.ResourceOwner,
|
||||
},
|
||||
sequence: sequence.sequence,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func commandsToEvents(ctx context.Context, cmds []eventstore.Command) (_ []eventstore.Event, _ []*command, err error) {
|
||||
events := make([]eventstore.Event, len(cmds))
|
||||
commands := make([]*command, len(cmds))
|
||||
for i, cmd := range cmds {
|
||||
if cmd.Aggregate().InstanceID == "" {
|
||||
cmd.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID()
|
||||
}
|
||||
events[i], err = commandToEvent(cmd)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
commands[i] = events[i].(*event).command
|
||||
}
|
||||
return events, commands, nil
|
||||
}
|
||||
|
||||
func commandToEvent(cmd eventstore.Command) (_ eventstore.Event, err error) {
|
||||
var payload Payload
|
||||
if cmd.Payload() != nil {
|
||||
payload, err = json.Marshal(cmd.Payload())
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("marshal payload failed")
|
||||
return nil, zerrors.ThrowInternal(err, "V3-MInPK", "Errors.Internal")
|
||||
}
|
||||
}
|
||||
|
||||
command := &command{
|
||||
InstanceID: cmd.Aggregate().InstanceID,
|
||||
AggregateType: string(cmd.Aggregate().Type),
|
||||
AggregateID: cmd.Aggregate().ID,
|
||||
CommandType: string(cmd.Type()),
|
||||
Revision: cmd.Revision(),
|
||||
Payload: payload,
|
||||
Creator: cmd.Creator(),
|
||||
Owner: cmd.Aggregate().ResourceOwner,
|
||||
}
|
||||
|
||||
return &event{
|
||||
command: command,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@ -56,22 +123,22 @@ func (e *event) EditorUser() string {
|
||||
|
||||
// Aggregate implements [eventstore.Event]
|
||||
func (e *event) Aggregate() *eventstore.Aggregate {
|
||||
return e.aggregate
|
||||
return e.command.Aggregate()
|
||||
}
|
||||
|
||||
// Creator implements [eventstore.Event]
|
||||
func (e *event) Creator() string {
|
||||
return e.creator
|
||||
return e.command.Creator
|
||||
}
|
||||
|
||||
// Revision implements [eventstore.Event]
|
||||
func (e *event) Revision() uint16 {
|
||||
return e.revision
|
||||
return e.command.Revision
|
||||
}
|
||||
|
||||
// Type implements [eventstore.Event]
|
||||
func (e *event) Type() eventstore.EventType {
|
||||
return e.typ
|
||||
return eventstore.EventType(e.command.CommandType)
|
||||
}
|
||||
|
||||
// CreatedAt implements [eventstore.Event]
|
||||
@ -91,10 +158,10 @@ func (e *event) Position() float64 {
|
||||
|
||||
// Unmarshal implements [eventstore.Event]
|
||||
func (e *event) Unmarshal(ptr any) error {
|
||||
if len(e.payload) == 0 {
|
||||
if len(e.command.Payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
if err := json.Unmarshal(e.payload, ptr); err != nil {
|
||||
if err := json.Unmarshal(e.command.Payload, ptr); err != nil {
|
||||
return zerrors.ThrowInternal(err, "V3-u8qVo", "Errors.Internal")
|
||||
}
|
||||
|
||||
@ -103,5 +170,5 @@ func (e *event) Unmarshal(ptr any) error {
|
||||
|
||||
// DataAsBytes implements [eventstore.Event]
|
||||
func (e *event) DataAsBytes() []byte {
|
||||
return e.payload
|
||||
return e.command.Payload
|
||||
}
|
||||
|
@ -1,16 +1,122 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func Test_commandToEvent(t *testing.T) {
|
||||
payload := struct {
|
||||
ID string
|
||||
}{
|
||||
ID: "test",
|
||||
}
|
||||
payloadMarshalled, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal of payload failed: %v", err)
|
||||
}
|
||||
type args struct {
|
||||
command eventstore.Command
|
||||
}
|
||||
type want struct {
|
||||
event *event
|
||||
err func(t *testing.T, err error)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
nil,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "struct payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: payload,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "pointer payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: &payload,
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
event: mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
).(*event),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid payload",
|
||||
args: args{
|
||||
command: &mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: func() {},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
err: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
if tt.want.err == nil {
|
||||
tt.want.err = func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := commandToEvent(tt.args.command)
|
||||
|
||||
tt.want.err(t, err)
|
||||
if tt.want.event == nil {
|
||||
assert.Nil(t, got)
|
||||
return
|
||||
}
|
||||
assert.Equal(t, tt.want.event, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_commandToEventOld(t *testing.T) {
|
||||
payload := struct {
|
||||
ID string
|
||||
}{
|
||||
@ -119,10 +225,258 @@ func Test_commandToEvent(t *testing.T) {
|
||||
}
|
||||
}
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := commandToEvent(tt.args.sequence, tt.args.command)
|
||||
got, err := commandToEventOld(tt.args.sequence, tt.args.command)
|
||||
|
||||
tt.want.err(t, err)
|
||||
assert.Equal(t, tt.want.event, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_commandsToEvents(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
payload := struct {
|
||||
ID string
|
||||
}{
|
||||
ID: "test",
|
||||
}
|
||||
payloadMarshalled, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
t.Fatalf("marshal of payload failed: %v", err)
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
cmds []eventstore.Command
|
||||
}
|
||||
type want struct {
|
||||
events []eventstore.Event
|
||||
commands []*command
|
||||
err func(t *testing.T, err error)
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want want
|
||||
}{
|
||||
{
|
||||
name: "no commands",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: nil,
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{},
|
||||
commands: []*command{},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single command no payload",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
Owner: "ro",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: nil,
|
||||
Creator: "creator",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single command no instance id",
|
||||
args: args{
|
||||
ctx: authz.WithInstanceID(ctx, "instance from ctx"),
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregateWithInstance("V3-Red9I", ""),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregateWithInstance("V3-Red9I", "instance from ctx"),
|
||||
0,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance from ctx",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
Owner: "ro",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: nil,
|
||||
Creator: "creator",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "single command with payload",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: payload,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
Owner: "ro",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: payloadMarshalled,
|
||||
Creator: "creator",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "multiple commands",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: payload,
|
||||
},
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: nil,
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: []eventstore.Event{
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
payloadMarshalled,
|
||||
),
|
||||
mockEvent(
|
||||
mockAggregate("V3-Red9I"),
|
||||
0,
|
||||
nil,
|
||||
),
|
||||
},
|
||||
commands: []*command{
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: payloadMarshalled,
|
||||
Creator: "creator",
|
||||
Owner: "ro",
|
||||
},
|
||||
{
|
||||
InstanceID: "instance",
|
||||
AggregateType: "type",
|
||||
AggregateID: "V3-Red9I",
|
||||
CommandType: "event.type",
|
||||
Revision: 1,
|
||||
Payload: nil,
|
||||
Creator: "creator",
|
||||
Owner: "ro",
|
||||
},
|
||||
},
|
||||
err: func(t *testing.T, err error) {
|
||||
require.NoError(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "invalid command",
|
||||
args: args{
|
||||
ctx: ctx,
|
||||
cmds: []eventstore.Command{
|
||||
&mockCommand{
|
||||
aggregate: mockAggregate("V3-Red9I"),
|
||||
payload: func() {},
|
||||
},
|
||||
},
|
||||
},
|
||||
want: want{
|
||||
events: nil,
|
||||
commands: nil,
|
||||
err: func(t *testing.T, err error) {
|
||||
assert.Error(t, err)
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gotEvents, gotCommands, err := commandsToEvents(tt.args.ctx, tt.args.cmds)
|
||||
|
||||
tt.want.err(t, err)
|
||||
assert.Equal(t, tt.want.events, gotEvents)
|
||||
require.Len(t, gotCommands, len(tt.want.commands))
|
||||
for i, wantCommand := range tt.want.commands {
|
||||
assertCommand(t, wantCommand, gotCommands[i])
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func assertCommand(t *testing.T, want, got *command) {
|
||||
t.Helper()
|
||||
assert.Equal(t, want.CommandType, got.CommandType)
|
||||
assert.Equal(t, want.Payload, got.Payload)
|
||||
assert.Equal(t, want.Creator, got.Creator)
|
||||
assert.Equal(t, want.Owner, got.Owner)
|
||||
assert.Equal(t, want.AggregateID, got.AggregateID)
|
||||
assert.Equal(t, want.AggregateType, got.AggregateType)
|
||||
assert.Equal(t, want.InstanceID, got.InstanceID)
|
||||
assert.Equal(t, want.Revision, got.Revision)
|
||||
}
|
||||
|
@ -2,11 +2,26 @@ package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/DATA-DOG/go-sqlmock"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
"github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
func init() {
|
||||
dialect.RegisterAfterConnect(RegisterEventstoreTypes)
|
||||
}
|
||||
|
||||
var (
|
||||
// pushPlaceholderFmt defines how data are inserted into the events table
|
||||
pushPlaceholderFmt string
|
||||
@ -20,6 +35,123 @@ type Eventstore struct {
|
||||
client *database.DB
|
||||
}
|
||||
|
||||
var (
|
||||
textType = &pgtype.Type{
|
||||
Name: "text",
|
||||
OID: pgtype.TextOID,
|
||||
Codec: pgtype.TextCodec{},
|
||||
}
|
||||
commandType = &pgtype.Type{
|
||||
Codec: &pgtype.CompositeCodec{
|
||||
Fields: []pgtype.CompositeCodecField{
|
||||
{
|
||||
Name: "instance_id",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "aggregate_type",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "aggregate_id",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "command_type",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "revision",
|
||||
Type: &pgtype.Type{
|
||||
Name: "int2",
|
||||
OID: pgtype.Int2OID,
|
||||
Codec: pgtype.Int2Codec{},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "payload",
|
||||
Type: &pgtype.Type{
|
||||
Name: "jsonb",
|
||||
OID: pgtype.JSONBOID,
|
||||
Codec: &pgtype.JSONBCodec{
|
||||
Marshal: json.Marshal,
|
||||
Unmarshal: json.Unmarshal,
|
||||
},
|
||||
},
|
||||
},
|
||||
{
|
||||
Name: "creator",
|
||||
Type: textType,
|
||||
},
|
||||
{
|
||||
Name: "owner",
|
||||
Type: textType,
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
commandArrayCodec = &pgtype.Type{
|
||||
Codec: &pgtype.ArrayCodec{
|
||||
ElementType: commandType,
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
var typeMu sync.Mutex
|
||||
|
||||
func RegisterEventstoreTypes(ctx context.Context, conn *pgx.Conn) error {
|
||||
// conn.TypeMap is not thread safe
|
||||
typeMu.Lock()
|
||||
defer typeMu.Unlock()
|
||||
|
||||
m := conn.TypeMap()
|
||||
|
||||
var cmd *command
|
||||
if _, ok := m.TypeForValue(cmd); ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
|
||||
err := conn.QueryRow(ctx, "select oid, typarray from pg_type where typname = $1 and typnamespace = (select oid from pg_namespace where nspname = $2)", "command", "eventstore").
|
||||
Scan(&commandType.OID, &commandArrayCodec.OID)
|
||||
if err != nil {
|
||||
logging.WithError(err).Debug("failed to get oid for command type")
|
||||
return nil
|
||||
}
|
||||
if commandType.OID == 0 || commandArrayCodec.OID == 0 {
|
||||
logging.Debug("oid for command type not found")
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
m.RegisterTypes([]*pgtype.Type{
|
||||
{
|
||||
Name: "eventstore.command",
|
||||
Codec: commandType.Codec,
|
||||
OID: commandType.OID,
|
||||
},
|
||||
{
|
||||
Name: "command",
|
||||
Codec: commandType.Codec,
|
||||
OID: commandType.OID,
|
||||
},
|
||||
{
|
||||
Name: "eventstore._command",
|
||||
Codec: commandArrayCodec.Codec,
|
||||
OID: commandArrayCodec.OID,
|
||||
},
|
||||
{
|
||||
Name: "_command",
|
||||
Codec: commandArrayCodec.Codec,
|
||||
OID: commandArrayCodec.OID,
|
||||
},
|
||||
})
|
||||
dialect.RegisterDefaultPgTypeVariants[command](m, "eventstore.command", "eventstore._command")
|
||||
dialect.RegisterDefaultPgTypeVariants[command](m, "command", "_command")
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Client implements the [eventstore.Pusher]
|
||||
func (es *Eventstore) Client() *database.DB {
|
||||
return es.client
|
||||
@ -41,3 +173,45 @@ func NewEventstore(client *database.DB) *Eventstore {
|
||||
func (es *Eventstore) Health(ctx context.Context) error {
|
||||
return es.client.PingContext(ctx)
|
||||
}
|
||||
|
||||
var errTypesNotFound = errors.New("types not found")
|
||||
|
||||
func CheckExecutionPlan(ctx context.Context, conn *sql.Conn) error {
|
||||
return conn.Raw(func(driverConn any) error {
|
||||
if _, ok := driverConn.(sqlmock.SqlmockCommon); ok {
|
||||
return nil
|
||||
}
|
||||
conn, ok := driverConn.(*stdlib.Conn)
|
||||
if !ok {
|
||||
return errTypesNotFound
|
||||
}
|
||||
|
||||
return RegisterEventstoreTypes(ctx, conn.Conn())
|
||||
})
|
||||
}
|
||||
|
||||
func (es *Eventstore) pushTx(ctx context.Context, client database.ContextQueryExecuter) (tx database.Tx, deferrable func(err error) error, err error) {
|
||||
tx, ok := client.(database.Tx)
|
||||
if ok {
|
||||
return tx, nil, nil
|
||||
}
|
||||
beginner, ok := client.(database.Beginner)
|
||||
if !ok {
|
||||
beginner = es.client
|
||||
}
|
||||
|
||||
isolationLevel := sql.LevelReadCommitted
|
||||
// cockroach requires serializable to execute the push function
|
||||
// because we use [cluster_logical_timestamp()](https://www.cockroachlabs.com/docs/stable/functions-and-operators#system-info-functions)
|
||||
if es.client.Type() == "cockroach" {
|
||||
isolationLevel = sql.LevelSerializable
|
||||
}
|
||||
tx, err = beginner.BeginTx(ctx, &sql.TxOptions{
|
||||
Isolation: isolationLevel,
|
||||
ReadOnly: false,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return tx, func(err error) error { return database.CloseTransaction(tx, err) }, nil
|
||||
}
|
||||
|
@ -143,7 +143,7 @@ func buildSearchCondition(builder *strings.Builder, index int, conditions map[ev
|
||||
return args
|
||||
}
|
||||
|
||||
func handleFieldCommands(ctx context.Context, tx database.Tx, commands []eventstore.Command) error {
|
||||
func (es *Eventstore) handleFieldCommands(ctx context.Context, tx database.Tx, commands []eventstore.Command) error {
|
||||
for _, command := range commands {
|
||||
if len(command.Fields()) > 0 {
|
||||
if err := handleFieldOperations(ctx, tx, command.Fields()); err != nil {
|
||||
|
@ -48,12 +48,17 @@ func (e *mockCommand) Fields() []*eventstore.FieldOperation {
|
||||
|
||||
func mockEvent(aggregate *eventstore.Aggregate, sequence uint64, payload Payload) eventstore.Event {
|
||||
return &event{
|
||||
aggregate: aggregate,
|
||||
creator: "creator",
|
||||
revision: 1,
|
||||
typ: "event.type",
|
||||
sequence: sequence,
|
||||
payload: payload,
|
||||
command: &command{
|
||||
InstanceID: aggregate.InstanceID,
|
||||
AggregateType: string(aggregate.Type),
|
||||
AggregateID: aggregate.ID,
|
||||
Owner: aggregate.ResourceOwner,
|
||||
Creator: "creator",
|
||||
Revision: 1,
|
||||
CommandType: "event.type",
|
||||
Payload: payload,
|
||||
},
|
||||
sequence: sequence,
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,3 +71,13 @@ func mockAggregate(id string) *eventstore.Aggregate {
|
||||
Version: "v1",
|
||||
}
|
||||
}
|
||||
|
||||
func mockAggregateWithInstance(id, instance string) *eventstore.Aggregate {
|
||||
return &eventstore.Aggregate{
|
||||
ID: id,
|
||||
InstanceID: instance,
|
||||
Type: "type",
|
||||
ResourceOwner: "ro",
|
||||
Version: "v1",
|
||||
}
|
||||
}
|
||||
|
@ -4,83 +4,58 @@ import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/database/dialect"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
var appNamePrefix = dialect.DBPurposeEventPusher.AppName() + "_"
|
||||
|
||||
var pushTxOpts = &sql.TxOptions{
|
||||
Isolation: sql.LevelReadCommitted,
|
||||
ReadOnly: false,
|
||||
}
|
||||
|
||||
func (es *Eventstore) Push(ctx context.Context, client database.QueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
|
||||
func (es *Eventstore) Push(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var tx database.Tx
|
||||
events, err = es.writeCommands(ctx, client, commands)
|
||||
if isSetupNotExecutedError(err) {
|
||||
return es.pushWithoutFunc(ctx, client, commands...)
|
||||
}
|
||||
|
||||
return events, err
|
||||
}
|
||||
|
||||
func (es *Eventstore) writeCommands(ctx context.Context, client database.ContextQueryExecuter, commands []eventstore.Command) (_ []eventstore.Event, err error) {
|
||||
var conn *sql.Conn
|
||||
switch c := client.(type) {
|
||||
case database.Tx:
|
||||
tx = c
|
||||
case database.Client:
|
||||
// We cannot use READ COMMITTED on CockroachDB because we use cluster_logical_timestamp() which is not supported in this isolation level
|
||||
var opts *sql.TxOptions
|
||||
if es.client.Database.Type() == "postgres" {
|
||||
opts = pushTxOpts
|
||||
}
|
||||
tx, err = c.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = database.CloseTransaction(tx, err)
|
||||
}()
|
||||
default:
|
||||
// We cannot use READ COMMITTED on CockroachDB because we use cluster_logical_timestamp() which is not supported in this isolation level
|
||||
var opts *sql.TxOptions
|
||||
if es.client.Database.Type() == "postgres" {
|
||||
opts = pushTxOpts
|
||||
}
|
||||
tx, err = es.client.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = database.CloseTransaction(tx, err)
|
||||
}()
|
||||
conn, err = c.Conn(ctx)
|
||||
case nil:
|
||||
conn, err = es.client.Conn(ctx)
|
||||
client = conn
|
||||
}
|
||||
// tx is not closed because [crdb.ExecuteInTx] takes care of that
|
||||
var (
|
||||
sequences []*latestSequence
|
||||
)
|
||||
|
||||
// needs to be set like this because psql complains about parameters in the SET statement
|
||||
_, err = tx.ExecContext(ctx, "SET application_name = '"+appNamePrefix+authz.GetInstance(ctx).InstanceID()+"'")
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("failed to set application name")
|
||||
return nil, err
|
||||
}
|
||||
|
||||
sequences, err = latestSequences(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if conn != nil {
|
||||
defer conn.Close()
|
||||
}
|
||||
|
||||
events, err = insertEvents(ctx, tx, sequences, commands)
|
||||
tx, close, err := es.pushTx(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if close != nil {
|
||||
defer func() {
|
||||
err = close(err)
|
||||
}()
|
||||
}
|
||||
|
||||
events, err := writeEvents(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -89,16 +64,7 @@ func (es *Eventstore) Push(ctx context.Context, client database.QueryExecuter, c
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// CockroachDB by default does not allow multiple modifications of the same table using ON CONFLICT
|
||||
// Thats why we enable it manually
|
||||
if es.client.Type() == "cockroach" {
|
||||
_, err = tx.Exec("SET enable_multiple_modifications_of_table = on")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = handleFieldCommands(ctx, tx, commands)
|
||||
err = es.handleFieldCommands(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -106,120 +72,30 @@ func (es *Eventstore) Push(ctx context.Context, client database.QueryExecuter, c
|
||||
return events, nil
|
||||
}
|
||||
|
||||
//go:embed push.sql
|
||||
var pushStmt string
|
||||
func writeEvents(ctx context.Context, tx database.Tx, commands []eventstore.Command) (_ []eventstore.Event, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
func insertEvents(ctx context.Context, tx database.Tx, sequences []*latestSequence, commands []eventstore.Command) ([]eventstore.Event, error) {
|
||||
events, placeholders, args, err := mapCommands(commands, sequences)
|
||||
events, cmds, err := commandsToEvents(ctx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := tx.QueryContext(ctx, fmt.Sprintf(pushStmt, strings.Join(placeholders, ", ")), args...)
|
||||
rows, err := tx.QueryContext(ctx, `select owner, created_at, "sequence", position from eventstore.push($1::eventstore.command[])`, cmds)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
err = rows.Scan(&events[i].(*event).createdAt, &events[i].(*event).position)
|
||||
err = rows.Scan(&events[i].(*event).command.Owner, &events[i].(*event).createdAt, &events[i].(*event).sequence, &events[i].(*event).position)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("failed to scan events")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(err, &pgErr) {
|
||||
// Check if push tries to write an event just written
|
||||
// by another transaction
|
||||
if pgErr.Code == "40001" {
|
||||
// TODO: @livio-a should we return the parent or not?
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "V3-p5xAn", "Errors.AlreadyExists")
|
||||
}
|
||||
}
|
||||
logging.WithError(rows.Err()).Warn("failed to push events")
|
||||
return nil, zerrors.ThrowInternal(err, "V3-VGnZY", "Errors.Internal")
|
||||
if err = rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
const argsPerCommand = 10
|
||||
|
||||
func mapCommands(commands []eventstore.Command, sequences []*latestSequence) (events []eventstore.Event, placeholders []string, args []any, err error) {
|
||||
events = make([]eventstore.Event, len(commands))
|
||||
args = make([]any, 0, len(commands)*argsPerCommand)
|
||||
placeholders = make([]string, len(commands))
|
||||
|
||||
for i, command := range commands {
|
||||
sequence := searchSequenceByCommand(sequences, command)
|
||||
if sequence == nil {
|
||||
logging.WithFields(
|
||||
"aggType", command.Aggregate().Type,
|
||||
"aggID", command.Aggregate().ID,
|
||||
"instance", command.Aggregate().InstanceID,
|
||||
).Panic("no sequence found")
|
||||
// added return for linting
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
sequence.sequence++
|
||||
|
||||
events[i], err = commandToEvent(sequence, command)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
placeholders[i] = fmt.Sprintf(pushPlaceholderFmt,
|
||||
i*argsPerCommand+1,
|
||||
i*argsPerCommand+2,
|
||||
i*argsPerCommand+3,
|
||||
i*argsPerCommand+4,
|
||||
i*argsPerCommand+5,
|
||||
i*argsPerCommand+6,
|
||||
i*argsPerCommand+7,
|
||||
i*argsPerCommand+8,
|
||||
i*argsPerCommand+9,
|
||||
i*argsPerCommand+10,
|
||||
)
|
||||
|
||||
revision, err := strconv.Atoi(strings.TrimPrefix(string(events[i].(*event).aggregate.Version), "v"))
|
||||
if err != nil {
|
||||
return nil, nil, nil, zerrors.ThrowInternal(err, "V3-JoZEp", "Errors.Internal")
|
||||
}
|
||||
args = append(args,
|
||||
events[i].(*event).aggregate.InstanceID,
|
||||
events[i].(*event).aggregate.ResourceOwner,
|
||||
events[i].(*event).aggregate.Type,
|
||||
events[i].(*event).aggregate.ID,
|
||||
revision,
|
||||
events[i].(*event).creator,
|
||||
events[i].(*event).typ,
|
||||
events[i].(*event).payload,
|
||||
events[i].(*event).sequence,
|
||||
i,
|
||||
)
|
||||
}
|
||||
|
||||
return events, placeholders, args, nil
|
||||
}
|
||||
|
||||
type transaction struct {
|
||||
database.Tx
|
||||
}
|
||||
|
||||
var _ crdb.Tx = (*transaction)(nil)
|
||||
|
||||
func (t *transaction) Exec(ctx context.Context, query string, args ...interface{}) error {
|
||||
_, err := t.Tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *transaction) Commit(ctx context.Context) error {
|
||||
return t.Tx.Commit()
|
||||
}
|
||||
|
||||
func (t *transaction) Rollback(ctx context.Context) error {
|
||||
return t.Tx.Rollback()
|
||||
}
|
||||
|
@ -70,11 +70,11 @@ func Test_mapCommands(t *testing.T) {
|
||||
args: []any{
|
||||
"instance",
|
||||
"ro",
|
||||
eventstore.AggregateType("type"),
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
1,
|
||||
uint16(1),
|
||||
"creator",
|
||||
eventstore.EventType("event.type"),
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(1),
|
||||
0,
|
||||
@ -121,22 +121,22 @@ func Test_mapCommands(t *testing.T) {
|
||||
// first event
|
||||
"instance",
|
||||
"ro",
|
||||
eventstore.AggregateType("type"),
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
1,
|
||||
uint16(1),
|
||||
"creator",
|
||||
eventstore.EventType("event.type"),
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(6),
|
||||
0,
|
||||
// second event
|
||||
"instance",
|
||||
"ro",
|
||||
eventstore.AggregateType("type"),
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
1,
|
||||
uint16(1),
|
||||
"creator",
|
||||
eventstore.EventType("event.type"),
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(7),
|
||||
1,
|
||||
@ -187,22 +187,22 @@ func Test_mapCommands(t *testing.T) {
|
||||
// first event
|
||||
"instance",
|
||||
"ro",
|
||||
eventstore.AggregateType("type"),
|
||||
"type",
|
||||
"V3-VEIvq",
|
||||
1,
|
||||
uint16(1),
|
||||
"creator",
|
||||
eventstore.EventType("event.type"),
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(6),
|
||||
0,
|
||||
// second event
|
||||
"instance",
|
||||
"ro",
|
||||
eventstore.AggregateType("type"),
|
||||
"type",
|
||||
"V3-IT6VN",
|
||||
1,
|
||||
uint16(1),
|
||||
"creator",
|
||||
eventstore.EventType("event.type"),
|
||||
"event.type",
|
||||
Payload(nil),
|
||||
uint64(1),
|
||||
1,
|
||||
|
183
internal/eventstore/v3/push_without_func.go
Normal file
183
internal/eventstore/v3/push_without_func.go
Normal file
@ -0,0 +1,183 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/cockroachdb/cockroach-go/v2/crdb"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type transaction struct {
|
||||
database.Tx
|
||||
}
|
||||
|
||||
var _ crdb.Tx = (*transaction)(nil)
|
||||
|
||||
func (t *transaction) Exec(ctx context.Context, query string, args ...interface{}) error {
|
||||
_, err := t.Tx.ExecContext(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (t *transaction) Commit(ctx context.Context) error {
|
||||
return t.Tx.Commit()
|
||||
}
|
||||
|
||||
func (t *transaction) Rollback(ctx context.Context) error {
|
||||
return t.Tx.Rollback()
|
||||
}
|
||||
|
||||
// checks whether the error is caused because setup step 39 was not executed
|
||||
func isSetupNotExecutedError(err error) bool {
|
||||
if err == nil {
|
||||
return false
|
||||
}
|
||||
var pgErr *pgconn.PgError
|
||||
if errors.As(err, &pgErr) {
|
||||
return (pgErr.Code == "42704" && strings.Contains(pgErr.Message, "eventstore.command")) ||
|
||||
(pgErr.Code == "42883" && strings.Contains(pgErr.Message, "eventstore.push"))
|
||||
}
|
||||
return errors.Is(err, errTypesNotFound)
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed push.sql
|
||||
pushStmt string
|
||||
)
|
||||
|
||||
// pushWithoutFunc implements pushing events before setup step 39 was introduced.
|
||||
// TODO: remove with v3
|
||||
func (es *Eventstore) pushWithoutFunc(ctx context.Context, client database.ContextQueryExecuter, commands ...eventstore.Command) (events []eventstore.Event, err error) {
|
||||
tx, closeTx, err := es.pushTx(ctx, client)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
err = closeTx(err)
|
||||
}()
|
||||
|
||||
// tx is not closed because [crdb.ExecuteInTx] takes care of that
|
||||
var (
|
||||
sequences []*latestSequence
|
||||
)
|
||||
sequences, err = latestSequences(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
events, err = es.writeEventsOld(ctx, tx, sequences, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = handleUniqueConstraints(ctx, tx, commands); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = es.handleFieldCommands(ctx, tx, commands)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
func (es *Eventstore) writeEventsOld(ctx context.Context, tx database.Tx, sequences []*latestSequence, commands []eventstore.Command) ([]eventstore.Event, error) {
|
||||
events, placeholders, args, err := mapCommands(commands, sequences)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := tx.QueryContext(ctx, fmt.Sprintf(pushStmt, strings.Join(placeholders, ", ")), args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for i := 0; rows.Next(); i++ {
|
||||
err = rows.Scan(&events[i].(*event).createdAt, &events[i].(*event).position)
|
||||
if err != nil {
|
||||
logging.WithError(err).Warn("failed to scan events")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if err := rows.Err(); err != nil {
|
||||
pgErr := new(pgconn.PgError)
|
||||
if errors.As(err, &pgErr) {
|
||||
// Check if push tries to write an event just written
|
||||
// by another transaction
|
||||
if pgErr.Code == "40001" {
|
||||
// TODO: @livio-a should we return the parent or not?
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "V3-p5xAn", "Errors.AlreadyExists")
|
||||
}
|
||||
}
|
||||
logging.WithError(rows.Err()).Warn("failed to push events")
|
||||
return nil, zerrors.ThrowInternal(err, "V3-VGnZY", "Errors.Internal")
|
||||
}
|
||||
|
||||
return events, nil
|
||||
}
|
||||
|
||||
const argsPerCommand = 10
|
||||
|
||||
func mapCommands(commands []eventstore.Command, sequences []*latestSequence) (events []eventstore.Event, placeholders []string, args []any, err error) {
|
||||
events = make([]eventstore.Event, len(commands))
|
||||
args = make([]any, 0, len(commands)*argsPerCommand)
|
||||
placeholders = make([]string, len(commands))
|
||||
|
||||
for i, command := range commands {
|
||||
sequence := searchSequenceByCommand(sequences, command)
|
||||
if sequence == nil {
|
||||
logging.WithFields(
|
||||
"aggType", command.Aggregate().Type,
|
||||
"aggID", command.Aggregate().ID,
|
||||
"instance", command.Aggregate().InstanceID,
|
||||
).Panic("no sequence found")
|
||||
// added return for linting
|
||||
return nil, nil, nil, nil
|
||||
}
|
||||
sequence.sequence++
|
||||
|
||||
events[i], err = commandToEventOld(sequence, command)
|
||||
if err != nil {
|
||||
return nil, nil, nil, err
|
||||
}
|
||||
|
||||
placeholders[i] = fmt.Sprintf(pushPlaceholderFmt,
|
||||
i*argsPerCommand+1,
|
||||
i*argsPerCommand+2,
|
||||
i*argsPerCommand+3,
|
||||
i*argsPerCommand+4,
|
||||
i*argsPerCommand+5,
|
||||
i*argsPerCommand+6,
|
||||
i*argsPerCommand+7,
|
||||
i*argsPerCommand+8,
|
||||
i*argsPerCommand+9,
|
||||
i*argsPerCommand+10,
|
||||
)
|
||||
|
||||
args = append(args,
|
||||
events[i].(*event).command.InstanceID,
|
||||
events[i].(*event).command.Owner,
|
||||
events[i].(*event).command.AggregateType,
|
||||
events[i].(*event).command.AggregateID,
|
||||
events[i].(*event).command.Revision,
|
||||
events[i].(*event).command.Creator,
|
||||
events[i].(*event).command.CommandType,
|
||||
events[i].(*event).command.Payload,
|
||||
events[i].(*event).sequence,
|
||||
i,
|
||||
)
|
||||
}
|
||||
|
||||
return events, placeholders, args, nil
|
||||
}
|
@ -12,6 +12,7 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
@ -24,7 +25,10 @@ var (
|
||||
addConstraintStmt string
|
||||
)
|
||||
|
||||
func handleUniqueConstraints(ctx context.Context, tx database.Tx, commands []eventstore.Command) error {
|
||||
func handleUniqueConstraints(ctx context.Context, tx database.Tx, commands []eventstore.Command) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
deletePlaceholders := make([]string, 0)
|
||||
deleteArgs := make([]any, 0)
|
||||
|
||||
|
@ -1,15 +1,19 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/rsa"
|
||||
"crypto/tls"
|
||||
"crypto/x509"
|
||||
"encoding/xml"
|
||||
"io"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
"golang.org/x/text/encoding/ianaindex"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
@ -104,6 +108,41 @@ func WithEntityID(entityID string) ProviderOpts {
|
||||
}
|
||||
}
|
||||
|
||||
// ParseMetadata parses the metadata with the provided XML encoding and returns the EntityDescriptor
|
||||
func ParseMetadata(metadata []byte) (*saml.EntityDescriptor, error) {
|
||||
entityDescriptor := new(saml.EntityDescriptor)
|
||||
reader := bytes.NewReader(metadata)
|
||||
decoder := xml.NewDecoder(reader)
|
||||
decoder.CharsetReader = func(charset string, reader io.Reader) (io.Reader, error) {
|
||||
enc, err := ianaindex.IANA.Encoding(charset)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return enc.NewDecoder().Reader(reader), nil
|
||||
}
|
||||
if err := decoder.Decode(entityDescriptor); err != nil {
|
||||
if err.Error() == "expected element type <EntityDescriptor> but have <EntitiesDescriptor>" {
|
||||
// reset reader to start of metadata so we can try to parse it as an EntitiesDescriptor
|
||||
if _, err := reader.Seek(0, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
entities := &saml.EntitiesDescriptor{}
|
||||
if err := decoder.Decode(entities); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
for i, e := range entities.EntityDescriptors {
|
||||
if len(e.IDPSSODescriptors) > 0 {
|
||||
return &entities.EntityDescriptors[i], nil
|
||||
}
|
||||
}
|
||||
return nil, zerrors.ThrowInternal(nil, "SAML-Ejoi3r2", "no entity found with IDPSSODescriptor")
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
return entityDescriptor, nil
|
||||
}
|
||||
|
||||
func New(
|
||||
name string,
|
||||
rootURLStr string,
|
||||
@ -112,8 +151,8 @@ func New(
|
||||
key []byte,
|
||||
options ...ProviderOpts,
|
||||
) (*Provider, error) {
|
||||
entityDescriptor := new(saml.EntityDescriptor)
|
||||
if err := xml.Unmarshal(metadata, entityDescriptor); err != nil {
|
||||
entityDescriptor, err := ParseMetadata(metadata)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
keyPair, err := tls.X509KeyPair(certificate, key)
|
||||
@ -180,6 +219,7 @@ func (p *Provider) GetSP() (*samlsp.Middleware, error) {
|
||||
if p.binding != "" {
|
||||
sp.Binding = p.binding
|
||||
}
|
||||
sp.ServiceProvider.MetadataValidDuration = time.Until(sp.ServiceProvider.Certificate.NotAfter)
|
||||
return sp, nil
|
||||
}
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
package saml
|
||||
|
||||
import (
|
||||
"encoding/xml"
|
||||
"testing"
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
@ -10,6 +11,7 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/idp/providers/saml/requesttracker"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestProvider_Options(t *testing.T) {
|
||||
@ -170,3 +172,111 @@ func TestProvider_Options(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseMetadata(t *testing.T) {
|
||||
type args struct {
|
||||
metadata []byte
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *saml.EntityDescriptor
|
||||
wantErr error
|
||||
}{
|
||||
{
|
||||
"invalid",
|
||||
args{
|
||||
metadata: []byte(`<Test></Test>`),
|
||||
},
|
||||
nil,
|
||||
xml.UnmarshalError("expected element type <EntityDescriptor> but have <Test>"),
|
||||
},
|
||||
{
|
||||
"valid entity descriptor",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor>`),
|
||||
},
|
||||
&saml.EntityDescriptor{
|
||||
EntityID: "http://localhost:8000/metadata",
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
XMLName: xml.Name{
|
||||
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
|
||||
Local: "IDPSSODescriptor",
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPRedirectBinding,
|
||||
Location: "http://localhost:8000/sso",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"valid entity descriptor, non utf-8",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="windows-1252"?><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor>`),
|
||||
},
|
||||
&saml.EntityDescriptor{
|
||||
EntityID: "http://localhost:8000/metadata",
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
XMLName: xml.Name{
|
||||
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
|
||||
Local: "IDPSSODescriptor",
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPRedirectBinding,
|
||||
Location: "http://localhost:8000/sso",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"entities descriptor without IDPSSODescriptor",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"></EntityDescriptor></EntitiesDescriptor>`),
|
||||
},
|
||||
nil,
|
||||
zerrors.ThrowInternal(nil, "SAML-Ejoi3r2", "no entity found with IDPSSODescriptor"),
|
||||
},
|
||||
{
|
||||
"valid entities descriptor",
|
||||
args{
|
||||
metadata: []byte(`<?xml version="1.0" encoding="UTF-8"?><EntitiesDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" entityID="http://localhost:8000/metadata"><IDPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata"><SingleSignOnService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Redirect" Location="http://localhost:8000/sso"></SingleSignOnService></IDPSSODescriptor></EntityDescriptor></EntitiesDescriptor>`),
|
||||
},
|
||||
&saml.EntityDescriptor{
|
||||
EntityID: "http://localhost:8000/metadata",
|
||||
IDPSSODescriptors: []saml.IDPSSODescriptor{
|
||||
{
|
||||
XMLName: xml.Name{
|
||||
Space: "urn:oasis:names:tc:SAML:2.0:metadata",
|
||||
Local: "IDPSSODescriptor",
|
||||
},
|
||||
SingleSignOnServices: []saml.Endpoint{
|
||||
{
|
||||
Binding: saml.HTTPRedirectBinding,
|
||||
Location: "http://localhost:8000/sso",
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := ParseMetadata(tt.args.metadata)
|
||||
assert.ErrorIs(t, err, tt.wantErr)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -9,7 +9,6 @@ import (
|
||||
|
||||
"github.com/crewjam/saml"
|
||||
"github.com/crewjam/saml/samlsp"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/idp"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
@ -71,7 +70,7 @@ func (s *Session) FetchUser(ctx context.Context) (user idp.User, err error) {
|
||||
if err != nil {
|
||||
invalidRespErr := new(saml.InvalidResponseError)
|
||||
if errors.As(err, &invalidRespErr) {
|
||||
logging.WithError(invalidRespErr.PrivateErr).Info("invalid SAML response details")
|
||||
return nil, zerrors.ThrowInvalidArgument(invalidRespErr.PrivateErr, "SAML-ajl3irfs", "Errors.Intent.ResponseInvalid")
|
||||
}
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "SAML-nuo0vphhh9", "Errors.Intent.ResponseInvalid")
|
||||
}
|
||||
|
@ -134,7 +134,7 @@ func TestSession_FetchUser(t *testing.T) {
|
||||
requestID: "id-b22c90db88bf01d82ffb0a7b6fe25ac9fcb2c679",
|
||||
},
|
||||
want: want{
|
||||
err: zerrors.ThrowInvalidArgument(nil, "SAML-nuo0vphhh9", "Errors.Intent.ResponseInvalid"),
|
||||
err: zerrors.ThrowInvalidArgument(nil, "SAML-ajl3irfs", "Errors.Intent.ResponseInvalid"),
|
||||
},
|
||||
},
|
||||
{
|
||||
|
Loading…
Reference in New Issue
Block a user