diff --git a/.github/workflows/core-integration-test.yml b/.github/workflows/core-integration-test.yml index cc9d898f5c..3e689e3db6 100644 --- a/.github/workflows/core-integration-test.yml +++ b/.github/workflows/core-integration-test.yml @@ -76,7 +76,6 @@ jobs: if: ${{ steps.cache.outputs.cache-hit != 'true' }} env: ZITADEL_MASTERKEY: MasterkeyNeedsToHave32Characters - INTEGRATION_DB_FLAVOR: postgres run: make core_integration_test - name: upload server logs @@ -102,71 +101,3 @@ jobs: with: key: integration-test-postgres-${{ inputs.core_cache_key }} path: ${{ steps.go-cache-path.outputs.GO_CACHE_PATH }} - - # TODO: produces the following output: ERROR: unknown command "cockroach start-single-node --insecure" for "cockroach" - # cockroach: - # runs-on: ubuntu-latest - # services: - # cockroach: - # image: cockroachdb/cockroach:latest - # ports: - # - 26257:26257 - # - 8080:8080 - # env: - # COCKROACH_ARGS: "start-single-node --insecure" - # options: >- - # --health-cmd "curl http://localhost:8080/health?ready=1 || exit 1" - # --health-interval 10s - # --health-timeout 5s - # --health-retries 5 - # --health-start-period 10s - # steps: - # - - # uses: actions/checkout@v4 - # - - # uses: actions/setup-go@v5 - # with: - # go-version: ${{ inputs.go_version }} - # - - # uses: actions/cache/restore@v4 - # timeout-minutes: 1 - # name: restore core - # with: - # path: ${{ inputs.core_cache_path }} - # key: ${{ inputs.core_cache_key }} - # fail-on-cache-miss: true - # - - # id: go-cache-path - # name: set cache path - # run: echo "GO_CACHE_PATH=$(go env GOCACHE)" >> $GITHUB_OUTPUT - # - - # uses: actions/cache/restore@v4 - # id: cache - # timeout-minutes: 1 - # name: restore previous results - # with: - # key: integration-test-crdb-${{ inputs.core_cache_key }} - # restore-keys: | - # integration-test-crdb-core- - # path: ${{ steps.go-cache-path.outputs.GO_CACHE_PATH }} - # - - # name: test - # if: ${{ steps.cache.outputs.cache-hit != 'true' }} - # env: - # ZITADEL_MASTERKEY: MasterkeyNeedsToHave32Characters - # INTEGRATION_DB_FLAVOR: cockroach - # run: make core_integration_test - # - - # name: publish coverage - # uses: codecov/codecov-action@v4.3.0 - # with: - # file: profile.cov - # name: core-integration-tests-cockroach - # flags: core-integration-tests-cockroach - # - - # uses: actions/cache/save@v4 - # name: cache results - # if: ${{ steps.cache.outputs.cache-hit != 'true' }} - # with: - # key: integration-test-crdb-${{ inputs.core_cache_key }} - # path: ${{ steps.go-cache-path.outputs.GO_CACHE_PATH }} \ No newline at end of file diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 6fafd3dd6f..0663c9be0d 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -216,12 +216,6 @@ Integration tests are run as gRPC clients against a running ZITADEL server binar The server binary is typically [build with coverage enabled](https://go.dev/doc/build-cover). It is also possible to run a ZITADEL sever in a debugger and run the integrations tests like that. In order to run the server, a database is required. -The database flavor can **optionally** be set in the environment to `cockroach` or `postgres`. The default is `postgres`. - -```bash -export INTEGRATION_DB_FLAVOR="cockroach" -``` - In order to prepare the local system, the following will bring up the database, builds a coverage binary, initializes the database and starts the sever. ```bash diff --git a/Makefile b/Makefile index 27e76c0614..b5145cef3d 100644 --- a/Makefile +++ b/Makefile @@ -8,10 +8,9 @@ COMMIT_SHA ?= $(shell git rev-parse HEAD) ZITADEL_IMAGE ?= zitadel:local GOCOVERDIR = tmp/coverage -INTEGRATION_DB_FLAVOR ?= postgres ZITADEL_MASTERKEY ?= MasterkeyNeedsToHave32Characters -export GOCOVERDIR INTEGRATION_DB_FLAVOR ZITADEL_MASTERKEY +export GOCOVERDIR ZITADEL_MASTERKEY .PHONY: compile compile: core_build console_build compile_pipeline @@ -113,7 +112,7 @@ core_unit_test: .PHONY: core_integration_db_up core_integration_db_up: - docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait $${INTEGRATION_DB_FLAVOR} cache + docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait cache .PHONY: core_integration_db_down core_integration_db_down: @@ -123,13 +122,13 @@ core_integration_db_down: core_integration_setup: go build -cover -race -tags integration -o zitadel.test main.go mkdir -p $${GOCOVERDIR} - GORACE="halt_on_error=1" ./zitadel.test init --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml - GORACE="halt_on_error=1" ./zitadel.test setup --masterkeyFromEnv --init-projections --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml --steps internal/integration/config/steps.yaml + GORACE="halt_on_error=1" ./zitadel.test init --config internal/integration/config/zitadel.yaml --config internal/integration/config/postgres.yaml + GORACE="halt_on_error=1" ./zitadel.test setup --masterkeyFromEnv --init-projections --config internal/integration/config/zitadel.yaml --config internal/integration/config/postgres.yaml --steps internal/integration/config/steps.yaml .PHONY: core_integration_server_start core_integration_server_start: core_integration_setup GORACE="log_path=tmp/race.log" \ - ./zitadel.test start --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/${INTEGRATION_DB_FLAVOR}.yaml \ + ./zitadel.test start --masterkeyFromEnv --config internal/integration/config/zitadel.yaml --config internal/integration/config/postgres.yaml \ > tmp/zitadel.log 2>&1 \ & printf $$! > tmp/zitadel.pid diff --git a/build/workflow.Dockerfile b/build/workflow.Dockerfile index db27daf91c..2286531192 100644 --- a/build/workflow.Dockerfile +++ b/build/workflow.Dockerfile @@ -199,7 +199,6 @@ ENV PATH="/go/bin:/usr/local/go/bin:${PATH}" WORKDIR /go/src/github.com/zitadel/zitadel # default vars -ENV DB_FLAVOR=postgres ENV POSTGRES_USER=zitadel ENV POSTGRES_DB=zitadel ENV POSTGRES_PASSWORD=postgres @@ -231,12 +230,6 @@ COPY --from=test-core-unit /go/src/github.com/zitadel/zitadel/profile.cov /cover # integration test core # ####################################### FROM test-core-base AS test-core-integration -ENV DB_FLAVOR=cockroach - -# install cockroach -COPY --from=cockroachdb/cockroach:latest /cockroach/cockroach /usr/local/bin/ -ENV COCKROACH_BINARY=/cockroach/cockroach - ENV ZITADEL_MASTERKEY=MasterkeyNeedsToHave32Characters COPY build/core-integration-test.sh /usr/local/bin/run-tests.sh diff --git a/cmd/defaults.yaml b/cmd/defaults.yaml index 71ad22a4f9..0c70080556 100644 --- a/cmd/defaults.yaml +++ b/cmd/defaults.yaml @@ -110,67 +110,36 @@ PublicHostHeaders: # ZITADEL_PUBLICHOSTHEADERS WebAuthNName: ZITADEL # ZITADEL_WEBAUTHNNAME Database: - # CockroachDB is the default database of ZITADEL - cockroach: - Host: localhost # ZITADEL_DATABASE_COCKROACH_HOST - Port: 26257 # ZITADEL_DATABASE_COCKROACH_PORT - Database: zitadel # ZITADEL_DATABASE_COCKROACH_DATABASE - MaxOpenConns: 5 # ZITADEL_DATABASE_COCKROACH_MAXOPENCONNS - MaxIdleConns: 2 # ZITADEL_DATABASE_COCKROACH_MAXIDLECONNS - MaxConnLifetime: 30m # ZITADEL_DATABASE_COCKROACH_MAXCONNLIFETIME - MaxConnIdleTime: 5m # ZITADEL_DATABASE_COCKROACH_MAXCONNIDLETIME - Options: "" # ZITADEL_DATABASE_COCKROACH_OPTIONS + # Postgres is the default database of ZITADEL + postgres: + Host: localhost # ZITADEL_DATABASE_POSTGRES_HOST + Port: 5432 # ZITADEL_DATABASE_POSTGRES_PORT + Database: zitadel # ZITADEL_DATABASE_POSTGRES_DATABASE + MaxOpenConns: 5 # ZITADEL_DATABASE_POSTGRES_MAXOPENCONNS + MaxIdleConns: 2 # ZITADEL_DATABASE_POSTGRES_MAXIDLECONNS + MaxConnLifetime: 30m # ZITADEL_DATABASE_POSTGRES_MAXCONNLIFETIME + MaxConnIdleTime: 5m # ZITADEL_DATABASE_POSTGRES_MAXCONNIDLETIME + Options: "" # ZITADEL_DATABASE_POSTGRES_OPTIONS User: - Username: zitadel # ZITADEL_DATABASE_COCKROACH_USER_USERNAME - Password: "" # ZITADEL_DATABASE_COCKROACH_USER_PASSWORD + Username: zitadel # ZITADEL_DATABASE_POSTGRES_USER_USERNAME + Password: "" # ZITADEL_DATABASE_POSTGRES_USER_PASSWORD SSL: - Mode: disable # ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE - RootCert: "" # ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT - Cert: "" # ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT - Key: "" # ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY + Mode: disable # ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE + RootCert: "" # ZITADEL_DATABASE_POSTGRES_USER_SSL_ROOTCERT + Cert: "" # ZITADEL_DATABASE_POSTGRES_USER_SSL_CERT + Key: "" # ZITADEL_DATABASE_POSTGRES_USER_SSL_KEY Admin: # By default, ExistingDatabase is not specified in the connection string # If the connection resolves to a database that is not existing in your system, configure an existing one here - # It is used in zitadel init to connect to cockroach and create a dedicated database for ZITADEL. - ExistingDatabase: # ZITADEL_DATABASE_COCKROACH_ADMIN_EXISTINGDATABASE - Username: root # ZITADEL_DATABASE_COCKROACH_ADMIN_USERNAME - Password: "" # ZITADEL_DATABASE_COCKROACH_ADMIN_PASSWORD - SSL: - Mode: disable # ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_MODE - RootCert: "" # ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_ROOTCERT - Cert: "" # ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_CERT - Key: "" # ZITADEL_DATABASE_COCKROACH_ADMIN_SSL_KEY - # Postgres is used as soon as a value is set - # The values describe the possible fields to set values - postgres: - Host: # ZITADEL_DATABASE_POSTGRES_HOST - Port: # ZITADEL_DATABASE_POSTGRES_PORT - Database: # ZITADEL_DATABASE_POSTGRES_DATABASE - MaxOpenConns: # ZITADEL_DATABASE_POSTGRES_MAXOPENCONNS - MaxIdleConns: # ZITADEL_DATABASE_POSTGRES_MAXIDLECONNS - MaxConnLifetime: # ZITADEL_DATABASE_POSTGRES_MAXCONNLIFETIME - MaxConnIdleTime: # ZITADEL_DATABASE_POSTGRES_MAXCONNIDLETIME - Options: # ZITADEL_DATABASE_POSTGRES_OPTIONS - User: - Username: # ZITADEL_DATABASE_POSTGRES_USER_USERNAME - Password: # ZITADEL_DATABASE_POSTGRES_USER_PASSWORD - SSL: - Mode: # ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE - RootCert: # ZITADEL_DATABASE_POSTGRES_USER_SSL_ROOTCERT - Cert: # ZITADEL_DATABASE_POSTGRES_USER_SSL_CERT - Key: # ZITADEL_DATABASE_POSTGRES_USER_SSL_KEY - Admin: - # The default ExistingDatabase is postgres - # If your db system doesn't have a database named postgres, configure an existing database here # It is used in zitadel init to connect to postgres and create a dedicated database for ZITADEL. ExistingDatabase: # ZITADEL_DATABASE_POSTGRES_ADMIN_EXISTINGDATABASE - Username: # ZITADEL_DATABASE_POSTGRES_ADMIN_USERNAME - Password: # ZITADEL_DATABASE_POSTGRES_ADMIN_PASSWORD + Username: postgres # ZITADEL_DATABASE_POSTGRES_ADMIN_USERNAME + Password: postgres # ZITADEL_DATABASE_POSTGRES_ADMIN_PASSWORD SSL: - Mode: # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_MODE - RootCert: # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_ROOTCERT - Cert: # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_CERT - Key: # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_KEY + Mode: disable # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_MODE + RootCert: "" # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_ROOTCERT + Cert: "" # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_CERT + Key: "" # ZITADEL_DATABASE_POSTGRES_ADMIN_SSL_KEY # Caches are EXPERIMENTAL. The following config may have breaking changes in the future. # If no config is provided, caching is disabled by default. @@ -444,7 +413,6 @@ Projections: Notifications: # Notifications can be processed by either a sequential mode (legacy) or a new parallel mode. # The parallel mode is currently only recommended for Postgres databases. - # For CockroachDB, the sequential mode is recommended, see: https://github.com/zitadel/zitadel/issues/9002 # If legacy mode is enabled, the worker config below is ignored. LegacyEnabled: true # ZITADEL_NOTIFICATIONS_LEGACYENABLED # The amount of workers processing the notification request events. diff --git a/cmd/initialise/config.go b/cmd/initialise/config.go index 3fe7173860..899018ddcb 100644 --- a/cmd/initialise/config.go +++ b/cmd/initialise/config.go @@ -19,7 +19,7 @@ func MustNewConfig(v *viper.Viper) *Config { config := new(Config) err := v.Unmarshal(config, viper.DecodeHook(mapstructure.ComposeDecodeHookFunc( - database.DecodeHook, + database.DecodeHook(false), mapstructure.TextUnmarshallerHookFunc(), )), ) diff --git a/cmd/initialise/init.go b/cmd/initialise/init.go index 02fd481eab..cc505325a9 100644 --- a/cmd/initialise/init.go +++ b/cmd/initialise/init.go @@ -12,20 +12,17 @@ import ( ) var ( - //go:embed sql/cockroach/* - //go:embed sql/postgres/* + //go:embed sql/*.sql stmts embed.FS createUserStmt string grantStmt string - settingsStmt string databaseStmt string createEventstoreStmt string createProjectionsStmt string createSystemStmt string createEncryptionKeysStmt string createEventsStmt string - createSystemSequenceStmt string createUniqueConstraints string roleAlreadyExistsCode = "42710" @@ -39,7 +36,7 @@ func New() *cobra.Command { Long: `Sets up the minimum requirements to start ZITADEL. Prerequisites: -- database (PostgreSql or cockroachdb) +- PostgreSql database The user provided by flags needs privileges to - create the database if it does not exist @@ -53,7 +50,7 @@ The user provided by flags needs privileges to }, } - cmd.AddCommand(newZitadel(), newDatabase(), newUser(), newGrant(), newSettings()) + cmd.AddCommand(newZitadel(), newDatabase(), newUser(), newGrant()) return cmd } @@ -62,7 +59,6 @@ func InitAll(ctx context.Context, config *Config) { VerifyUser(config.Database.Username(), config.Database.Password()), VerifyDatabase(config.Database.DatabaseName()), VerifyGrant(config.Database.DatabaseName(), config.Database.Username()), - VerifySettings(config.Database.DatabaseName(), config.Database.Username()), ) logging.OnError(err).Fatal("unable to initialize the database") @@ -73,7 +69,7 @@ func InitAll(ctx context.Context, config *Config) { func initialise(ctx context.Context, config database.Config, steps ...func(context.Context, *database.DB) error) error { logging.Info("initialization started") - err := ReadStmts(config.Type()) + err := ReadStmts() if err != nil { return err } @@ -97,58 +93,48 @@ func Init(ctx context.Context, db *database.DB, steps ...func(context.Context, * return nil } -func ReadStmts(typ string) (err error) { - createUserStmt, err = readStmt(typ, "01_user") +func ReadStmts() (err error) { + createUserStmt, err = readStmt("01_user") if err != nil { return err } - databaseStmt, err = readStmt(typ, "02_database") + databaseStmt, err = readStmt("02_database") if err != nil { return err } - grantStmt, err = readStmt(typ, "03_grant_user") + grantStmt, err = readStmt("03_grant_user") if err != nil { return err } - createEventstoreStmt, err = readStmt(typ, "04_eventstore") + createEventstoreStmt, err = readStmt("04_eventstore") if err != nil { return err } - createProjectionsStmt, err = readStmt(typ, "05_projections") + createProjectionsStmt, err = readStmt("05_projections") if err != nil { return err } - createSystemStmt, err = readStmt(typ, "06_system") + createSystemStmt, err = readStmt("06_system") if err != nil { return err } - createEncryptionKeysStmt, err = readStmt(typ, "07_encryption_keys_table") + createEncryptionKeysStmt, err = readStmt("07_encryption_keys_table") if err != nil { return err } - createEventsStmt, err = readStmt(typ, "08_events_table") + createEventsStmt, err = readStmt("08_events_table") if err != nil { return err } - createSystemSequenceStmt, err = readStmt(typ, "09_system_sequence") - if err != nil { - return err - } - - createUniqueConstraints, err = readStmt(typ, "10_unique_constraints_table") - if err != nil { - return err - } - - settingsStmt, err = readStmt(typ, "11_settings") + createUniqueConstraints, err = readStmt("10_unique_constraints_table") if err != nil { return err } @@ -156,7 +142,7 @@ func ReadStmts(typ string) (err error) { return nil } -func readStmt(typ, step string) (string, error) { - stmt, err := stmts.ReadFile("sql/" + typ + "/" + step + ".sql") +func readStmt(step string) (string, error) { + stmt, err := stmts.ReadFile("sql/" + step + ".sql") return string(stmt), err } diff --git a/cmd/initialise/sql/cockroach/01_user.sql b/cmd/initialise/sql/01_user.sql similarity index 56% rename from cmd/initialise/sql/cockroach/01_user.sql rename to cmd/initialise/sql/01_user.sql index 4e621216ce..7be2d4ae4d 100644 --- a/cmd/initialise/sql/cockroach/01_user.sql +++ b/cmd/initialise/sql/01_user.sql @@ -1,2 +1,2 @@ -- replace %[1]s with the name of the user -CREATE USER IF NOT EXISTS "%[1]s" \ No newline at end of file +CREATE USER "%[1]s" \ No newline at end of file diff --git a/cmd/initialise/sql/cockroach/02_database.sql b/cmd/initialise/sql/02_database.sql similarity index 54% rename from cmd/initialise/sql/cockroach/02_database.sql rename to cmd/initialise/sql/02_database.sql index 6103b95b31..172913661b 100644 --- a/cmd/initialise/sql/cockroach/02_database.sql +++ b/cmd/initialise/sql/02_database.sql @@ -1,2 +1,2 @@ -- replace %[1]s with the name of the database -CREATE DATABASE IF NOT EXISTS "%[1]s"; +CREATE DATABASE "%[1]s" \ No newline at end of file diff --git a/cmd/initialise/sql/postgres/03_grant_user.sql b/cmd/initialise/sql/03_grant_user.sql similarity index 100% rename from cmd/initialise/sql/postgres/03_grant_user.sql rename to cmd/initialise/sql/03_grant_user.sql diff --git a/cmd/initialise/sql/cockroach/04_eventstore.sql b/cmd/initialise/sql/04_eventstore.sql similarity index 100% rename from cmd/initialise/sql/cockroach/04_eventstore.sql rename to cmd/initialise/sql/04_eventstore.sql diff --git a/cmd/initialise/sql/cockroach/05_projections.sql b/cmd/initialise/sql/05_projections.sql similarity index 100% rename from cmd/initialise/sql/cockroach/05_projections.sql rename to cmd/initialise/sql/05_projections.sql diff --git a/cmd/initialise/sql/cockroach/06_system.sql b/cmd/initialise/sql/06_system.sql similarity index 100% rename from cmd/initialise/sql/cockroach/06_system.sql rename to cmd/initialise/sql/06_system.sql diff --git a/cmd/initialise/sql/cockroach/07_encryption_keys_table.sql b/cmd/initialise/sql/07_encryption_keys_table.sql similarity index 100% rename from cmd/initialise/sql/cockroach/07_encryption_keys_table.sql rename to cmd/initialise/sql/07_encryption_keys_table.sql diff --git a/cmd/initialise/sql/postgres/08_events_table.sql b/cmd/initialise/sql/08_events_table.sql similarity index 100% rename from cmd/initialise/sql/postgres/08_events_table.sql rename to cmd/initialise/sql/08_events_table.sql diff --git a/cmd/initialise/sql/postgres/10_unique_constraints_table.sql b/cmd/initialise/sql/10_unique_constraints_table.sql similarity index 100% rename from cmd/initialise/sql/postgres/10_unique_constraints_table.sql rename to cmd/initialise/sql/10_unique_constraints_table.sql diff --git a/cmd/initialise/sql/README.md b/cmd/initialise/sql/README.md index b477c0fb73..b7a18f0f98 100644 --- a/cmd/initialise/sql/README.md +++ b/cmd/initialise/sql/README.md @@ -11,6 +11,5 @@ The sql-files in this folder initialize the ZITADEL database and user. These obj - 05_projections.sql: creates the schema needed to read the data - 06_system.sql: creates the schema needed for ZITADEL itself - 07_encryption_keys_table.sql: creates the table for encryption keys (for event data) -- files 08_enable_hash_sharded_indexes.sql and 09_events_table.sql must run in the same session - - 08_enable_hash_sharded_indexes.sql enables the [hash sharded index](https://www.cockroachlabs.com/docs/stable/hash-sharded-indexes.html) feature for this session - - 09_events_table.sql creates the table for eventsourcing +- 08_events_table.sql creates the table for eventsourcing +- 10_unique_constraints_table.sql creates the table to check unique constraints for events diff --git a/cmd/initialise/sql/cockroach/03_grant_user.sql b/cmd/initialise/sql/cockroach/03_grant_user.sql deleted file mode 100644 index de0d2743eb..0000000000 --- a/cmd/initialise/sql/cockroach/03_grant_user.sql +++ /dev/null @@ -1,4 +0,0 @@ --- replace the first %[1]s with the database --- replace the second \%[2]s with the user -GRANT ALL ON DATABASE "%[1]s" TO "%[2]s"; -GRANT SYSTEM VIEWACTIVITY TO "%[2]s"; \ No newline at end of file diff --git a/cmd/initialise/sql/cockroach/08_events_table.sql b/cmd/initialise/sql/cockroach/08_events_table.sql deleted file mode 100644 index ebaf18ce2a..0000000000 --- a/cmd/initialise/sql/cockroach/08_events_table.sql +++ /dev/null @@ -1,116 +0,0 @@ -CREATE TABLE IF NOT EXISTS eventstore.events2 ( - instance_id TEXT NOT NULL - , aggregate_type TEXT NOT NULL - , aggregate_id TEXT NOT NULL - - , event_type TEXT NOT NULL - , "sequence" BIGINT NOT NULL - , revision SMALLINT NOT NULL - , created_at TIMESTAMPTZ NOT NULL - , payload JSONB - , creator TEXT NOT NULL - , "owner" TEXT NOT NULL - - , "position" DECIMAL NOT NULL - , in_tx_order INTEGER NOT NULL - - , PRIMARY KEY (instance_id, aggregate_type, aggregate_id, "sequence") - , INDEX es_active_instances (created_at DESC) STORING ("position") - , INDEX es_wm (aggregate_id, instance_id, aggregate_type, event_type) - , INDEX es_projection (instance_id, aggregate_type, event_type, "position" DESC) -); - --- represents an event to be created. -CREATE TYPE IF NOT EXISTS eventstore.command AS ( - instance_id TEXT - , aggregate_type TEXT - , aggregate_id TEXT - , command_type TEXT - , revision INT2 - , payload JSONB - , creator TEXT - , owner TEXT -); - -CREATE OR REPLACE FUNCTION eventstore.commands_to_events(commands eventstore.command[]) RETURNS SETOF eventstore.events2 VOLATILE AS $$ -SELECT - ("c").instance_id - , ("c").aggregate_type - , ("c").aggregate_id - , ("c").command_type AS event_type - , cs.sequence + ROW_NUMBER() OVER (PARTITION BY ("c").instance_id, ("c").aggregate_type, ("c").aggregate_id ORDER BY ("c").in_tx_order) AS sequence - , ("c").revision - , hlc_to_timestamp(cluster_logical_timestamp()) AS created_at - , ("c").payload - , ("c").creator - , cs.owner - , cluster_logical_timestamp() AS position - , ("c").in_tx_order -FROM ( - SELECT - ("c").instance_id - , ("c").aggregate_type - , ("c").aggregate_id - , ("c").command_type - , ("c").revision - , ("c").payload - , ("c").creator - , ("c").owner - , ROW_NUMBER() OVER () AS in_tx_order - FROM - UNNEST(commands) AS "c" -) AS "c" -JOIN ( - SELECT - cmds.instance_id - , cmds.aggregate_type - , cmds.aggregate_id - , CASE WHEN (e.owner IS NOT NULL OR e.owner <> '') THEN e.owner ELSE command_owners.owner END AS owner - , COALESCE(MAX(e.sequence), 0) AS sequence - FROM ( - SELECT DISTINCT - ("cmds").instance_id - , ("cmds").aggregate_type - , ("cmds").aggregate_id - , ("cmds").owner - FROM UNNEST(commands) AS "cmds" - ) AS cmds - LEFT JOIN eventstore.events2 AS e - ON cmds.instance_id = e.instance_id - AND cmds.aggregate_type = e.aggregate_type - AND cmds.aggregate_id = e.aggregate_id - JOIN ( - SELECT - DISTINCT ON ( - ("c").instance_id - , ("c").aggregate_type - , ("c").aggregate_id - ) - ("c").instance_id - , ("c").aggregate_type - , ("c").aggregate_id - , ("c").owner - FROM - UNNEST(commands) AS "c" - ) AS command_owners ON - cmds.instance_id = command_owners.instance_id - AND cmds.aggregate_type = command_owners.aggregate_type - AND cmds.aggregate_id = command_owners.aggregate_id - GROUP BY - cmds.instance_id - , cmds.aggregate_type - , cmds.aggregate_id - , 4 -- owner -) AS cs - ON ("c").instance_id = cs.instance_id - AND ("c").aggregate_type = cs.aggregate_type - AND ("c").aggregate_id = cs.aggregate_id -ORDER BY - in_tx_order -$$ LANGUAGE SQL; - -CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 AS $$ - INSERT INTO eventstore.events2 - SELECT * FROM eventstore.commands_to_events(commands) - RETURNING * -$$ LANGUAGE SQL; \ No newline at end of file diff --git a/cmd/initialise/sql/cockroach/09_system_sequence.sql b/cmd/initialise/sql/cockroach/09_system_sequence.sql deleted file mode 100644 index 596e887664..0000000000 --- a/cmd/initialise/sql/cockroach/09_system_sequence.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE SEQUENCE IF NOT EXISTS eventstore.system_seq diff --git a/cmd/initialise/sql/cockroach/10_unique_constraints_table.sql b/cmd/initialise/sql/cockroach/10_unique_constraints_table.sql deleted file mode 100644 index 2594a248b7..0000000000 --- a/cmd/initialise/sql/cockroach/10_unique_constraints_table.sql +++ /dev/null @@ -1,6 +0,0 @@ -CREATE TABLE IF NOT EXISTS eventstore.unique_constraints ( - instance_id TEXT, - unique_type TEXT, - unique_field TEXT, - PRIMARY KEY (instance_id, unique_type, unique_field) -) diff --git a/cmd/initialise/sql/cockroach/11_settings.sql b/cmd/initialise/sql/cockroach/11_settings.sql deleted file mode 100644 index 5fa9dd72f6..0000000000 --- a/cmd/initialise/sql/cockroach/11_settings.sql +++ /dev/null @@ -1,4 +0,0 @@ --- replace the first %[1]q with the database in double quotes --- replace the second \%[2]q with the user in double quotes$ --- For more information see technical advisory 10009 (https://zitadel.com/docs/support/advisory/a10009) -ALTER ROLE %[2]q IN DATABASE %[1]q SET enable_durable_locking_for_serializable = on; \ No newline at end of file diff --git a/cmd/initialise/sql/postgres/01_user.sql b/cmd/initialise/sql/postgres/01_user.sql deleted file mode 100644 index cd60b9a2cf..0000000000 --- a/cmd/initialise/sql/postgres/01_user.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE USER "%[1]s" \ No newline at end of file diff --git a/cmd/initialise/sql/postgres/02_database.sql b/cmd/initialise/sql/postgres/02_database.sql deleted file mode 100644 index 895a1f29d5..0000000000 --- a/cmd/initialise/sql/postgres/02_database.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE DATABASE "%[1]s" \ No newline at end of file diff --git a/cmd/initialise/sql/postgres/04_eventstore.sql b/cmd/initialise/sql/postgres/04_eventstore.sql deleted file mode 100644 index 3cb4fc0d3e..0000000000 --- a/cmd/initialise/sql/postgres/04_eventstore.sql +++ /dev/null @@ -1,3 +0,0 @@ -CREATE SCHEMA IF NOT EXISTS eventstore; - -GRANT ALL ON ALL TABLES IN SCHEMA eventstore TO "%[1]s"; \ No newline at end of file diff --git a/cmd/initialise/sql/postgres/05_projections.sql b/cmd/initialise/sql/postgres/05_projections.sql deleted file mode 100644 index 91ca6662ee..0000000000 --- a/cmd/initialise/sql/postgres/05_projections.sql +++ /dev/null @@ -1,3 +0,0 @@ -CREATE SCHEMA IF NOT EXISTS projections; - -GRANT ALL ON ALL TABLES IN SCHEMA projections TO "%[1]s"; \ No newline at end of file diff --git a/cmd/initialise/sql/postgres/06_system.sql b/cmd/initialise/sql/postgres/06_system.sql deleted file mode 100644 index 6c9138918b..0000000000 --- a/cmd/initialise/sql/postgres/06_system.sql +++ /dev/null @@ -1,3 +0,0 @@ -CREATE SCHEMA IF NOT EXISTS system; - -GRANT ALL ON ALL TABLES IN SCHEMA system TO "%[1]s"; \ No newline at end of file diff --git a/cmd/initialise/sql/postgres/07_encryption_keys_table.sql b/cmd/initialise/sql/postgres/07_encryption_keys_table.sql deleted file mode 100644 index 61cb617fdf..0000000000 --- a/cmd/initialise/sql/postgres/07_encryption_keys_table.sql +++ /dev/null @@ -1,6 +0,0 @@ -CREATE TABLE IF NOT EXISTS system.encryption_keys ( - id TEXT NOT NULL - , key TEXT NOT NULL - - , PRIMARY KEY (id) -); diff --git a/cmd/initialise/sql/postgres/09_system_sequence.sql b/cmd/initialise/sql/postgres/09_system_sequence.sql deleted file mode 100644 index 15383b3878..0000000000 --- a/cmd/initialise/sql/postgres/09_system_sequence.sql +++ /dev/null @@ -1 +0,0 @@ -CREATE SEQUENCE IF NOT EXISTS eventstore.system_seq; diff --git a/cmd/initialise/sql/postgres/11_settings.sql b/cmd/initialise/sql/postgres/11_settings.sql deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/cmd/initialise/verify_database.go b/cmd/initialise/verify_database.go index 6e04e489f5..3e3bea9efa 100644 --- a/cmd/initialise/verify_database.go +++ b/cmd/initialise/verify_database.go @@ -19,7 +19,7 @@ func newDatabase() *cobra.Command { Long: `Sets up the ZITADEL database. Prerequisites: -- cockroachDB or postgreSQL +- postgreSQL The user provided by flags needs privileges to - create the database if it does not exist diff --git a/cmd/initialise/verify_database_test.go b/cmd/initialise/verify_database_test.go index d7da97847f..1899605e4f 100644 --- a/cmd/initialise/verify_database_test.go +++ b/cmd/initialise/verify_database_test.go @@ -8,7 +8,7 @@ import ( ) func Test_verifyDB(t *testing.T) { - err := ReadStmts("cockroach") //TODO: check all dialects + err := ReadStmts() if err != nil { t.Errorf("unable to read stmts: %v", err) t.FailNow() @@ -27,7 +27,7 @@ func Test_verifyDB(t *testing.T) { name: "doesn't exists, create fails", args: args{ db: prepareDB(t, - expectExec("-- replace zitadel with the name of the database\nCREATE DATABASE IF NOT EXISTS \"zitadel\"", sql.ErrTxDone), + expectExec("-- replace zitadel with the name of the database\nCREATE DATABASE \"zitadel\"", sql.ErrTxDone), ), database: "zitadel", }, @@ -37,7 +37,7 @@ func Test_verifyDB(t *testing.T) { name: "doesn't exists, create successful", args: args{ db: prepareDB(t, - expectExec("-- replace zitadel with the name of the database\nCREATE DATABASE IF NOT EXISTS \"zitadel\"", nil), + expectExec("-- replace zitadel with the name of the database\nCREATE DATABASE \"zitadel\"", nil), ), database: "zitadel", }, @@ -47,7 +47,7 @@ func Test_verifyDB(t *testing.T) { name: "already exists", args: args{ db: prepareDB(t, - expectExec("-- replace zitadel with the name of the database\nCREATE DATABASE IF NOT EXISTS \"zitadel\"", nil), + expectExec("-- replace zitadel with the name of the database\nCREATE DATABASE \"zitadel\"", nil), ), database: "zitadel", }, diff --git a/cmd/initialise/verify_grant.go b/cmd/initialise/verify_grant.go index a14a495bff..27f0bd4d08 100644 --- a/cmd/initialise/verify_grant.go +++ b/cmd/initialise/verify_grant.go @@ -19,7 +19,7 @@ func newGrant() *cobra.Command { Long: `Sets ALL grant to the database user. Prerequisites: -- cockroachDB or postgreSQL +- postgreSQL `, Run: func(cmd *cobra.Command, args []string) { config := MustNewConfig(viper.GetViper()) diff --git a/cmd/initialise/verify_settings.go b/cmd/initialise/verify_settings.go deleted file mode 100644 index 6f4ba7c074..0000000000 --- a/cmd/initialise/verify_settings.go +++ /dev/null @@ -1,45 +0,0 @@ -package initialise - -import ( - "context" - _ "embed" - "fmt" - - "github.com/spf13/cobra" - "github.com/spf13/viper" - "github.com/zitadel/logging" - - "github.com/zitadel/zitadel/internal/database" -) - -func newSettings() *cobra.Command { - return &cobra.Command{ - Use: "settings", - Short: "Ensures proper settings on the database", - Long: `Ensures proper settings on the database. - -Prerequisites: -- cockroachDB or postgreSQL - -Cockroach -- Sets enable_durable_locking_for_serializable to on for the zitadel user and database -`, - Run: func(cmd *cobra.Command, args []string) { - config := MustNewConfig(viper.GetViper()) - - err := initialise(cmd.Context(), config.Database, VerifySettings(config.Database.DatabaseName(), config.Database.Username())) - logging.OnError(err).Fatal("unable to set settings") - }, - } -} - -func VerifySettings(databaseName, username string) func(context.Context, *database.DB) error { - return func(ctx context.Context, db *database.DB) error { - if db.Type() == "postgres" { - return nil - } - logging.WithFields("user", username, "database", databaseName).Info("verify settings") - - return exec(ctx, db, fmt.Sprintf(settingsStmt, databaseName, username), nil) - } -} diff --git a/cmd/initialise/verify_user.go b/cmd/initialise/verify_user.go index 43bdb91420..3adca93e53 100644 --- a/cmd/initialise/verify_user.go +++ b/cmd/initialise/verify_user.go @@ -19,7 +19,7 @@ func newUser() *cobra.Command { Long: `Sets up the ZITADEL database user. Prerequisites: -- cockroachDB or postgreSQL +- postgreSQL The user provided by flags needs privileges to - create the database if it does not exist diff --git a/cmd/initialise/verify_user_test.go b/cmd/initialise/verify_user_test.go index 53b35e67db..40cde5baa2 100644 --- a/cmd/initialise/verify_user_test.go +++ b/cmd/initialise/verify_user_test.go @@ -8,7 +8,7 @@ import ( ) func Test_verifyUser(t *testing.T) { - err := ReadStmts("cockroach") //TODO: check all dialects + err := ReadStmts() if err != nil { t.Errorf("unable to read stmts: %v", err) t.FailNow() @@ -28,7 +28,7 @@ func Test_verifyUser(t *testing.T) { name: "doesn't exists, create fails", args: args{ db: prepareDB(t, - expectExec("-- replace zitadel-user with the name of the user\nCREATE USER IF NOT EXISTS \"zitadel-user\"", sql.ErrTxDone), + expectExec("-- replace zitadel-user with the name of the user\nCREATE USER \"zitadel-user\"", sql.ErrTxDone), ), username: "zitadel-user", password: "", @@ -39,7 +39,7 @@ func Test_verifyUser(t *testing.T) { name: "correct without password", args: args{ db: prepareDB(t, - expectExec("-- replace zitadel-user with the name of the user\nCREATE USER IF NOT EXISTS \"zitadel-user\"", nil), + expectExec("-- replace zitadel-user with the name of the user\nCREATE USER \"zitadel-user\"", nil), ), username: "zitadel-user", password: "", @@ -50,7 +50,7 @@ func Test_verifyUser(t *testing.T) { name: "correct with password", args: args{ db: prepareDB(t, - expectExec("-- replace zitadel-user with the name of the user\nCREATE USER IF NOT EXISTS \"zitadel-user\" WITH PASSWORD 'password'", nil), + expectExec("-- replace zitadel-user with the name of the user\nCREATE USER \"zitadel-user\" WITH PASSWORD 'password'", nil), ), username: "zitadel-user", password: "password", @@ -61,7 +61,7 @@ func Test_verifyUser(t *testing.T) { name: "already exists", args: args{ db: prepareDB(t, - expectExec("-- replace zitadel-user with the name of the user\nCREATE USER IF NOT EXISTS \"zitadel-user\" WITH PASSWORD 'password'", nil), + expectExec("-- replace zitadel-user with the name of the user\nCREATE USER \"zitadel-user\" WITH PASSWORD 'password'", nil), ), username: "zitadel-user", password: "", diff --git a/cmd/initialise/verify_zitadel.go b/cmd/initialise/verify_zitadel.go index 1ae85a21fa..78f28809c2 100644 --- a/cmd/initialise/verify_zitadel.go +++ b/cmd/initialise/verify_zitadel.go @@ -21,7 +21,7 @@ func newZitadel() *cobra.Command { Long: `initialize ZITADEL internals. Prerequisites: -- cockroachDB or postgreSQL with user and database +- postgreSQL with user and database `, Run: func(cmd *cobra.Command, args []string) { config := MustNewConfig(viper.GetViper()) @@ -32,7 +32,7 @@ Prerequisites: } func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config) error { - err := ReadStmts(config.Type()) + err := ReadStmts() if err != nil { return err } @@ -68,11 +68,6 @@ func VerifyZitadel(ctx context.Context, db *database.DB, config database.Config) return err } - logging.WithFields().Info("verify system sequence") - if err := exec(ctx, conn, createSystemSequenceStmt, nil); err != nil { - return err - } - logging.WithFields().Info("verify unique constraints") if err := exec(ctx, conn, createUniqueConstraints, nil); err != nil { return err diff --git a/cmd/initialise/verify_zitadel_test.go b/cmd/initialise/verify_zitadel_test.go index 194911a179..7fccd4c0a2 100644 --- a/cmd/initialise/verify_zitadel_test.go +++ b/cmd/initialise/verify_zitadel_test.go @@ -9,7 +9,7 @@ import ( ) func Test_verifyEvents(t *testing.T) { - err := ReadStmts("cockroach") //TODO: check all dialects + err := ReadStmts() if err != nil { t.Errorf("unable to read stmts: %v", err) t.FailNow() diff --git a/cmd/key/key.go b/cmd/key/key.go index 1dba8fd969..a1cf15b34e 100644 --- a/cmd/key/key.go +++ b/cmd/key/key.go @@ -40,7 +40,7 @@ func newKey() *cobra.Command { Long: `create new encryption key(s) (encrypted by the provided master key) provide key(s) by YAML file and/or by argument Requirements: -- cockroachdb`, +- postgreSQL`, Example: `new -f keys.yaml new key1=somekey key2=anotherkey new -f keys.yaml key2=anotherkey`, diff --git a/cmd/mirror/config.go b/cmd/mirror/config.go index cc98000869..89b0876e5f 100644 --- a/cmd/mirror/config.go +++ b/cmd/mirror/config.go @@ -71,7 +71,7 @@ func mustNewConfig(v *viper.Viper, config any) { mapstructure.StringToTimeDurationHookFunc(), mapstructure.StringToTimeHookFunc(time.RFC3339), mapstructure.StringToSliceHookFunc(","), - database.DecodeHook, + database.DecodeHook(true), actions.HTTPConfigDecodeHook, hook.EnumHookFunc(internal_authz.MemberTypeString), mapstructure.TextUnmarshallerHookFunc(), diff --git a/cmd/mirror/defaults.yaml b/cmd/mirror/defaults.yaml index 7db91ecc0b..f7499461d7 100644 --- a/cmd/mirror/defaults.yaml +++ b/cmd/mirror/defaults.yaml @@ -5,8 +5,6 @@ Source: Database: zitadel # ZITADEL_DATABASE_COCKROACH_DATABASE MaxOpenConns: 6 # ZITADEL_DATABASE_COCKROACH_MAXOPENCONNS MaxIdleConns: 6 # ZITADEL_DATABASE_COCKROACH_MAXIDLECONNS - EventPushConnRatio: 0.33 # ZITADEL_DATABASE_COCKROACH_EVENTPUSHCONNRATIO - ProjectionSpoolerConnRatio: 0.33 # ZITADEL_DATABASE_COCKROACH_PROJECTIONSPOOLERCONNRATIO MaxConnLifetime: 30m # ZITADEL_DATABASE_COCKROACH_MAXCONNLIFETIME MaxConnIdleTime: 5m # ZITADEL_DATABASE_COCKROACH_MAXCONNIDLETIME Options: "" # ZITADEL_DATABASE_COCKROACH_OPTIONS @@ -39,41 +37,20 @@ Source: Key: # ZITADEL_DATABASE_POSTGRES_USER_SSL_KEY Destination: - cockroach: - Host: localhost # ZITADEL_DATABASE_COCKROACH_HOST - Port: 26257 # ZITADEL_DATABASE_COCKROACH_PORT - Database: zitadel # ZITADEL_DATABASE_COCKROACH_DATABASE - MaxOpenConns: 0 # ZITADEL_DATABASE_COCKROACH_MAXOPENCONNS - MaxIdleConns: 0 # ZITADEL_DATABASE_COCKROACH_MAXIDLECONNS - MaxConnLifetime: 30m # ZITADEL_DATABASE_COCKROACH_MAXCONNLIFETIME - MaxConnIdleTime: 5m # ZITADEL_DATABASE_COCKROACH_MAXCONNIDLETIME - EventPushConnRatio: 0.01 # ZITADEL_DATABASE_COCKROACH_EVENTPUSHCONNRATIO - ProjectionSpoolerConnRatio: 0.5 # ZITADEL_DATABASE_COCKROACH_PROJECTIONSPOOLERCONNRATIO - Options: "" # ZITADEL_DATABASE_COCKROACH_OPTIONS - User: - Username: zitadel # ZITADEL_DATABASE_COCKROACH_USER_USERNAME - Password: "" # ZITADEL_DATABASE_COCKROACH_USER_PASSWORD - SSL: - Mode: disable # ZITADEL_DATABASE_COCKROACH_USER_SSL_MODE - RootCert: "" # ZITADEL_DATABASE_COCKROACH_USER_SSL_ROOTCERT - Cert: "" # ZITADEL_DATABASE_COCKROACH_USER_SSL_CERT - Key: "" # ZITADEL_DATABASE_COCKROACH_USER_SSL_KEY - # Postgres is used as soon as a value is set - # The values describe the possible fields to set values postgres: - Host: # ZITADEL_DATABASE_POSTGRES_HOST - Port: # ZITADEL_DATABASE_POSTGRES_PORT - Database: # ZITADEL_DATABASE_POSTGRES_DATABASE - MaxOpenConns: # ZITADEL_DATABASE_POSTGRES_MAXOPENCONNS - MaxIdleConns: # ZITADEL_DATABASE_POSTGRES_MAXIDLECONNS - MaxConnLifetime: # ZITADEL_DATABASE_POSTGRES_MAXCONNLIFETIME - MaxConnIdleTime: # ZITADEL_DATABASE_POSTGRES_MAXCONNIDLETIME - Options: # ZITADEL_DATABASE_POSTGRES_OPTIONS + Host: localhost # ZITADEL_DATABASE_POSTGRES_HOST + Port: 5432 # ZITADEL_DATABASE_POSTGRES_PORT + Database: zitadel # ZITADEL_DATABASE_POSTGRES_DATABASE + MaxOpenConns: 5 # ZITADEL_DATABASE_POSTGRES_MAXOPENCONNS + MaxIdleConns: 2 # ZITADEL_DATABASE_POSTGRES_MAXIDLECONNS + MaxConnLifetime: 30m # ZITADEL_DATABASE_POSTGRES_MAXCONNLIFETIME + MaxConnIdleTime: 5m # ZITADEL_DATABASE_POSTGRES_MAXCONNIDLETIME + Options: "" # ZITADEL_DATABASE_POSTGRES_OPTIONS User: - Username: # ZITADEL_DATABASE_POSTGRES_USER_USERNAME + Username: zitadel # ZITADEL_DATABASE_POSTGRES_USER_USERNAME Password: # ZITADEL_DATABASE_POSTGRES_USER_PASSWORD SSL: - Mode: # ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE + Mode: disable # ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE RootCert: # ZITADEL_DATABASE_POSTGRES_USER_SSL_ROOTCERT Cert: # ZITADEL_DATABASE_POSTGRES_USER_SSL_CERT Key: # ZITADEL_DATABASE_POSTGRES_USER_SSL_KEY diff --git a/cmd/mirror/event_store.go b/cmd/mirror/event_store.go index 3825462126..9b26865416 100644 --- a/cmd/mirror/event_store.go +++ b/cmd/mirror/event_store.go @@ -56,15 +56,15 @@ func copyEventstore(ctx context.Context, config *Migration) { } func positionQuery(db *db.DB) string { - switch db.Type() { - case "postgres": - return "SELECT EXTRACT(EPOCH FROM clock_timestamp())" - case "cockroach": - return "SELECT cluster_logical_timestamp()" - default: - logging.WithFields("db_type", db.Type()).Fatal("database type not recognized") - return "" - } + // switch db.Type() { + // case "postgres": + return "SELECT EXTRACT(EPOCH FROM clock_timestamp())" + // case "cockroach": + // return "SELECT cluster_logical_timestamp()" + // default: + // logging.WithFields("db_type", db.Type()).Fatal("database type not recognized") + // return "" + // } } func copyEvents(ctx context.Context, source, dest *db.DB, bulkSize uint32) { diff --git a/cmd/mirror/projections.go b/cmd/mirror/projections.go index c15747e74a..64e594ad9b 100644 --- a/cmd/mirror/projections.go +++ b/cmd/mirror/projections.go @@ -117,7 +117,7 @@ func projections( staticStorage, err := config.AssetStorage.NewStorage(client.DB) logging.OnError(err).Fatal("unable create static storage") - config.Eventstore.Querier = old_es.NewCRDB(client) + config.Eventstore.Querier = old_es.NewPostgres(client) config.Eventstore.Pusher = new_es.NewEventstore(client) es := eventstore.NewEventstore(config.Eventstore) esV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(client, &es_v4_pg.Config{ diff --git a/cmd/setup/07.go b/cmd/setup/07.go index 73b9d3480b..590b220eb3 100644 --- a/cmd/setup/07.go +++ b/cmd/setup/07.go @@ -3,7 +3,7 @@ package setup import ( "context" "database/sql" - "embed" + _ "embed" "strings" "github.com/zitadel/zitadel/internal/eventstore" @@ -12,31 +12,20 @@ import ( var ( //go:embed 07/logstore.sql createLogstoreSchema07 string - //go:embed 07/cockroach/access.sql - //go:embed 07/postgres/access.sql - createAccessLogsTable07 embed.FS - //go:embed 07/cockroach/execution.sql - //go:embed 07/postgres/execution.sql - createExecutionLogsTable07 embed.FS + //go:embed 07/access.sql + createAccessLogsTable07 string + //go:embed 07/execution.sql + createExecutionLogsTable07 string ) type LogstoreTables struct { dbClient *sql.DB username string - dbType string } func (mig *LogstoreTables) Execute(ctx context.Context, _ eventstore.Event) error { - accessStmt, err := readStmt(createAccessLogsTable07, "07", mig.dbType, "access.sql") - if err != nil { - return err - } - executionStmt, err := readStmt(createExecutionLogsTable07, "07", mig.dbType, "execution.sql") - if err != nil { - return err - } - stmt := strings.ReplaceAll(createLogstoreSchema07, "%[1]s", mig.username) + accessStmt + executionStmt - _, err = mig.dbClient.ExecContext(ctx, stmt) + stmt := strings.ReplaceAll(createLogstoreSchema07, "%[1]s", mig.username) + createAccessLogsTable07 + createExecutionLogsTable07 + _, err := mig.dbClient.ExecContext(ctx, stmt) return err } diff --git a/cmd/setup/07/postgres/access.sql b/cmd/setup/07/access.sql similarity index 100% rename from cmd/setup/07/postgres/access.sql rename to cmd/setup/07/access.sql diff --git a/cmd/setup/07/cockroach/access.sql b/cmd/setup/07/cockroach/access.sql deleted file mode 100644 index fc5354cf32..0000000000 --- a/cmd/setup/07/cockroach/access.sql +++ /dev/null @@ -1,14 +0,0 @@ -CREATE TABLE IF NOT EXISTS logstore.access ( - log_date TIMESTAMPTZ NOT NULL - , protocol INT NOT NULL - , request_url TEXT NOT NULL - , response_status INT NOT NULL - , request_headers JSONB - , response_headers JSONB - , instance_id TEXT NOT NULL - , project_id TEXT NOT NULL - , requested_domain TEXT - , requested_host TEXT - - , INDEX protocol_date_desc (instance_id, protocol, log_date DESC) STORING (request_url, response_status, request_headers) -); diff --git a/cmd/setup/07/cockroach/execution.sql b/cmd/setup/07/cockroach/execution.sql deleted file mode 100644 index b8e18b525a..0000000000 --- a/cmd/setup/07/cockroach/execution.sql +++ /dev/null @@ -1,11 +0,0 @@ -CREATE TABLE IF NOT EXISTS logstore.execution ( - log_date TIMESTAMPTZ NOT NULL - , took INTERVAL - , message TEXT NOT NULL - , loglevel INT NOT NULL - , instance_id TEXT NOT NULL - , action_id TEXT NOT NULL - , metadata JSONB - - , INDEX log_date_desc (instance_id, log_date DESC) STORING (took) -); diff --git a/cmd/setup/07/postgres/execution.sql b/cmd/setup/07/execution.sql similarity index 100% rename from cmd/setup/07/postgres/execution.sql rename to cmd/setup/07/execution.sql diff --git a/cmd/setup/08.go b/cmd/setup/08.go index bec6a65ebb..fa006bd3cf 100644 --- a/cmd/setup/08.go +++ b/cmd/setup/08.go @@ -2,16 +2,15 @@ package setup import ( "context" - "embed" + _ "embed" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore" ) var ( - //go:embed 08/cockroach/08.sql - //go:embed 08/postgres/08.sql - tokenIndexes08 embed.FS + //go:embed 08/08.sql + tokenIndexes08 string ) type AuthTokenIndexes struct { @@ -19,11 +18,7 @@ type AuthTokenIndexes struct { } func (mig *AuthTokenIndexes) Execute(ctx context.Context, _ eventstore.Event) error { - stmt, err := readStmt(tokenIndexes08, "08", mig.dbClient.Type(), "08.sql") - if err != nil { - return err - } - _, err = mig.dbClient.ExecContext(ctx, stmt) + _, err := mig.dbClient.ExecContext(ctx, tokenIndexes08) return err } diff --git a/cmd/setup/08/postgres/08.sql b/cmd/setup/08/08.sql similarity index 100% rename from cmd/setup/08/postgres/08.sql rename to cmd/setup/08/08.sql diff --git a/cmd/setup/08/cockroach/08.sql b/cmd/setup/08/cockroach/08.sql deleted file mode 100644 index aec4d54303..0000000000 --- a/cmd/setup/08/cockroach/08.sql +++ /dev/null @@ -1,5 +0,0 @@ -CREATE INDEX IF NOT EXISTS inst_refresh_tkn_idx ON auth.tokens(instance_id, refresh_token_id); -CREATE INDEX IF NOT EXISTS inst_app_tkn_idx ON auth.tokens(instance_id, application_id); -CREATE INDEX IF NOT EXISTS inst_ro_tkn_idx ON auth.tokens(instance_id, resource_owner); -DROP INDEX IF EXISTS auth.tokens@user_user_agent_idx; -CREATE INDEX IF NOT EXISTS inst_usr_agnt_tkn_idx ON auth.tokens(instance_id, user_id, user_agent_id); \ No newline at end of file diff --git a/cmd/setup/10.go b/cmd/setup/10.go index 93c017305c..b134fcab62 100644 --- a/cmd/setup/10.go +++ b/cmd/setup/10.go @@ -3,7 +3,7 @@ package setup import ( "context" "database/sql" - "embed" + _ "embed" "time" "github.com/cockroachdb/cockroach-go/v2/crdb" @@ -18,9 +18,8 @@ var ( correctCreationDate10CreateTable string //go:embed 10/10_fill_table.sql correctCreationDate10FillTable string - //go:embed 10/cockroach/10_update.sql - //go:embed 10/postgres/10_update.sql - correctCreationDate10Update embed.FS + //go:embed 10/10_update.sql + correctCreationDate10Update string //go:embed 10/10_count_wrong_events.sql correctCreationDate10CountWrongEvents string //go:embed 10/10_empty_table.sql @@ -40,11 +39,6 @@ func (mig *CorrectCreationDate) Execute(ctx context.Context, _ eventstore.Event) logging.WithFields("mig", mig.String(), "iteration", i).Debug("start iteration") var affected int64 err = crdb.ExecuteTx(ctx, mig.dbClient.DB, nil, func(tx *sql.Tx) error { - if mig.dbClient.Type() == "cockroach" { - if _, err := tx.Exec("SET experimental_enable_temp_tables=on"); err != nil { - return err - } - } _, err := tx.ExecContext(ctx, correctCreationDate10CreateTable) if err != nil { return err @@ -66,11 +60,7 @@ func (mig *CorrectCreationDate) Execute(ctx context.Context, _ eventstore.Event) return err } - updateStmt, err := readStmt(correctCreationDate10Update, "10", mig.dbClient.Type(), "10_update.sql") - if err != nil { - return err - } - _, err = tx.ExecContext(ctx, updateStmt) + _, err = tx.ExecContext(ctx, correctCreationDate10Update) if err != nil { return err } diff --git a/cmd/setup/10/postgres/10_update.sql b/cmd/setup/10/10_update.sql similarity index 100% rename from cmd/setup/10/postgres/10_update.sql rename to cmd/setup/10/10_update.sql diff --git a/cmd/setup/10/cockroach/10_update.sql b/cmd/setup/10/cockroach/10_update.sql deleted file mode 100644 index 9e7d7f993a..0000000000 --- a/cmd/setup/10/cockroach/10_update.sql +++ /dev/null @@ -1 +0,0 @@ -UPDATE eventstore.events e SET (creation_date, "position") = (we.next_cd, we.next_cd::DECIMAL) FROM wrong_events we WHERE e.event_sequence = we.event_sequence AND e.instance_id = we.instance_id; diff --git a/cmd/setup/14.go b/cmd/setup/14.go index f0ea1b819a..2cd5ac2c57 100644 --- a/cmd/setup/14.go +++ b/cmd/setup/14.go @@ -15,8 +15,7 @@ import ( ) var ( - //go:embed 14/cockroach/*.sql - //go:embed 14/postgres/*.sql + //go:embed 14/*.sql newEventsTable embed.FS ) @@ -40,7 +39,7 @@ func (mig *NewEventsTable) Execute(ctx context.Context, _ eventstore.Event) erro return err } - statements, err := readStatements(newEventsTable, "14", mig.dbClient.Type()) + statements, err := readStatements(newEventsTable, "14") if err != nil { return err } diff --git a/cmd/setup/14/cockroach/01_disable_inserts.sql b/cmd/setup/14/01_disable_inserts.sql similarity index 100% rename from cmd/setup/14/cockroach/01_disable_inserts.sql rename to cmd/setup/14/01_disable_inserts.sql diff --git a/cmd/setup/14/postgres/02_create_and_fill_events2.sql b/cmd/setup/14/02_create_and_fill_events2.sql similarity index 100% rename from cmd/setup/14/postgres/02_create_and_fill_events2.sql rename to cmd/setup/14/02_create_and_fill_events2.sql diff --git a/cmd/setup/14/postgres/03_events2_pk.sql b/cmd/setup/14/03_events2_pk.sql similarity index 100% rename from cmd/setup/14/postgres/03_events2_pk.sql rename to cmd/setup/14/03_events2_pk.sql diff --git a/cmd/setup/14/postgres/04_constraints.sql b/cmd/setup/14/04_constraints.sql similarity index 100% rename from cmd/setup/14/postgres/04_constraints.sql rename to cmd/setup/14/04_constraints.sql diff --git a/cmd/setup/14/postgres/05_indexes.sql b/cmd/setup/14/05_indexes.sql similarity index 100% rename from cmd/setup/14/postgres/05_indexes.sql rename to cmd/setup/14/05_indexes.sql diff --git a/cmd/setup/14/cockroach/02_create_and_fill_events2.sql b/cmd/setup/14/cockroach/02_create_and_fill_events2.sql deleted file mode 100644 index 300ac4b621..0000000000 --- a/cmd/setup/14/cockroach/02_create_and_fill_events2.sql +++ /dev/null @@ -1,33 +0,0 @@ -CREATE TABLE eventstore.events2 ( - instance_id, - aggregate_type, - aggregate_id, - - event_type, - "sequence", - revision, - created_at, - payload, - creator, - "owner", - - "position", - in_tx_order, - - PRIMARY KEY (instance_id, aggregate_type, aggregate_id, "sequence") -) AS SELECT - instance_id, - aggregate_type, - aggregate_id, - - event_type, - event_sequence, - substr(aggregate_version, 2)::SMALLINT, - creation_date, - event_data, - editor_user, - resource_owner, - - creation_date::DECIMAL, - event_sequence -FROM eventstore.events_old; \ No newline at end of file diff --git a/cmd/setup/14/cockroach/03_constraints.sql b/cmd/setup/14/cockroach/03_constraints.sql deleted file mode 100644 index 62f119cc43..0000000000 --- a/cmd/setup/14/cockroach/03_constraints.sql +++ /dev/null @@ -1,7 +0,0 @@ -ALTER TABLE eventstore.events2 ALTER COLUMN event_type SET NOT NULL; -ALTER TABLE eventstore.events2 ALTER COLUMN revision SET NOT NULL; -ALTER TABLE eventstore.events2 ALTER COLUMN created_at SET NOT NULL; -ALTER TABLE eventstore.events2 ALTER COLUMN creator SET NOT NULL; -ALTER TABLE eventstore.events2 ALTER COLUMN "owner" SET NOT NULL; -ALTER TABLE eventstore.events2 ALTER COLUMN "position" SET NOT NULL; -ALTER TABLE eventstore.events2 ALTER COLUMN in_tx_order SET NOT NULL; \ No newline at end of file diff --git a/cmd/setup/14/cockroach/04_indexes.sql b/cmd/setup/14/cockroach/04_indexes.sql deleted file mode 100644 index a442653606..0000000000 --- a/cmd/setup/14/cockroach/04_indexes.sql +++ /dev/null @@ -1,3 +0,0 @@ -CREATE INDEX IF NOT EXISTS es_active_instances ON eventstore.events2 (created_at DESC) STORING ("position"); -CREATE INDEX IF NOT EXISTS es_wm ON eventstore.events2 (aggregate_id, instance_id, aggregate_type, event_type); -CREATE INDEX IF NOT EXISTS es_projection ON eventstore.events2 (instance_id, aggregate_type, event_type, "position"); \ No newline at end of file diff --git a/cmd/setup/14/postgres/01_disable_inserts.sql b/cmd/setup/14/postgres/01_disable_inserts.sql deleted file mode 100644 index 0f3c277eba..0000000000 --- a/cmd/setup/14/postgres/01_disable_inserts.sql +++ /dev/null @@ -1 +0,0 @@ -ALTER TABLE eventstore.events RENAME TO events_old; \ No newline at end of file diff --git a/cmd/setup/15.go b/cmd/setup/15.go index 2e75ffb118..54161ddef9 100644 --- a/cmd/setup/15.go +++ b/cmd/setup/15.go @@ -11,8 +11,7 @@ import ( ) var ( - //go:embed 15/cockroach/*.sql - //go:embed 15/postgres/*.sql + //go:embed 15/*.sql currentProjectionState embed.FS ) @@ -21,7 +20,7 @@ type CurrentProjectionState struct { } func (mig *CurrentProjectionState) Execute(ctx context.Context, _ eventstore.Event) error { - statements, err := readStatements(currentProjectionState, "15", mig.dbClient.Type()) + statements, err := readStatements(currentProjectionState, "15") if err != nil { return err } diff --git a/cmd/setup/15/cockroach/01_new_failed_events.sql b/cmd/setup/15/01_new_failed_events.sql similarity index 100% rename from cmd/setup/15/cockroach/01_new_failed_events.sql rename to cmd/setup/15/01_new_failed_events.sql diff --git a/cmd/setup/15/postgres/02_fe_from_projections.sql b/cmd/setup/15/02_fe_from_projections.sql similarity index 100% rename from cmd/setup/15/postgres/02_fe_from_projections.sql rename to cmd/setup/15/02_fe_from_projections.sql diff --git a/cmd/setup/15/cockroach/03_fe_from_adminapi.sql b/cmd/setup/15/03_fe_from_adminapi.sql similarity index 100% rename from cmd/setup/15/cockroach/03_fe_from_adminapi.sql rename to cmd/setup/15/03_fe_from_adminapi.sql diff --git a/cmd/setup/15/cockroach/04_fe_from_auth.sql b/cmd/setup/15/04_fe_from_auth.sql similarity index 100% rename from cmd/setup/15/cockroach/04_fe_from_auth.sql rename to cmd/setup/15/04_fe_from_auth.sql diff --git a/cmd/setup/15/cockroach/05_current_states.sql b/cmd/setup/15/05_current_states.sql similarity index 100% rename from cmd/setup/15/cockroach/05_current_states.sql rename to cmd/setup/15/05_current_states.sql diff --git a/cmd/setup/15/postgres/06_cs_from_projections.sql b/cmd/setup/15/06_cs_from_projections.sql similarity index 100% rename from cmd/setup/15/postgres/06_cs_from_projections.sql rename to cmd/setup/15/06_cs_from_projections.sql diff --git a/cmd/setup/15/postgres/07_cs_from_adminapi.sql b/cmd/setup/15/07_cs_from_adminapi.sql similarity index 100% rename from cmd/setup/15/postgres/07_cs_from_adminapi.sql rename to cmd/setup/15/07_cs_from_adminapi.sql diff --git a/cmd/setup/15/postgres/08_cs_from_auth.sql b/cmd/setup/15/08_cs_from_auth.sql similarity index 100% rename from cmd/setup/15/postgres/08_cs_from_auth.sql rename to cmd/setup/15/08_cs_from_auth.sql diff --git a/cmd/setup/15/cockroach/02_fe_from_projections.sql b/cmd/setup/15/cockroach/02_fe_from_projections.sql deleted file mode 100644 index 8bf7a4b8d4..0000000000 --- a/cmd/setup/15/cockroach/02_fe_from_projections.sql +++ /dev/null @@ -1,26 +0,0 @@ -INSERT INTO projections.failed_events2 ( - projection_name - , instance_id - , aggregate_type - , aggregate_id - , event_creation_date - , failed_sequence - , failure_count - , error - , last_failed -) SELECT - fe.projection_name - , fe.instance_id - , e.aggregate_type - , e.aggregate_id - , e.created_at - , e.sequence - , fe.failure_count - , fe.error - , fe.last_failed -FROM - projections.failed_events fe -JOIN eventstore.events2 e ON - e.instance_id = fe.instance_id - AND e.sequence = fe.failed_sequence -ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/cmd/setup/15/cockroach/06_cs_from_projections.sql b/cmd/setup/15/cockroach/06_cs_from_projections.sql deleted file mode 100644 index 579afb6d4c..0000000000 --- a/cmd/setup/15/cockroach/06_cs_from_projections.sql +++ /dev/null @@ -1,29 +0,0 @@ -INSERT INTO projections.current_states ( - projection_name - , instance_id - , event_date - , "position" - , last_updated -) (SELECT - cs.projection_name - , cs.instance_id - , e.created_at - , e.position - , cs.timestamp -FROM - projections.current_sequences cs -JOIN eventstore.events2 e ON - e.instance_id = cs.instance_id - AND e.aggregate_type = cs.aggregate_type - AND e.sequence = cs.current_sequence - AND cs.current_sequence = ( - SELECT - MAX(cs2.current_sequence) - FROM - projections.current_sequences cs2 - WHERE - cs.projection_name = cs2.projection_name - AND cs.instance_id = cs2.instance_id - ) -) -ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/cmd/setup/15/cockroach/07_cs_from_adminapi.sql b/cmd/setup/15/cockroach/07_cs_from_adminapi.sql deleted file mode 100644 index c40d13a067..0000000000 --- a/cmd/setup/15/cockroach/07_cs_from_adminapi.sql +++ /dev/null @@ -1,28 +0,0 @@ -INSERT INTO projections.current_states ( - projection_name - , instance_id - , event_date - , "position" - , last_updated -) (SELECT - cs.view_name - , cs.instance_id - , e.created_at - , e.position - , cs.last_successful_spooler_run -FROM - adminapi.current_sequences cs -JOIN eventstore.events2 e ON - e.instance_id = cs.instance_id - AND e.sequence = cs.current_sequence - AND cs.current_sequence = ( - SELECT - MAX(cs2.current_sequence) - FROM - adminapi.current_sequences cs2 - WHERE - cs.view_name = cs2.view_name - AND cs.instance_id = cs2.instance_id - ) -) -ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/cmd/setup/15/cockroach/08_cs_from_auth.sql b/cmd/setup/15/cockroach/08_cs_from_auth.sql deleted file mode 100644 index c8e7236107..0000000000 --- a/cmd/setup/15/cockroach/08_cs_from_auth.sql +++ /dev/null @@ -1,28 +0,0 @@ -INSERT INTO projections.current_states ( - projection_name - , instance_id - , event_date - , "position" - , last_updated -) (SELECT - cs.view_name - , cs.instance_id - , e.created_at - , e.position - , cs.last_successful_spooler_run -FROM - auth.current_sequences cs -JOIN eventstore.events2 e ON - e.instance_id = cs.instance_id - AND e.sequence = cs.current_sequence - AND cs.current_sequence = ( - SELECT - MAX(cs2.current_sequence) - FROM - auth.current_sequences cs2 - WHERE - cs.view_name = cs2.view_name - AND cs.instance_id = cs2.instance_id - ) -) -ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/cmd/setup/15/postgres/01_new_failed_events.sql b/cmd/setup/15/postgres/01_new_failed_events.sql deleted file mode 100644 index 5fa39c08a5..0000000000 --- a/cmd/setup/15/postgres/01_new_failed_events.sql +++ /dev/null @@ -1,16 +0,0 @@ -CREATE TABLE IF NOT EXISTS projections.failed_events2 ( - projection_name TEXT NOT NULL - , instance_id TEXT NOT NULL - - , aggregate_type TEXT NOT NULL - , aggregate_id TEXT NOT NULL - , event_creation_date TIMESTAMPTZ NOT NULL - , failed_sequence INT8 NOT NULL - - , failure_count INT2 NULL DEFAULT 0 - , error TEXT - , last_failed TIMESTAMPTZ - - , PRIMARY KEY (projection_name, instance_id, aggregate_type, aggregate_id, failed_sequence) -); -CREATE INDEX IF NOT EXISTS fe2_instance_id_idx on projections.failed_events2 (instance_id); \ No newline at end of file diff --git a/cmd/setup/15/postgres/03_fe_from_adminapi.sql b/cmd/setup/15/postgres/03_fe_from_adminapi.sql deleted file mode 100644 index 1616662fed..0000000000 --- a/cmd/setup/15/postgres/03_fe_from_adminapi.sql +++ /dev/null @@ -1,26 +0,0 @@ -INSERT INTO projections.failed_events2 ( - projection_name - , instance_id - , aggregate_type - , aggregate_id - , event_creation_date - , failed_sequence - , failure_count - , error - , last_failed -) SELECT - fe.view_name - , fe.instance_id - , e.aggregate_type - , e.aggregate_id - , e.created_at - , e.sequence - , fe.failure_count - , fe.err_msg - , fe.last_failed -FROM - adminapi.failed_events fe -JOIN eventstore.events2 e ON - e.instance_id = fe.instance_id - AND e.sequence = fe.failed_sequence -ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/cmd/setup/15/postgres/04_fe_from_auth.sql b/cmd/setup/15/postgres/04_fe_from_auth.sql deleted file mode 100644 index a249293e24..0000000000 --- a/cmd/setup/15/postgres/04_fe_from_auth.sql +++ /dev/null @@ -1,26 +0,0 @@ -INSERT INTO projections.failed_events2 ( - projection_name - , instance_id - , aggregate_type - , aggregate_id - , event_creation_date - , failed_sequence - , failure_count - , error - , last_failed -) SELECT - fe.view_name - , fe.instance_id - , e.aggregate_type - , e.aggregate_id - , e.created_at - , e.sequence - , fe.failure_count - , fe.err_msg - , fe.last_failed -FROM - auth.failed_events fe -JOIN eventstore.events2 e ON - e.instance_id = fe.instance_id - AND e.sequence = fe.failed_sequence -ON CONFLICT DO NOTHING; \ No newline at end of file diff --git a/cmd/setup/15/postgres/05_current_states.sql b/cmd/setup/15/postgres/05_current_states.sql deleted file mode 100644 index bc2f5ed771..0000000000 --- a/cmd/setup/15/postgres/05_current_states.sql +++ /dev/null @@ -1,15 +0,0 @@ -CREATE TABLE IF NOT EXISTS projections.current_states ( - projection_name TEXT NOT NULL - , instance_id TEXT NOT NULL - - , last_updated TIMESTAMPTZ - - , aggregate_id TEXT - , aggregate_type TEXT - , "sequence" INT8 - , event_date TIMESTAMPTZ - , "position" DECIMAL - - , PRIMARY KEY (projection_name, instance_id) -); -CREATE INDEX IF NOT EXISTS cs_instance_id_idx ON projections.current_states (instance_id); \ No newline at end of file diff --git a/cmd/setup/34.go b/cmd/setup/34.go index 59854e9e97..75e4076803 100644 --- a/cmd/setup/34.go +++ b/cmd/setup/34.go @@ -3,17 +3,14 @@ package setup import ( "context" _ "embed" - "fmt" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore" ) var ( - //go:embed 34/cockroach/34_cache_schema.sql - addCacheSchemaCockroach string - //go:embed 34/postgres/34_cache_schema.sql - addCacheSchemaPostgres string + //go:embed 34/34_cache_schema.sql + addCacheSchema string ) type AddCacheSchema struct { @@ -21,14 +18,7 @@ type AddCacheSchema struct { } func (mig *AddCacheSchema) Execute(ctx context.Context, _ eventstore.Event) (err error) { - switch mig.dbClient.Type() { - case "cockroach": - _, err = mig.dbClient.ExecContext(ctx, addCacheSchemaCockroach) - case "postgres": - _, err = mig.dbClient.ExecContext(ctx, addCacheSchemaPostgres) - default: - err = fmt.Errorf("add cache schema: unsupported db type %q", mig.dbClient.Type()) - } + _, err = mig.dbClient.ExecContext(ctx, addCacheSchema) return err } diff --git a/cmd/setup/34/postgres/34_cache_schema.sql b/cmd/setup/34/34_cache_schema.sql similarity index 100% rename from cmd/setup/34/postgres/34_cache_schema.sql rename to cmd/setup/34/34_cache_schema.sql diff --git a/cmd/setup/34/cockroach/34_cache_schema.sql b/cmd/setup/34/cockroach/34_cache_schema.sql deleted file mode 100644 index 0f866b0ccd..0000000000 --- a/cmd/setup/34/cockroach/34_cache_schema.sql +++ /dev/null @@ -1,27 +0,0 @@ -create schema if not exists cache; - -create table if not exists cache.objects ( - cache_name varchar not null, - id uuid not null default gen_random_uuid(), - created_at timestamptz not null default now(), - last_used_at timestamptz not null default now(), - payload jsonb not null, - - primary key(cache_name, id) -); - -create table if not exists cache.string_keys( - cache_name varchar not null check (cache_name <> ''), - index_id integer not null check (index_id > 0), - index_key varchar not null check (index_key <> ''), - object_id uuid not null, - - primary key (cache_name, index_id, index_key), - constraint fk_object - foreign key(cache_name, object_id) - references cache.objects(cache_name, id) - on delete cascade -); - -create index if not exists string_keys_object_id_idx - on cache.string_keys (cache_name, object_id); -- for delete cascade diff --git a/cmd/setup/35.go b/cmd/setup/35.go index f8473cfbfd..68e08bdfdb 100644 --- a/cmd/setup/35.go +++ b/cmd/setup/35.go @@ -21,7 +21,7 @@ type AddPositionToIndexEsWm struct { } func (mig *AddPositionToIndexEsWm) Execute(ctx context.Context, _ eventstore.Event) error { - statements, err := readStatements(addPositionToEsWmIndex, "35", "") + statements, err := readStatements(addPositionToEsWmIndex, "35") if err != nil { return err } diff --git a/cmd/setup/40.go b/cmd/setup/40.go index b16b9226f7..86cdab0d11 100644 --- a/cmd/setup/40.go +++ b/cmd/setup/40.go @@ -24,8 +24,7 @@ const ( ) var ( - //go:embed 40/cockroach/*.sql - //go:embed 40/postgres/*.sql + //go:embed 40/*.sql initPushFunc embed.FS ) @@ -112,5 +111,5 @@ func (mig *InitPushFunc) inTxOrderType(ctx context.Context) (typeName string, er } func (mig *InitPushFunc) filePath(fileName string) string { - return path.Join("40", mig.dbClient.Type(), fileName) + return path.Join("40", fileName) } diff --git a/cmd/setup/40/cockroach/00_in_tx_order_type.sql b/cmd/setup/40/00_in_tx_order_type.sql similarity index 100% rename from cmd/setup/40/cockroach/00_in_tx_order_type.sql rename to cmd/setup/40/00_in_tx_order_type.sql diff --git a/cmd/setup/40/postgres/01_type.sql b/cmd/setup/40/01_type.sql similarity index 100% rename from cmd/setup/40/postgres/01_type.sql rename to cmd/setup/40/01_type.sql diff --git a/cmd/setup/40/postgres/02_func.sql b/cmd/setup/40/02_func.sql similarity index 100% rename from cmd/setup/40/postgres/02_func.sql rename to cmd/setup/40/02_func.sql diff --git a/cmd/setup/40/cockroach/01_type.sql b/cmd/setup/40/cockroach/01_type.sql deleted file mode 100644 index e26af2f828..0000000000 --- a/cmd/setup/40/cockroach/01_type.sql +++ /dev/null @@ -1,10 +0,0 @@ -CREATE TYPE IF NOT EXISTS eventstore.command AS ( - instance_id TEXT - , aggregate_type TEXT - , aggregate_id TEXT - , command_type TEXT - , revision INT2 - , payload JSONB - , creator TEXT - , owner TEXT -); diff --git a/cmd/setup/40/cockroach/02_func.sql b/cmd/setup/40/cockroach/02_func.sql deleted file mode 100644 index 9cb45529ad..0000000000 --- a/cmd/setup/40/cockroach/02_func.sql +++ /dev/null @@ -1,137 +0,0 @@ -CREATE OR REPLACE FUNCTION eventstore.latest_aggregate_state( - instance_id TEXT - , aggregate_type TEXT - , aggregate_id TEXT - - , sequence OUT BIGINT - , owner OUT TEXT -) - LANGUAGE 'plpgsql' -AS $$ - BEGIN - SELECT - COALESCE(e.sequence, 0) AS sequence - , e.owner - INTO - sequence - , owner - FROM - eventstore.events2 e - WHERE - e.instance_id = $1 - AND e.aggregate_type = $2 - AND e.aggregate_id = $3 - ORDER BY - e.sequence DESC - LIMIT 1; - - RETURN; - END; -$$; - -CREATE OR REPLACE FUNCTION eventstore.commands_to_events2(commands eventstore.command[]) - RETURNS eventstore.events2[] - LANGUAGE 'plpgsql' -AS $$ -DECLARE - current_sequence BIGINT; - current_owner TEXT; - - instance_id TEXT; - aggregate_type TEXT; - aggregate_id TEXT; - - _events eventstore.events2[]; - - _aggregates CURSOR FOR - select - DISTINCT ("c").instance_id - , ("c").aggregate_type - , ("c").aggregate_id - FROM - UNNEST(commands) AS c; -BEGIN - OPEN _aggregates; - LOOP - FETCH NEXT IN _aggregates INTO instance_id, aggregate_type, aggregate_id; - -- crdb does not support EXIT WHEN NOT FOUND - EXIT WHEN instance_id IS NULL; - - SELECT - * - INTO - current_sequence - , current_owner - FROM eventstore.latest_aggregate_state( - instance_id - , aggregate_type - , aggregate_id - ); - - -- RETURN QUERY is not supported by crdb: https://github.com/cockroachdb/cockroach/issues/105240 - SELECT - ARRAY_CAT(_events, ARRAY_AGG(e)) - INTO - _events - FROM ( - SELECT - ("c").instance_id - , ("c").aggregate_type - , ("c").aggregate_id - , ("c").command_type -- AS event_type - , COALESCE(current_sequence, 0) + ROW_NUMBER() OVER () -- AS sequence - , ("c").revision - , NOW() -- AS created_at - , ("c").payload - , ("c").creator - , COALESCE(current_owner, ("c").owner) -- AS owner - , cluster_logical_timestamp() -- AS position - , ordinality::{{ .InTxOrderType }} -- AS in_tx_order - FROM - UNNEST(commands) WITH ORDINALITY AS c - WHERE - ("c").instance_id = instance_id - AND ("c").aggregate_type = aggregate_type - AND ("c").aggregate_id = aggregate_id - ) AS e; - END LOOP; - CLOSE _aggregates; - RETURN _events; -END; -$$; - -CREATE OR REPLACE FUNCTION eventstore.push(commands eventstore.command[]) RETURNS SETOF eventstore.events2 AS $$ - INSERT INTO eventstore.events2 - SELECT - ("e").instance_id - , ("e").aggregate_type - , ("e").aggregate_id - , ("e").event_type - , ("e").sequence - , ("e").revision - , ("e").created_at - , ("e").payload - , ("e").creator - , ("e").owner - , ("e")."position" - , ("e").in_tx_order - FROM - UNNEST(eventstore.commands_to_events2(commands)) e - ORDER BY - in_tx_order - RETURNING * -$$ LANGUAGE SQL; - -/* -select (c).* from UNNEST(eventstore.commands_to_events2( -ARRAY[ - ROW('', 'system', 'SYSTEM', 'ct1', 1, '{"key": "value"}', 'c1', 'SYSTEM') - , ROW('', 'system', 'SYSTEM', 'ct2', 1, '{"key": "value"}', 'c1', 'SYSTEM') - , ROW('289525561255060732', 'org', '289575074711790844', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844') - , ROW('289525561255060732', 'user', '289575075164906748', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844') - , ROW('289525561255060732', 'oidc_session', 'V2_289575178579535100', 'ct3', 1, '{"key": "value"}', 'c1', '289575074711790844') - , ROW('', 'system', 'SYSTEM', 'ct3', 1, '{"key": "value"}', 'c1', 'SYSTEM') -]::eventstore.command[] -) )c; -*/ - diff --git a/cmd/setup/40/postgres/00_in_tx_order_type.sql b/cmd/setup/40/postgres/00_in_tx_order_type.sql deleted file mode 100644 index 68b7daf984..0000000000 --- a/cmd/setup/40/postgres/00_in_tx_order_type.sql +++ /dev/null @@ -1,5 +0,0 @@ -SELECT data_type -FROM information_schema.columns -WHERE table_schema = 'eventstore' -AND table_name = 'events2' -AND column_name = 'in_tx_order'; diff --git a/cmd/setup/43.go b/cmd/setup/43.go index 844c25cf24..1fa09773bc 100644 --- a/cmd/setup/43.go +++ b/cmd/setup/43.go @@ -12,8 +12,7 @@ import ( ) var ( - //go:embed 43/cockroach/*.sql - //go:embed 43/postgres/*.sql + //go:embed 43/*.sql createFieldsDomainIndex embed.FS ) @@ -22,7 +21,7 @@ type CreateFieldsDomainIndex struct { } func (mig *CreateFieldsDomainIndex) Execute(ctx context.Context, _ eventstore.Event) error { - statements, err := readStatements(createFieldsDomainIndex, "43", mig.dbClient.Type()) + statements, err := readStatements(createFieldsDomainIndex, "43") if err != nil { return err } diff --git a/cmd/setup/43/postgres/43.sql b/cmd/setup/43/43.sql similarity index 100% rename from cmd/setup/43/postgres/43.sql rename to cmd/setup/43/43.sql diff --git a/cmd/setup/43/cockroach/43.sql b/cmd/setup/43/cockroach/43.sql deleted file mode 100644 index 9152130970..0000000000 --- a/cmd/setup/43/cockroach/43.sql +++ /dev/null @@ -1,3 +0,0 @@ -CREATE INDEX CONCURRENTLY IF NOT EXISTS fields_instance_domains_idx -ON eventstore.fields (object_id) -WHERE object_type = 'instance_domain' AND field_name = 'domain'; \ No newline at end of file diff --git a/cmd/setup/44.go b/cmd/setup/44.go index 11c355a053..5eb2f8d5c1 100644 --- a/cmd/setup/44.go +++ b/cmd/setup/44.go @@ -21,7 +21,7 @@ type ReplaceCurrentSequencesIndex struct { } func (mig *ReplaceCurrentSequencesIndex) Execute(ctx context.Context, _ eventstore.Event) error { - statements, err := readStatements(replaceCurrentSequencesIndex, "44", "") + statements, err := readStatements(replaceCurrentSequencesIndex, "44") if err != nil { return err } diff --git a/cmd/setup/46.go b/cmd/setup/46.go index e48b16e4b0..3593a1b668 100644 --- a/cmd/setup/46.go +++ b/cmd/setup/46.go @@ -21,7 +21,7 @@ var ( ) func (mig *InitPermissionFunctions) Execute(ctx context.Context, _ eventstore.Event) error { - statements, err := readStatements(permissionFunctions, "46", "") + statements, err := readStatements(permissionFunctions, "46") if err != nil { return err } diff --git a/cmd/setup/49.go b/cmd/setup/49.go index 28bf797110..8465589140 100644 --- a/cmd/setup/49.go +++ b/cmd/setup/49.go @@ -21,7 +21,7 @@ var ( ) func (mig *InitPermittedOrgsFunction) Execute(ctx context.Context, _ eventstore.Event) error { - statements, err := readStatements(permittedOrgsFunction, "49", "") + statements, err := readStatements(permittedOrgsFunction, "49") if err != nil { return err } diff --git a/cmd/setup/cleanup.go b/cmd/setup/cleanup.go index 943ac164ea..e0a07c0a9d 100644 --- a/cmd/setup/cleanup.go +++ b/cmd/setup/cleanup.go @@ -35,7 +35,7 @@ func Cleanup(config *Config) { logging.OnError(err).Fatal("unable to connect to database") config.Eventstore.Pusher = new_es.NewEventstore(dbClient) - config.Eventstore.Querier = old_es.NewCRDB(dbClient) + config.Eventstore.Querier = old_es.NewPostgres(dbClient) es := eventstore.NewEventstore(config.Eventstore) step, err := migration.LastStuckStep(ctx, es) diff --git a/cmd/setup/config.go b/cmd/setup/config.go index 6706d219e6..a5903cb48a 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -69,7 +69,7 @@ func MustNewConfig(v *viper.Viper) *Config { hooks.SliceTypeStringDecode[internal_authz.RoleMapping], hooks.MapTypeStringDecode[string, *internal_authz.SystemAPIUser], hooks.MapHTTPHeaderStringDecode, - database.DecodeHook, + database.DecodeHook(false), actions.HTTPConfigDecodeHook, hook.EnumHookFunc(internal_authz.MemberTypeString), hook.Base64ToBytesHookFunc(), diff --git a/cmd/setup/river_queue_repeatable.go b/cmd/setup/river_queue_repeatable.go index 5248894a8f..bfbd3ee581 100644 --- a/cmd/setup/river_queue_repeatable.go +++ b/cmd/setup/river_queue_repeatable.go @@ -13,9 +13,6 @@ type RiverMigrateRepeatable struct { } func (mig *RiverMigrateRepeatable) Execute(ctx context.Context, _ eventstore.Event) error { - if mig.client.Type() != "postgres" { - return nil - } return queue.NewMigrator(mig.client).Execute(ctx) } diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index 9d57928d06..2e4d5317a5 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -55,7 +55,7 @@ func New() *cobra.Command { Short: "setup ZITADEL instance", Long: `sets up data to start ZITADEL. Requirements: -- cockroachdb`, +- postgreSQL`, Run: func(cmd *cobra.Command, args []string) { err := tls.ModeFromFlag(cmd) logging.OnError(err).Fatal("invalid tlsMode") @@ -107,7 +107,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) dbClient, err := database.Connect(config.Database, false) logging.OnError(err).Fatal("unable to connect to database") - config.Eventstore.Querier = old_es.NewCRDB(dbClient) + config.Eventstore.Querier = old_es.NewPostgres(dbClient) esV3 := new_es.NewEventstore(dbClient) config.Eventstore.Pusher = esV3 config.Eventstore.Searcher = esV3 @@ -137,7 +137,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s5LastFailed = &LastFailed{dbClient: dbClient.DB} steps.s6OwnerRemoveColumns = &OwnerRemoveColumns{dbClient: dbClient.DB} - steps.s7LogstoreTables = &LogstoreTables{dbClient: dbClient.DB, username: config.Database.Username(), dbType: config.Database.Type()} + steps.s7LogstoreTables = &LogstoreTables{dbClient: dbClient.DB, username: config.Database.Username()} steps.s8AuthTokens = &AuthTokenIndexes{dbClient: dbClient} steps.CorrectCreationDate.dbClient = dbClient steps.s12AddOTPColumns = &AddOTPColumns{dbClient: dbClient} @@ -304,8 +304,8 @@ func mustExecuteMigration(ctx context.Context, eventstoreClient *eventstore.Even // under the folder/typ/filename path. // Typ describes the database dialect and may be omitted if no // dialect specific migration is specified. -func readStmt(fs embed.FS, folder, typ, filename string) (string, error) { - stmt, err := fs.ReadFile(path.Join(folder, typ, filename)) +func readStmt(fs embed.FS, folder, filename string) (string, error) { + stmt, err := fs.ReadFile(path.Join(folder, filename)) return string(stmt), err } @@ -318,16 +318,15 @@ type statement struct { // under the folder/type path. // Typ describes the database dialect and may be omitted if no // dialect specific migration is specified. -func readStatements(fs embed.FS, folder, typ string) ([]statement, error) { - basePath := path.Join(folder, typ) - dir, err := fs.ReadDir(basePath) +func readStatements(fs embed.FS, folder string) ([]statement, error) { + dir, err := fs.ReadDir(folder) if err != nil { return nil, err } statements := make([]statement, len(dir)) for i, file := range dir { statements[i].file = file.Name() - statements[i].query, err = readStmt(fs, folder, typ, file.Name()) + statements[i].query, err = readStmt(fs, folder, file.Name()) if err != nil { return nil, err } @@ -468,9 +467,6 @@ func startCommandsQueries( ) logging.OnError(err).Fatal("unable to start commands") - if !config.Notifications.LegacyEnabled && dbClient.Type() == "cockroach" { - logging.Fatal("notifications must be set to LegacyEnabled=true when using CockroachDB") - } q, err := queue.NewQueue(&queue.Config{ Client: dbClient, }) diff --git a/cmd/start/config.go b/cmd/start/config.go index 910759b653..cab39b6c85 100644 --- a/cmd/start/config.go +++ b/cmd/start/config.go @@ -97,7 +97,7 @@ func MustNewConfig(v *viper.Viper) *Config { hooks.SliceTypeStringDecode[internal_authz.RoleMapping], hooks.MapTypeStringDecode[string, *internal_authz.SystemAPIUser], hooks.MapHTTPHeaderStringDecode, - database.DecodeHook, + database.DecodeHook(false), actions.HTTPConfigDecodeHook, hook.EnumHookFunc(internal_authz.MemberTypeString), hooks.MapTypeStringDecode[domain.Feature, any], diff --git a/cmd/start/start.go b/cmd/start/start.go index 7e102b711c..5ab313a48a 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -4,7 +4,6 @@ import ( "context" "crypto/tls" _ "embed" - "errors" "fmt" "math" "net/http" @@ -107,7 +106,7 @@ func New(server chan<- *Server) *cobra.Command { Short: "starts ZITADEL instance", Long: `starts ZITADEL. Requirements: -- cockroachdb`, +- postgreSQL`, RunE: func(cmd *cobra.Command, args []string) error { err := cmd_tls.ModeFromFlag(cmd) if err != nil { @@ -163,7 +162,7 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server config.Eventstore.Pusher = new_es.NewEventstore(dbClient) config.Eventstore.Searcher = new_es.NewEventstore(dbClient) - config.Eventstore.Querier = old_es.NewCRDB(dbClient) + config.Eventstore.Querier = old_es.NewPostgres(dbClient) eventstoreClient := eventstore.NewEventstore(config.Eventstore) eventstoreV4 := es_v4.NewEventstoreFromOne(es_v4_pg.New(dbClient, &es_v4_pg.Config{ MaxRetries: config.Eventstore.MaxRetries, @@ -269,9 +268,6 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server actionsLogstoreSvc := logstore.New(queries, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter) actions.SetLogstoreService(actionsLogstoreSvc) - if !config.Notifications.LegacyEnabled && dbClient.Type() == "cockroach" { - return errors.New("notifications must be set to LegacyEnabled=true when using CockroachDB") - } q, err := queue.NewQueue(&queue.Config{ Client: dbClient, }) diff --git a/cmd/start/start_from_init.go b/cmd/start/start_from_init.go index 38a6a6c4d1..62d705b33c 100644 --- a/cmd/start/start_from_init.go +++ b/cmd/start/start_from_init.go @@ -21,7 +21,7 @@ Second the initial events are created. Last ZITADEL starts. Requirements: -- cockroachdb`, +- postgreSQL`, Run: func(cmd *cobra.Command, args []string) { err := tls.ModeFromFlag(cmd) logging.OnError(err).Fatal("invalid tlsMode") diff --git a/docs/docs/self-hosting/manage/cli/mirror.mdx b/docs/docs/self-hosting/manage/cli/mirror.mdx index 1c32dc8741..4428abec73 100644 --- a/docs/docs/self-hosting/manage/cli/mirror.mdx +++ b/docs/docs/self-hosting/manage/cli/mirror.mdx @@ -87,8 +87,6 @@ Source: Database: zitadel # ZITADEL_SOURCE_COCKROACH_DATABASE MaxOpenConns: 6 # ZITADEL_SOURCE_COCKROACH_MAXOPENCONNS MaxIdleConns: 6 # ZITADEL_SOURCE_COCKROACH_MAXIDLECONNS - EventPushConnRatio: 0.33 # ZITADEL_SOURCE_COCKROACH_EVENTPUSHCONNRATIO - ProjectionSpoolerConnRatio: 0.33 # ZITADEL_SOURCE_COCKROACH_PROJECTIONSPOOLERCONNRATIO MaxConnLifetime: 30m # ZITADEL_SOURCE_COCKROACH_MAXCONNLIFETIME MaxConnIdleTime: 5m # ZITADEL_SOURCE_COCKROACH_MAXCONNIDLETIME Options: "" # ZITADEL_SOURCE_COCKROACH_OPTIONS @@ -122,44 +120,23 @@ Source: # The destination database the data are copied to. Use either cockroach or postgres, by default cockroach is used Destination: - cockroach: - Host: localhost # ZITADEL_DESTINATION_COCKROACH_HOST - Port: 26257 # ZITADEL_DESTINATION_COCKROACH_PORT - Database: zitadel # ZITADEL_DESTINATION_COCKROACH_DATABASE - MaxOpenConns: 0 # ZITADEL_DESTINATION_COCKROACH_MAXOPENCONNS - MaxIdleConns: 0 # ZITADEL_DESTINATION_COCKROACH_MAXIDLECONNS - MaxConnLifetime: 30m # ZITADEL_DESTINATION_COCKROACH_MAXCONNLIFETIME - MaxConnIdleTime: 5m # ZITADEL_DESTINATION_COCKROACH_MAXCONNIDLETIME - EventPushConnRatio: 0.01 # ZITADEL_DESTINATION_COCKROACH_EVENTPUSHCONNRATIO - ProjectionSpoolerConnRatio: 0.5 # ZITADEL_DESTINATION_COCKROACH_PROJECTIONSPOOLERCONNRATIO - Options: "" # ZITADEL_DESTINATION_COCKROACH_OPTIONS - User: - Username: zitadel # ZITADEL_DESTINATION_COCKROACH_USER_USERNAME - Password: "" # ZITADEL_DESTINATION_COCKROACH_USER_PASSWORD - SSL: - Mode: disable # ZITADEL_DESTINATION_COCKROACH_USER_SSL_MODE - RootCert: "" # ZITADEL_DESTINATION_COCKROACH_USER_SSL_ROOTCERT - Cert: "" # ZITADEL_DESTINATION_COCKROACH_USER_SSL_CERT - Key: "" # ZITADEL_DESTINATION_COCKROACH_USER_SSL_KEY - # Postgres is used as soon as a value is set - # The values describe the possible fields to set values postgres: - Host: # ZITADEL_DESTINATION_POSTGRES_HOST - Port: # ZITADEL_DESTINATION_POSTGRES_PORT - Database: # ZITADEL_DESTINATION_POSTGRES_DATABASE - MaxOpenConns: # ZITADEL_DESTINATION_POSTGRES_MAXOPENCONNS - MaxIdleConns: # ZITADEL_DESTINATION_POSTGRES_MAXIDLECONNS - MaxConnLifetime: # ZITADEL_DESTINATION_POSTGRES_MAXCONNLIFETIME - MaxConnIdleTime: # ZITADEL_DESTINATION_POSTGRES_MAXCONNIDLETIME - Options: # ZITADEL_DESTINATION_POSTGRES_OPTIONS + Host: localhost # ZITADEL_DATABASE_POSTGRES_HOST + Port: 5432 # ZITADEL_DATABASE_POSTGRES_PORT + Database: zitadel # ZITADEL_DATABASE_POSTGRES_DATABASE + MaxOpenConns: 5 # ZITADEL_DATABASE_POSTGRES_MAXOPENCONNS + MaxIdleConns: 2 # ZITADEL_DATABASE_POSTGRES_MAXIDLECONNS + MaxConnLifetime: 30m # ZITADEL_DATABASE_POSTGRES_MAXCONNLIFETIME + MaxConnIdleTime: 5m # ZITADEL_DATABASE_POSTGRES_MAXCONNIDLETIME + Options: "" # ZITADEL_DATABASE_POSTGRES_OPTIONS User: - Username: # ZITADEL_DESTINATION_POSTGRES_USER_USERNAME - Password: # ZITADEL_DESTINATION_POSTGRES_USER_PASSWORD + Username: zitadel # ZITADEL_DATABASE_POSTGRES_USER_USERNAME + Password: # ZITADEL_DATABASE_POSTGRES_USER_PASSWORD SSL: - Mode: # ZITADEL_DESTINATION_POSTGRES_USER_SSL_MODE - RootCert: # ZITADEL_DESTINATION_POSTGRES_USER_SSL_ROOTCERT - Cert: # ZITADEL_DESTINATION_POSTGRES_USER_SSL_CERT - Key: # ZITADEL_DESTINATION_POSTGRES_USER_SSL_KEY + Mode: disable # ZITADEL_DATABASE_POSTGRES_USER_SSL_MODE + RootCert: # ZITADEL_DATABASE_POSTGRES_USER_SSL_ROOTCERT + Cert: # ZITADEL_DATABASE_POSTGRES_USER_SSL_CERT + Key: # ZITADEL_DATABASE_POSTGRES_USER_SSL_KEY # As cockroachdb first copies the data into memory this parameter is used to iterate through the events table and fetch only the given amount of events per iteration EventBulkSize: 10000 # ZITADEL_EVENTBULKSIZE diff --git a/e2e/config/host.docker.internal/zitadel.yaml b/e2e/config/host.docker.internal/zitadel.yaml index cb7e985be1..203dd16437 100644 --- a/e2e/config/host.docker.internal/zitadel.yaml +++ b/e2e/config/host.docker.internal/zitadel.yaml @@ -5,12 +5,23 @@ ExternalDomain: host.docker.internal ExternalSecure: false Database: - cockroach: + postgres: # This makes the e2e config reusable with an out-of-docker zitadel process and an /etc/hosts entry - Host: host.docker.internal - EventPushConnRatio: 0.2 + Host: host.docker.internal + Port: 5432 MaxOpenConns: 15 MaxIdleConns: 10 + Database: zitadel + User: + Username: zitadel + Password: zitadel + SSL: + Mode: disable + Admin: + Username: postgres + Password: postgres + SSL: + Mode: disable TLS: Enabled: false diff --git a/e2e/config/localhost/docker-compose.yaml b/e2e/config/localhost/docker-compose.yaml index f90ee158f0..41334d92f9 100644 --- a/e2e/config/localhost/docker-compose.yaml +++ b/e2e/config/localhost/docker-compose.yaml @@ -30,14 +30,15 @@ services: db: restart: 'always' - image: 'cockroachdb/cockroach:latest' - command: 'start-single-node --insecure --http-addr :9090' + image: 'postgres:17-alpine' + environment: + - POSTGRES_USER=postgres + - POSTGRES_PASSWORD=postgres healthcheck: - test: ['CMD', 'curl', '-f', 'http://localhost:9090/health?ready=1'] + test: ["CMD-SHELL", "pg_isready", "-d", "zitadel", "-U", "postgres"] interval: '10s' timeout: '30s' retries: 5 start_period: '20s' ports: - - "26257:26257" - - "9090:9090" + - "5432:5432" diff --git a/e2e/config/localhost/zitadel.yaml b/e2e/config/localhost/zitadel.yaml index 649f35fa9d..966bb4f6b7 100644 --- a/e2e/config/localhost/zitadel.yaml +++ b/e2e/config/localhost/zitadel.yaml @@ -5,12 +5,24 @@ ExternalDomain: localhost ExternalSecure: false Database: - cockroach: + postgres: # This makes the e2e config reusable with an out-of-docker zitadel process and an /etc/hosts entry Host: host.docker.internal - EventPushConnRatio: 0.2 + Port: 5432 + database: zitadel MaxOpenConns: 15 MaxIdleConns: 10 + Database: zitadel + User: + Username: zitadel + Password: zitadel + SSL: + Mode: disable + Admin: + Username: postgres + Password: postgres + SSL: + Mode: disable TLS: Enabled: false diff --git a/go.mod b/go.mod index 3f81b16ac5..c8ca6f1dfc 100644 --- a/go.mod +++ b/go.mod @@ -24,6 +24,7 @@ require ( github.com/drone/envsubst v1.0.3 github.com/envoyproxy/protoc-gen-validate v1.0.4 github.com/fatih/color v1.17.0 + github.com/fergusstrange/embedded-postgres v1.30.0 github.com/gabriel-vasile/mimetype v1.4.4 github.com/go-chi/chi/v5 v5.1.0 github.com/go-jose/go-jose/v4 v4.0.4 @@ -135,6 +136,7 @@ require ( github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/sjson v1.2.5 // indirect + github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 // indirect github.com/yuin/gopher-lua v1.1.1 // indirect github.com/zenazn/goji v1.0.1 // indirect go.uber.org/goleak v1.3.0 // indirect @@ -174,7 +176,6 @@ require ( github.com/go-errors/errors v1.5.1 // indirect github.com/go-sourcemap/sourcemap v2.1.4+incompatible // indirect github.com/go-xmlfmt/xmlfmt v1.1.2 // indirect - github.com/gofrs/flock v0.8.1 // indirect github.com/golang/geo v0.0.0-20230421003525-6adc56603217 // indirect github.com/golang/groupcache v0.0.0-20210331224755-41bb18bfe9da // indirect github.com/google/uuid v1.6.0 diff --git a/go.sum b/go.sum index 9c992f3662..cb86173165 100644 --- a/go.sum +++ b/go.sum @@ -202,6 +202,8 @@ github.com/fatih/color v1.17.0 h1:GlRw1BRJxkpqUCBKzKOw098ed57fEsKeNjpTe3cSjK4= github.com/fatih/color v1.17.0/go.mod h1:YZ7TlrGPkiz6ku9fK3TLD/pl3CpsiFyu8N92HLgmosI= github.com/felixge/httpsnoop v1.0.4 h1:NFTV2Zj1bL4mc9sqWACXbQFVBBg2W3GPvqp8/ESS2Wg= github.com/felixge/httpsnoop v1.0.4/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= +github.com/fergusstrange/embedded-postgres v1.30.0 h1:ewv1e6bBlqOIYtgGgRcEnNDpfGlmfPxB8T3PO9tV68Q= +github.com/fergusstrange/embedded-postgres v1.30.0/go.mod h1:w0YvnCgf19o6tskInrOOACtnqfVlOvluz3hlNLY7tRk= github.com/franela/goblin v0.0.0-20200105215937-c9ffbefa60db/go.mod h1:7dvUGVsVBjqR7JHJk0brhHOZYGmfBYOrK0ZhYMEtBr4= github.com/franela/goreq v0.0.0-20171204163338-bcd34c9993f8/go.mod h1:ZhphrRTfi2rbfLwlschooIH4+wKKDR4Pdxhh+TRoA20= github.com/frankban/quicktest v1.14.6 h1:7Xjx+VpznH+oBnejlPUj8oUpdxnVs4f8XU8WnHkI4W8= @@ -760,6 +762,8 @@ github.com/urfave/cli v1.22.1/go.mod h1:Gos4lmkARVdJ6EkW0WaNv/tZAAMe9V7XWyB60NtX github.com/wcharczuk/go-chart/v2 v2.1.0/go.mod h1:yx7MvAVNcP/kN9lKXM/NTce4au4DFN99j6i1OwDclNA= github.com/x448/float16 v0.8.4 h1:qLwI1I70+NjRFUR3zs1JPUCgaCXSh3SW62uAKT1mSBM= github.com/x448/float16 v0.8.4/go.mod h1:14CWIYCyZA/cWjXOioeEpHeN/83MdbZDRQHoFcYsOfg= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8 h1:nIPpBwaJSVYIxUFsDv3M8ofmx9yWTog9BfvIu0q41lo= +github.com/xi2/xz v0.0.0-20171230120015-48954b6210f8/go.mod h1:HUYIGzjTL3rfEspMxjDjgmT5uz5wzYJKVo23qUhYTos= github.com/xiang90/probing v0.0.0-20190116061207-43a291ad63a2/go.mod h1:UETIi67q53MR2AWcXfiuqkDkRtnGDLqkBTpCHuJHxtU= github.com/xrash/smetrics v0.0.0-20201216005158-039620a65673/go.mod h1:N3UwUGtsrSj3ccvlPHLoLsHnpR27oXr4ZE984MbSER8= github.com/xrash/smetrics v0.0.0-20240312152122-5f08fbb34913 h1:+qGGcbkzsfDQNPPe9UDgpxAWQrhbbBXOYJFQDq/dtJw= diff --git a/internal/api/grpc/auth/user.go b/internal/api/grpc/auth/user.go index 90e0ddc1d6..13f955fd81 100644 --- a/internal/api/grpc/auth/user.go +++ b/internal/api/grpc/auth/user.go @@ -69,7 +69,6 @@ func (s *Server) ListMyUserChanges(ctx context.Context, req *auth_pb.ListMyUserC } query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - AllowTimeTravel(). Limit(limit). OrderDesc(). AwaitOpenTransactions(). diff --git a/internal/api/grpc/management/org.go b/internal/api/grpc/management/org.go index abc179a763..a6a934160a 100644 --- a/internal/api/grpc/management/org.go +++ b/internal/api/grpc/management/org.go @@ -50,7 +50,6 @@ func (s *Server) ListOrgChanges(ctx context.Context, req *mgmt_pb.ListOrgChanges } query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - AllowTimeTravel(). Limit(limit). OrderDesc(). AwaitOpenTransactions(). diff --git a/internal/api/grpc/management/project.go b/internal/api/grpc/management/project.go index 00ccbd215c..52b6b10e9a 100644 --- a/internal/api/grpc/management/project.go +++ b/internal/api/grpc/management/project.go @@ -70,7 +70,6 @@ func (s *Server) ListProjectGrantChanges(ctx context.Context, req *mgmt_pb.ListP } query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - AllowTimeTravel(). Limit(limit). OrderDesc(). ResourceOwner(authz.GetCtxData(ctx).OrgID). @@ -152,7 +151,6 @@ func (s *Server) ListProjectChanges(ctx context.Context, req *mgmt_pb.ListProjec } query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - AllowTimeTravel(). Limit(limit). AwaitOpenTransactions(). OrderDesc(). diff --git a/internal/api/grpc/management/project_application.go b/internal/api/grpc/management/project_application.go index 4b65808776..3a0e1d5f92 100644 --- a/internal/api/grpc/management/project_application.go +++ b/internal/api/grpc/management/project_application.go @@ -52,7 +52,6 @@ func (s *Server) ListAppChanges(ctx context.Context, req *mgmt_pb.ListAppChanges } query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - AllowTimeTravel(). Limit(limit). AwaitOpenTransactions(). OrderDesc(). diff --git a/internal/api/grpc/management/user.go b/internal/api/grpc/management/user.go index 17bca58993..b876999584 100644 --- a/internal/api/grpc/management/user.go +++ b/internal/api/grpc/management/user.go @@ -92,7 +92,6 @@ func (s *Server) ListUserChanges(ctx context.Context, req *mgmt_pb.ListUserChang } query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - AllowTimeTravel(). Limit(limit). AwaitOpenTransactions(). OrderDesc(). diff --git a/internal/api/grpc/system/integration_test/limits_auditlogretention_test.go b/internal/api/grpc/system/integration_test/limits_auditlogretention_test.go index 24c224b0fe..b705618f68 100644 --- a/internal/api/grpc/system/integration_test/limits_auditlogretention_test.go +++ b/internal/api/grpc/system/integration_test/limits_auditlogretention_test.go @@ -73,6 +73,7 @@ func requireEventually( assertCounts func(assert.TestingT, *eventCounts), msg string, ) (counts *eventCounts) { + t.Helper() countTimeout := 30 * time.Second assertTimeout := countTimeout + time.Second countCtx, cancel := context.WithTimeout(ctx, time.Minute) diff --git a/internal/api/http/middleware/middleware_test.go b/internal/api/http/middleware/middleware_test.go index 4d7cb6636d..60d4099e06 100644 --- a/internal/api/http/middleware/middleware_test.go +++ b/internal/api/http/middleware/middleware_test.go @@ -1,6 +1,7 @@ package middleware import ( + "os" "testing" "golang.org/x/text/language" @@ -14,5 +15,5 @@ var ( func TestMain(m *testing.M) { i18n.SupportLanguages(SupportedLanguages...) - m.Run() + os.Exit(m.Run()) } diff --git a/internal/api/oidc/key.go b/internal/api/oidc/key.go index 76f78ab5ab..81f3b1c466 100644 --- a/internal/api/oidc/key.go +++ b/internal/api/oidc/key.go @@ -417,7 +417,6 @@ func (o *OPStorage) getMaxKeySequence(ctx context.Context) (float64, error) { eventstore.NewSearchQueryBuilder(eventstore.ColumnsMaxSequence). ResourceOwner(authz.GetInstance(ctx).InstanceID()). AwaitOpenTransactions(). - AllowTimeTravel(). AddQuery(). AggregateTypes( keypair.AggregateType, diff --git a/internal/api/scim/integration_test/users_list_test.go b/internal/api/scim/integration_test/users_list_test.go index 7945d2039d..ed8b79dc4e 100644 --- a/internal/api/scim/integration_test/users_list_test.go +++ b/internal/api/scim/integration_test/users_list_test.go @@ -145,8 +145,8 @@ func TestListUser(t *testing.T) { assert.Equal(t, totalCountOfHumanUsers, resp.TotalResults) assert.Equal(t, 5, resp.StartIndex) assert.Len(t, resp.Resources, 2) - assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-1: ")) - assert.True(t, strings.HasPrefix(resp.Resources[1].UserName, "scim-username-2: ")) + assert.True(t, strings.HasPrefix(resp.Resources[0].UserName, "scim-username-1: "), "got %q", resp.Resources[0].UserName) + assert.True(t, strings.HasPrefix(resp.Resources[1].UserName, "scim-username-2: "), "got %q", resp.Resources[1].UserName) }, }, { diff --git a/internal/auth/repository/eventsourcing/view/view.go b/internal/auth/repository/eventsourcing/view/view.go index 56e2676b87..c67844dbad 100644 --- a/internal/auth/repository/eventsourcing/view/view.go +++ b/internal/auth/repository/eventsourcing/view/view.go @@ -1,11 +1,8 @@ package view import ( - "context" - "github.com/jinzhu/gorm" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore" @@ -37,7 +34,3 @@ func StartView(sqlClient *database.DB, keyAlgorithm crypto.EncryptionAlgorithm, func (v *View) Health() (err error) { return v.Db.DB().Ping() } - -func (v *View) TimeTravel(ctx context.Context, tableName string) string { - return tableName + v.client.Timetravel(call.Took(ctx)) -} diff --git a/internal/authz/repository/eventsourcing/view/view.go b/internal/authz/repository/eventsourcing/view/view.go index f25b764f53..21a15c45fc 100644 --- a/internal/authz/repository/eventsourcing/view/view.go +++ b/internal/authz/repository/eventsourcing/view/view.go @@ -1,11 +1,8 @@ package view import ( - "context" - "github.com/jinzhu/gorm" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/query" ) @@ -31,7 +28,3 @@ func StartView(sqlClient *database.DB, queries *query.Queries) (*View, error) { func (v *View) Health() (err error) { return v.Db.DB().Ping() } - -func (v *View) TimeTravel(ctx context.Context, tableName string) string { - return tableName + v.client.Timetravel(call.Took(ctx)) -} diff --git a/internal/cache/connector/pg/connector.go b/internal/cache/connector/pg/connector.go index 9a89cf5f6a..e919aea49d 100644 --- a/internal/cache/connector/pg/connector.go +++ b/internal/cache/connector/pg/connector.go @@ -12,8 +12,7 @@ type Config struct { type Connector struct { PGXPool - Dialect string - Config Config + Config Config } func NewConnector(config Config, client *database.DB) *Connector { @@ -22,7 +21,6 @@ func NewConnector(config Config, client *database.DB) *Connector { } return &Connector{ PGXPool: client.Pool, - Dialect: client.Type(), Config: config, } } diff --git a/internal/cache/connector/pg/pg.go b/internal/cache/connector/pg/pg.go index 18215b68ed..1a64ef85f7 100644 --- a/internal/cache/connector/pg/pg.go +++ b/internal/cache/connector/pg/pg.go @@ -58,10 +58,8 @@ func NewCache[I ~int, K ~string, V cache.Entry[I, K]](ctx context.Context, purpo } c.logger.InfoContext(ctx, "pg cache logging enabled") - if connector.Dialect == "postgres" { - if err := c.createPartition(ctx); err != nil { - return nil, err - } + if err := c.createPartition(ctx); err != nil { + return nil, err } return c, nil } diff --git a/internal/cache/connector/pg/pg_test.go b/internal/cache/connector/pg/pg_test.go index f5980ad845..bb9b681b15 100644 --- a/internal/cache/connector/pg/pg_test.go +++ b/internal/cache/connector/pg/pg_test.go @@ -78,7 +78,6 @@ func TestNewCache(t *testing.T) { tt.expect(pool) connector := &Connector{ PGXPool: pool, - Dialect: "postgres", } c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector) @@ -518,7 +517,6 @@ func prepareCache(t *testing.T, conf cache.Config) (cache.PrunerCache[testIndex, WillReturnResult(pgxmock.NewResult("CREATE TABLE", 0)) connector := &Connector{ PGXPool: pool, - Dialect: "postgres", } c, err := NewCache[testIndex, string, *testObject](context.Background(), cachePurpose, conf, testIndices, connector) require.NoError(t, err) diff --git a/internal/command/command.go b/internal/command/command.go index f9c78fbaab..486e044a41 100644 --- a/internal/command/command.go +++ b/internal/command/command.go @@ -218,33 +218,6 @@ func (c *Commands) pushAppendAndReduce(ctx context.Context, object AppendReducer return AppendAndReduce(object, events...) } -// pushChunked pushes the commands in chunks of size to the eventstore. -// This can be used to reduce the amount of events in a single transaction. -// When an error occurs, the events that have been pushed so far will be returned. -// -// Warning: chunks are pushed in separate transactions. -// Successful pushes will not be rolled back if a later chunk fails. -// Only use this function when the caller is able to handle partial success -// and is able to consolidate the state on errors. -func (c *Commands) pushChunked(ctx context.Context, size uint16, cmds ...eventstore.Command) (_ []eventstore.Event, err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - - events := make([]eventstore.Event, 0, len(cmds)) - for i := 0; i < len(cmds); i += int(size) { - end := i + int(size) - if end > len(cmds) { - end = len(cmds) - } - chunk, err := c.eventstore.Push(ctx, cmds[i:end]...) - if err != nil { - return events, err - } - events = append(events, chunk...) - } - return events, nil -} - type AppendReducerDetails interface { AppendEvents(...eventstore.Event) // TODO: Why is it allowed to return an error here? diff --git a/internal/command/command_test.go b/internal/command/command_test.go index 7224f047b5..2367930b89 100644 --- a/internal/command/command_test.go +++ b/internal/command/command_test.go @@ -2,7 +2,6 @@ package command import ( "context" - "fmt" "io" "os" "testing" @@ -14,7 +13,6 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/i18n" - "github.com/zitadel/zitadel/internal/repository/permission" "github.com/zitadel/zitadel/internal/repository/user" ) @@ -31,93 +29,6 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func TestCommands_pushChunked(t *testing.T) { - aggregate := permission.NewAggregate("instanceID") - cmds := make([]eventstore.Command, 100) - for i := 0; i < 100; i++ { - cmds[i] = permission.NewAddedEvent(context.Background(), aggregate, "role", fmt.Sprintf("permission%d", i)) - } - type args struct { - size uint16 - } - tests := []struct { - name string - args args - eventstore func(*testing.T) *eventstore.Eventstore - wantEvents int - wantErr error - }{ - { - name: "push error", - args: args{ - size: 100, - }, - eventstore: expectEventstore( - expectPushFailed(io.ErrClosedPipe, cmds...), - ), - wantEvents: 0, - wantErr: io.ErrClosedPipe, - }, - { - name: "single chunk", - args: args{ - size: 100, - }, - eventstore: expectEventstore( - expectPush(cmds...), - ), - wantEvents: len(cmds), - }, - { - name: "aligned chunks", - args: args{ - size: 50, - }, - eventstore: expectEventstore( - expectPush(cmds[0:50]...), - expectPush(cmds[50:100]...), - ), - wantEvents: len(cmds), - }, - { - name: "odd chunks", - args: args{ - size: 30, - }, - eventstore: expectEventstore( - expectPush(cmds[0:30]...), - expectPush(cmds[30:60]...), - expectPush(cmds[60:90]...), - expectPush(cmds[90:100]...), - ), - wantEvents: len(cmds), - }, - { - name: "partial error", - args: args{ - size: 30, - }, - eventstore: expectEventstore( - expectPush(cmds[0:30]...), - expectPush(cmds[30:60]...), - expectPushFailed(io.ErrClosedPipe, cmds[60:90]...), - ), - wantEvents: len(cmds[0:60]), - wantErr: io.ErrClosedPipe, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - c := &Commands{ - eventstore: tt.eventstore(t), - } - gotEvents, err := c.pushChunked(context.Background(), tt.args.size, cmds...) - require.ErrorIs(t, err, tt.wantErr) - assert.Len(t, gotEvents, tt.wantEvents) - }) - } -} - func TestCommands_asyncPush(t *testing.T) { // make sure the test terminates on deadlock background := context.Background() diff --git a/internal/command/instance_role_permissions.go b/internal/command/instance_role_permissions.go index c0c6355dd6..fce272cc12 100644 --- a/internal/command/instance_role_permissions.go +++ b/internal/command/instance_role_permissions.go @@ -17,15 +17,9 @@ import ( "github.com/zitadel/zitadel/internal/zerrors" ) -const ( - CockroachRollPermissionChunkSize uint16 = 50 -) - // SynchronizeRolePermission checks the current state of role permissions in the eventstore for the aggregate. // It pushes the commands required to reach the desired state passed in target. // For system level permissions aggregateID must be set to `SYSTEM`, else it is the instance ID. -// -// In case cockroachDB is used, the commands are pushed in chunks of CockroachRollPermissionChunkSize. func (c *Commands) SynchronizeRolePermission(ctx context.Context, aggregateID string, target []authz.RoleMapping) (_ *domain.ObjectDetails, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -36,13 +30,9 @@ func (c *Commands) SynchronizeRolePermission(ctx context.Context, aggregateID st if err != nil { return nil, zerrors.ThrowInternal(err, "COMMA-Iej2r", "Errors.Internal") } - var events []eventstore.Event - if c.eventstore.Client().Database.Type() == "cockroach" { - events, err = c.pushChunked(ctx, CockroachRollPermissionChunkSize, cmds...) - } else { - events, err = c.eventstore.Push(ctx, cmds...) - } + events, err := c.eventstore.Push(ctx, cmds...) if err != nil { + logging.WithError(err).Error("failed to push role permission commands") return nil, zerrors.ThrowInternal(err, "COMMA-AiV3u", "Errors.Internal") } return pushedEventsToObjectDetails(events), nil diff --git a/internal/database/cockroach/crdb.go b/internal/database/cockroach/crdb.go index a5b3208a86..d1e7437d2b 100644 --- a/internal/database/cockroach/crdb.go +++ b/internal/database/cockroach/crdb.go @@ -18,7 +18,7 @@ import ( func init() { config := new(Config) - dialect.Register(config, config, true) + dialect.Register(config, config, false) } const ( @@ -52,7 +52,7 @@ func (c *Config) MatchName(name string) bool { return false } -func (_ *Config) Decode(configs []interface{}) (dialect.Connector, error) { +func (_ *Config) Decode(configs []any) (dialect.Connector, error) { connector := new(Config) decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ DecodeHook: mapstructure.StringToTimeDurationHookFunc(), @@ -149,12 +149,8 @@ func (c *Config) Password() string { return c.User.Password } -func (c *Config) Type() string { - return "cockroach" -} - -func (c *Config) Timetravel(d time.Duration) string { - return "" +func (c *Config) Type() dialect.DatabaseType { + return dialect.DatabaseTypeCockroach } type User struct { diff --git a/internal/database/database.go b/internal/database/database.go index e254edadc1..561717e3ea 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -149,33 +149,40 @@ func Connect(config Config, useAdmin bool) (*DB, error) { }, nil } -func DecodeHook(from, to reflect.Value) (_ interface{}, err error) { - if to.Type() != reflect.TypeOf(Config{}) { - return from.Interface(), nil - } - - config := new(Config) - if err = mapstructure.Decode(from.Interface(), config); err != nil { - return nil, err - } - - configuredDialect := dialect.SelectByConfig(config.Dialects) - configs := make([]interface{}, 0, len(config.Dialects)-1) - - for name, dialectConfig := range config.Dialects { - if !configuredDialect.Matcher.MatchName(name) { - continue +func DecodeHook(allowCockroach bool) func(from, to reflect.Value) (_ interface{}, err error) { + return func(from, to reflect.Value) (_ interface{}, err error) { + if to.Type() != reflect.TypeOf(Config{}) { + return from.Interface(), nil } - configs = append(configs, dialectConfig) - } + config := new(Config) + if err = mapstructure.Decode(from.Interface(), config); err != nil { + return nil, err + } - config.connector, err = configuredDialect.Matcher.Decode(configs) - if err != nil { - return nil, err - } + configuredDialect := dialect.SelectByConfig(config.Dialects) + configs := make([]any, 0, len(config.Dialects)) - return config, nil + for name, dialectConfig := range config.Dialects { + if !configuredDialect.Matcher.MatchName(name) { + continue + } + + configs = append(configs, dialectConfig) + } + + if !allowCockroach && configuredDialect.Matcher.Type() == dialect.DatabaseTypeCockroach { + logging.Info("Cockroach support was removed with Zitadel v3, please refer to https://zitadel.com/docs/self-hosting/manage/cli/mirror to migrate your data to postgres") + return nil, zerrors.ThrowPreconditionFailed(nil, "DATAB-0pIWD", "Cockroach support was removed with Zitadel v3") + } + + config.connector, err = configuredDialect.Matcher.Decode(configs) + if err != nil { + return nil, err + } + + return config, nil + } } func (c Config) DatabaseName() string { @@ -190,10 +197,6 @@ func (c Config) Password() string { return c.connector.Password() } -func (c Config) Type() string { - return c.connector.Type() -} - func EscapeLikeWildcards(value string) string { value = strings.ReplaceAll(value, "%", "\\%") value = strings.ReplaceAll(value, "_", "\\_") diff --git a/internal/database/dialect/config.go b/internal/database/dialect/config.go index 71fb477ea1..e544ea2878 100644 --- a/internal/database/dialect/config.go +++ b/internal/database/dialect/config.go @@ -3,7 +3,6 @@ package dialect import ( "database/sql" "sync" - "time" "github.com/jackc/pgx/v5/pgxpool" ) @@ -22,9 +21,17 @@ var ( type Matcher interface { MatchName(string) bool - Decode([]interface{}) (Connector, error) + Decode([]any) (Connector, error) + Type() DatabaseType } +type DatabaseType uint8 + +const ( + DatabaseTypePostgres DatabaseType = iota + DatabaseTypeCockroach +) + const ( DefaultAppName = "zitadel" ) @@ -38,8 +45,6 @@ type Connector interface { type Database interface { DatabaseName() string Username() string - Type() string - Timetravel(time.Duration) string } func Register(matcher Matcher, config Connector, isDefault bool) { diff --git a/internal/database/postgres/embedded.go b/internal/database/postgres/embedded.go new file mode 100644 index 0000000000..57aec756f0 --- /dev/null +++ b/internal/database/postgres/embedded.go @@ -0,0 +1,38 @@ +package postgres + +import ( + "net" + "os" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/zitadel/logging" +) + +func StartEmbedded() (embeddedpostgres.Config, func()) { + path, err := os.MkdirTemp("", "zitadel-embedded-postgres-*") + logging.OnError(err).Fatal("unable to create temp dir") + + port, close := getPort() + + config := embeddedpostgres.DefaultConfig().Version(embeddedpostgres.V16).Port(uint32(port)).RuntimePath(path) + embedded := embeddedpostgres.NewDatabase(config) + + close() + err = embedded.Start() + logging.OnError(err).Fatal("unable to start db") + + return config, func() { + logging.OnError(embedded.Stop()).Error("unable to stop db") + } +} + +// getPort returns a free port and locks it until close is called +func getPort() (port uint16, close func()) { + l, err := net.Listen("tcp", ":0") + logging.OnError(err).Fatal("unable to get port") + port = uint16(l.Addr().(*net.TCPAddr).Port) + logging.WithFields("port", port).Info("Port is available") + return port, func() { + logging.OnError(l.Close()).Error("unable to close port listener") + } +} diff --git a/internal/database/postgres/pg.go b/internal/database/postgres/pg.go index c847cc0a58..2f8bb29e17 100644 --- a/internal/database/postgres/pg.go +++ b/internal/database/postgres/pg.go @@ -18,7 +18,7 @@ import ( func init() { config := new(Config) - dialect.Register(config, config, false) + dialect.Register(config, config, true) } const ( @@ -29,16 +29,15 @@ const ( ) type Config struct { - Host string - Port int32 - Database string - EventPushConnRatio float64 - MaxOpenConns uint32 - MaxIdleConns uint32 - MaxConnLifetime time.Duration - MaxConnIdleTime time.Duration - User User - Admin AdminUser + Host string + Port int32 + Database string + MaxOpenConns uint32 + MaxIdleConns uint32 + MaxConnLifetime time.Duration + MaxConnIdleTime time.Duration + User User + Admin AdminUser // Additional options to be appended as options= // The value will be taken as is. Multiple options are space separated. Options string @@ -148,12 +147,8 @@ func (c *Config) Password() string { return c.User.Password } -func (c *Config) Type() string { - return "postgres" -} - -func (c *Config) Timetravel(time.Duration) string { - return "" +func (c *Config) Type() dialect.DatabaseType { + return dialect.DatabaseTypePostgres } type User struct { diff --git a/internal/eventstore/eventstore_pusher_test.go b/internal/eventstore/eventstore_pusher_test.go index 4e8e663667..318cf1a37e 100644 --- a/internal/eventstore/eventstore_pusher_test.go +++ b/internal/eventstore/eventstore_pusher_test.go @@ -11,7 +11,7 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" ) -func TestCRDB_Push_OneAggregate(t *testing.T) { +func TestEventstore_Push_OneAggregate(t *testing.T) { type args struct { ctx context.Context commands []eventstore.Command @@ -202,7 +202,7 @@ func TestCRDB_Push_OneAggregate(t *testing.T) { } } if _, err := db.Push(tt.args.ctx, tt.args.commands...); (err != nil) != tt.res.wantErr { - t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr) + t.Errorf("eventstore.Push() error = %v, wantErr %v", err, tt.res.wantErr) } assertEventCount(t, @@ -218,7 +218,7 @@ func TestCRDB_Push_OneAggregate(t *testing.T) { } } -func TestCRDB_Push_MultipleAggregate(t *testing.T) { +func TestEventstore_Push_MultipleAggregate(t *testing.T) { type args struct { commands []eventstore.Command } @@ -312,7 +312,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) { }, ) if _, err := db.Push(context.Background(), tt.args.commands...); (err != nil) != tt.res.wantErr { - t.Errorf("CRDB.Push() error = %v, wantErr %v", err, tt.res.wantErr) + t.Errorf("eventstore.Push() error = %v, wantErr %v", err, tt.res.wantErr) } assertEventCount(t, clients[pusherName], tt.res.eventsRes.aggType, tt.res.eventsRes.aggID, tt.res.eventsRes.pushedEventsCount) @@ -321,7 +321,7 @@ func TestCRDB_Push_MultipleAggregate(t *testing.T) { } } -func TestCRDB_Push_Parallel(t *testing.T) { +func TestEventstore_Push_Parallel(t *testing.T) { type args struct { commands [][]eventstore.Command } @@ -453,7 +453,7 @@ func TestCRDB_Push_Parallel(t *testing.T) { } } -func TestCRDB_Push_ResourceOwner(t *testing.T) { +func TestEventstore_Push_ResourceOwner(t *testing.T) { type args struct { commands []eventstore.Command } @@ -587,7 +587,7 @@ func TestCRDB_Push_ResourceOwner(t *testing.T) { events, err := db.Push(context.Background(), tt.args.commands...) if err != nil { - t.Errorf("CRDB.Push() error = %v", err) + t.Errorf("eventstore.Push() error = %v", err) } if len(events) != len(tt.res.resourceOwners) { diff --git a/internal/eventstore/eventstore_querier_test.go b/internal/eventstore/eventstore_querier_test.go index 4b7ad78b25..3f23c5da75 100644 --- a/internal/eventstore/eventstore_querier_test.go +++ b/internal/eventstore/eventstore_querier_test.go @@ -7,7 +7,7 @@ import ( "github.com/zitadel/zitadel/internal/eventstore" ) -func TestCRDB_Filter(t *testing.T) { +func TestEventstore_Filter(t *testing.T) { type args struct { searchQuery *eventstore.SearchQueryBuilder } @@ -120,18 +120,18 @@ func TestCRDB_Filter(t *testing.T) { events, err := db.Filter(context.Background(), tt.args.searchQuery) if (err != nil) != tt.wantErr { - t.Errorf("CRDB.query() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("eventstore.query() error = %v, wantErr %v", err, tt.wantErr) } if len(events) != tt.res.eventCount { - t.Errorf("CRDB.query() expected event count: %d got %d", tt.res.eventCount, len(events)) + t.Errorf("eventstore.query() expected event count: %d got %d", tt.res.eventCount, len(events)) } }) } } } -func TestCRDB_LatestSequence(t *testing.T) { +func TestEventstore_LatestSequence(t *testing.T) { type args struct { searchQuery *eventstore.SearchQueryBuilder } @@ -204,10 +204,10 @@ func TestCRDB_LatestSequence(t *testing.T) { sequence, err := db.LatestSequence(context.Background(), tt.args.searchQuery) if (err != nil) != tt.wantErr { - t.Errorf("CRDB.query() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("eventstore.query() error = %v, wantErr %v", err, tt.wantErr) } if tt.res.sequence > sequence { - t.Errorf("CRDB.query() expected sequence: %v got %v", tt.res.sequence, sequence) + t.Errorf("eventstore.query() expected sequence: %v got %v", tt.res.sequence, sequence) } }) } diff --git a/internal/eventstore/example_test.go b/internal/eventstore/example_test.go index ee053d16da..2b6c205ddd 100644 --- a/internal/eventstore/example_test.go +++ b/internal/eventstore/example_test.go @@ -289,8 +289,8 @@ func (rm *UserReadModel) Reduce() error { func TestUserReadModel(t *testing.T) { es := eventstore.NewEventstore( &eventstore.Config{ - Querier: query_repo.NewCRDB(testCRDBClient), - Pusher: v3.NewEventstore(testCRDBClient), + Querier: query_repo.NewPostgres(testClient), + Pusher: v3.NewEventstore(testClient), }, ) diff --git a/internal/eventstore/handler/v2/handler.go b/internal/eventstore/handler/v2/handler.go index 052f965e22..09d5a63825 100644 --- a/internal/eventstore/handler/v2/handler.go +++ b/internal/eventstore/handler/v2/handler.go @@ -650,7 +650,6 @@ func (h *Handler) eventQuery(currentState *state) *eventstore.SearchQueryBuilder builder := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). AwaitOpenTransactions(). Limit(uint64(h.bulkLimit)). - AllowTimeTravel(). OrderAsc(). InstanceID(currentState.instanceID) diff --git a/internal/eventstore/local_crdb_test.go b/internal/eventstore/local_postgres_test.go similarity index 63% rename from internal/eventstore/local_crdb_test.go rename to internal/eventstore/local_postgres_test.go index 87c5084fe7..1344611511 100644 --- a/internal/eventstore/local_crdb_test.go +++ b/internal/eventstore/local_postgres_test.go @@ -8,111 +8,99 @@ import ( "testing" "time" - "github.com/cockroachdb/cockroach-go/v2/testserver" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" "github.com/zitadel/logging" "github.com/zitadel/zitadel/cmd/initialise" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/cockroach" + "github.com/zitadel/zitadel/internal/database/postgres" "github.com/zitadel/zitadel/internal/eventstore" es_sql "github.com/zitadel/zitadel/internal/eventstore/repository/sql" new_es "github.com/zitadel/zitadel/internal/eventstore/v3" ) var ( - testCRDBClient *database.DB - queriers map[string]eventstore.Querier = make(map[string]eventstore.Querier) - pushers map[string]eventstore.Pusher = make(map[string]eventstore.Pusher) - clients map[string]*database.DB = make(map[string]*database.DB) + testClient *database.DB + queriers map[string]eventstore.Querier = make(map[string]eventstore.Querier) + pushers map[string]eventstore.Pusher = make(map[string]eventstore.Pusher) + clients map[string]*database.DB = make(map[string]*database.DB) ) func TestMain(m *testing.M) { - opts := make([]testserver.TestServerOpt, 0, 1) - if version := os.Getenv("ZITADEL_CRDB_VERSION"); version != "" { - opts = append(opts, testserver.CustomVersionOpt(version)) - } - ts, err := testserver.NewTestServer(opts...) - if err != nil { - logging.WithFields("error", err).Fatal("unable to start db") - } + os.Exit(func() int { + config, cleanup := postgres.StartEmbedded() + defer cleanup() - testCRDBClient = &database.DB{ - Database: new(testDB), - } - - connConfig, err := pgxpool.ParseConfig(ts.PGURL().String()) - if err != nil { - logging.WithFields("error", err).Fatal("unable to parse db url") - } - connConfig.AfterConnect = new_es.RegisterEventstoreTypes - pool, err := pgxpool.NewWithConfig(context.Background(), connConfig) - if err != nil { - logging.WithFields("error", err).Fatal("unable to create db pool") - } - testCRDBClient.DB = stdlib.OpenDBFromPool(pool) - if err = testCRDBClient.Ping(); err != nil { - logging.WithFields("error", err).Fatal("unable to ping db") - } - - v2 := &es_sql.CRDB{DB: testCRDBClient} - queriers["v2(inmemory)"] = v2 - clients["v2(inmemory)"] = testCRDBClient - - pushers["v3(inmemory)"] = new_es.NewEventstore(testCRDBClient) - clients["v3(inmemory)"] = testCRDBClient - - if localDB, err := connectLocalhost(); err == nil { - if err = initDB(context.Background(), localDB); err != nil { - logging.WithFields("error", err).Fatal("migrations failed") + testClient = &database.DB{ + Database: new(testDB), } - pushers["v3(singlenode)"] = new_es.NewEventstore(localDB) - clients["v3(singlenode)"] = localDB - } - // pushers["v2(inmemory)"] = v2 + connConfig, err := pgxpool.ParseConfig(config.GetConnectionURL()) + logging.OnError(err).Fatal("unable to parse db url") - defer func() { - testCRDBClient.Close() - ts.Stop() - }() + connConfig.AfterConnect = new_es.RegisterEventstoreTypes + pool, err := pgxpool.NewWithConfig(context.Background(), connConfig) + logging.OnError(err).Fatal("unable to create db pool") - if err = initDB(context.Background(), testCRDBClient); err != nil { - logging.WithFields("error", err).Fatal("migrations failed") - } + testClient.DB = stdlib.OpenDBFromPool(pool) + err = testClient.Ping() + logging.OnError(err).Fatal("unable to ping db") - os.Exit(m.Run()) + v2 := &es_sql.Postgres{DB: testClient} + queriers["v2(inmemory)"] = v2 + clients["v2(inmemory)"] = testClient + + pushers["v3(inmemory)"] = new_es.NewEventstore(testClient) + clients["v3(inmemory)"] = testClient + + if localDB, err := connectLocalhost(); err == nil { + err = initDB(context.Background(), localDB) + logging.OnError(err).Fatal("migrations failed") + + pushers["v3(singlenode)"] = new_es.NewEventstore(localDB) + clients["v3(singlenode)"] = localDB + } + + defer func() { + logging.OnError(testClient.Close()).Error("unable to close db") + }() + + err = initDB(context.Background(), &database.DB{DB: testClient.DB, Database: &postgres.Config{Database: "zitadel"}}) + logging.OnError(err).Fatal("migrations failed") + + return m.Run() + }()) } func initDB(ctx context.Context, db *database.DB) error { - initialise.ReadStmts("cockroach") config := new(database.Config) - config.SetConnector(&cockroach.Config{ - User: cockroach.User{ - Username: "zitadel", - }, - Database: "zitadel", - }) + config.SetConnector(&postgres.Config{User: postgres.User{Username: "zitadel"}, Database: "zitadel"}) + + if err := initialise.ReadStmts(); err != nil { + return err + } + err := initialise.Init(ctx, db, initialise.VerifyUser(config.Username(), ""), initialise.VerifyDatabase(config.DatabaseName()), - initialise.VerifyGrant(config.DatabaseName(), config.Username()), - initialise.VerifySettings(config.DatabaseName(), config.Username())) + initialise.VerifyGrant(config.DatabaseName(), config.Username())) if err != nil { return err } + err = initialise.VerifyZitadel(ctx, db, *config) if err != nil { return err } + // create old events _, err = db.Exec(oldEventsTable) return err } func connectLocalhost() (*database.DB, error) { - client, err := sql.Open("pgx", "postgresql://root@localhost:26257/defaultdb?sslmode=disable") + client, err := sql.Open("pgx", "postgresql://postgres@localhost:5432/postgres?sslmode=disable") if err != nil { return nil, err } @@ -134,7 +122,7 @@ func (*testDB) DatabaseName() string { return "db" } func (*testDB) Username() string { return "user" } -func (*testDB) Type() string { return "cockroach" } +func (*testDB) Type() string { return "postgres" } func generateCommand(aggregateType eventstore.AggregateType, aggregateID string, opts ...func(*testEvent)) eventstore.Command { e := &testEvent{ @@ -177,7 +165,7 @@ func canceledCtx() context.Context { } func fillUniqueData(unique_type, field, instanceID string) error { - _, err := testCRDBClient.Exec("INSERT INTO eventstore.unique_constraints (unique_type, unique_field, instance_id) VALUES ($1, $2, $3)", unique_type, field, instanceID) + _, err := testClient.Exec("INSERT INTO eventstore.unique_constraints (unique_type, unique_field, instance_id) VALUES ($1, $2, $3)", unique_type, field, instanceID) return err } @@ -251,5 +239,5 @@ const oldEventsTable = `CREATE TABLE IF NOT EXISTS eventstore.events ( , "position" DECIMAL NOT NULL , in_tx_order INTEGER NOT NULL - , PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence DESC) + , PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence) );` diff --git a/internal/eventstore/repository/search_query.go b/internal/eventstore/repository/search_query.go index f84c7f1201..6ffba31ca8 100644 --- a/internal/eventstore/repository/search_query.go +++ b/internal/eventstore/repository/search_query.go @@ -16,7 +16,6 @@ type SearchQuery struct { Tx *sql.Tx LockRows bool LockOption eventstore.LockOption - AllowTimeTravel bool AwaitOpenTransactions bool Limit uint64 Offset uint32 @@ -51,11 +50,11 @@ const ( OperationGreater // OperationLess compares if the given values is less than the stored one OperationLess - //OperationIn checks if a stored value matches one of the passed value list + // OperationIn checks if a stored value matches one of the passed value list OperationIn - //OperationJSONContains checks if a stored value matches the given json + // OperationJSONContains checks if a stored value matches the given json OperationJSONContains - //OperationNotIn checks if a stored value does not match one of the passed value list + // OperationNotIn checks if a stored value does not match one of the passed value list OperationNotIn operationCount @@ -65,25 +64,25 @@ const ( type Field int32 const ( - //FieldAggregateType represents the aggregate type field + // FieldAggregateType represents the aggregate type field FieldAggregateType Field = iota + 1 - //FieldAggregateID represents the aggregate id field + // FieldAggregateID represents the aggregate id field FieldAggregateID - //FieldSequence represents the sequence field + // FieldSequence represents the sequence field FieldSequence - //FieldResourceOwner represents the resource owner field + // FieldResourceOwner represents the resource owner field FieldResourceOwner - //FieldInstanceID represents the instance id field + // FieldInstanceID represents the instance id field FieldInstanceID - //FieldEditorService represents the editor service field + // FieldEditorService represents the editor service field FieldEditorService - //FieldEditorUser represents the editor user field + // FieldEditorUser represents the editor user field FieldEditorUser - //FieldEventType represents the event type field + // FieldEventType represents the event type field FieldEventType - //FieldEventData represents the event data field + // FieldEventData represents the event data field FieldEventData - //FieldCreationDate represents the creation date field + // FieldCreationDate represents the creation date field FieldCreationDate // FieldPosition represents the field of the global sequence FieldPosition @@ -129,7 +128,6 @@ func QueryFromBuilder(builder *eventstore.SearchQueryBuilder) (*SearchQuery, err Offset: builder.GetOffset(), Desc: builder.GetDesc(), Tx: builder.GetTx(), - AllowTimeTravel: builder.GetAllowTimeTravel(), AwaitOpenTransactions: builder.GetAwaitOpenTransactions(), SubQueries: make([][]*Filter, len(builder.GetQueries())), } diff --git a/internal/eventstore/repository/sql/crdb.go b/internal/eventstore/repository/sql/crdb.go deleted file mode 100644 index 68610676c3..0000000000 --- a/internal/eventstore/repository/sql/crdb.go +++ /dev/null @@ -1,455 +0,0 @@ -package sql - -import ( - "context" - "database/sql" - "encoding/json" - "errors" - "regexp" - "strconv" - "strings" - - "github.com/cockroachdb/cockroach-go/v2/crdb" - "github.com/jackc/pgx/v5/pgconn" - "github.com/zitadel/logging" - - "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/eventstore" - "github.com/zitadel/zitadel/internal/eventstore/repository" - "github.com/zitadel/zitadel/internal/telemetry/tracing" - "github.com/zitadel/zitadel/internal/zerrors" -) - -const ( - //as soon as stored procedures are possible in crdb - // we could move the code to migrations and call the procedure - // traking issue: https://github.com/cockroachdb/cockroach/issues/17511 - // - //previous_data selects the needed data of the latest event of the aggregate - // and buffers it (crdb inmemory) - crdbInsert = "WITH previous_data (aggregate_type_sequence, aggregate_sequence, resource_owner) AS (" + - "SELECT agg_type.seq, agg.seq, agg.ro FROM " + - "(" + - //max sequence of requested aggregate type - " SELECT MAX(event_sequence) seq, 1 join_me" + - " FROM eventstore.events" + - " WHERE aggregate_type = $2" + - " AND (CASE WHEN $9::TEXT IS NULL THEN instance_id is null else instance_id = $9::TEXT END)" + - ") AS agg_type " + - // combined with - "LEFT JOIN " + - "(" + - // max sequence and resource owner of aggregate root - " SELECT event_sequence seq, resource_owner ro, 1 join_me" + - " FROM eventstore.events" + - " WHERE aggregate_type = $2 AND aggregate_id = $3" + - " AND (CASE WHEN $9::TEXT IS NULL THEN instance_id is null else instance_id = $9::TEXT END)" + - " ORDER BY event_sequence DESC" + - " LIMIT 1" + - ") AS agg USING(join_me)" + - ") " + - "INSERT INTO eventstore.events (" + - " event_type," + - " aggregate_type," + - " aggregate_id," + - " aggregate_version," + - " creation_date," + - " position," + - " event_data," + - " editor_user," + - " editor_service," + - " resource_owner," + - " instance_id," + - " event_sequence," + - " previous_aggregate_sequence," + - " previous_aggregate_type_sequence," + - " in_tx_order" + - ") " + - // defines the data to be inserted - "SELECT" + - " $1::VARCHAR AS event_type," + - " $2::VARCHAR AS aggregate_type," + - " $3::VARCHAR AS aggregate_id," + - " $4::VARCHAR AS aggregate_version," + - " hlc_to_timestamp(cluster_logical_timestamp()) AS creation_date," + - " cluster_logical_timestamp() AS position," + - " $5::JSONB AS event_data," + - " $6::VARCHAR AS editor_user," + - " $7::VARCHAR AS editor_service," + - " COALESCE((resource_owner), $8::VARCHAR) AS resource_owner," + - " $9::VARCHAR AS instance_id," + - " COALESCE(aggregate_sequence, 0)+1," + - " aggregate_sequence AS previous_aggregate_sequence," + - " aggregate_type_sequence AS previous_aggregate_type_sequence," + - " $10 AS in_tx_order " + - "FROM previous_data " + - "RETURNING id, event_sequence, creation_date, resource_owner, instance_id" - - uniqueInsert = `INSERT INTO eventstore.unique_constraints - ( - unique_type, - unique_field, - instance_id - ) - VALUES ( - $1, - $2, - $3 - )` - - uniqueDelete = `DELETE FROM eventstore.unique_constraints - WHERE unique_type = $1 and unique_field = $2 and instance_id = $3` - uniqueDeleteInstance = `DELETE FROM eventstore.unique_constraints - WHERE instance_id = $1` -) - -// awaitOpenTransactions ensures event ordering, so we don't events younger that open transactions -var ( - awaitOpenTransactionsV1 string - awaitOpenTransactionsV2 string -) - -func awaitOpenTransactions(useV1 bool) string { - if useV1 { - return awaitOpenTransactionsV1 - } - return awaitOpenTransactionsV2 -} - -type CRDB struct { - *database.DB -} - -func NewCRDB(client *database.DB) *CRDB { - switch client.Type() { - case "cockroach": - awaitOpenTransactionsV1 = " AND creation_date::TIMESTAMP < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = ANY(?))" - awaitOpenTransactionsV2 = ` AND hlc_to_timestamp("position") < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = ANY(?))` - case "postgres": - awaitOpenTransactionsV1 = ` AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')` - awaitOpenTransactionsV2 = ` AND "position" < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')` - } - - return &CRDB{client} -} - -func (db *CRDB) Health(ctx context.Context) error { return db.Ping() } - -// Push adds all events to the eventstreams of the aggregates. -// This call is transaction save. The transaction will be rolled back if one event fails -func (db *CRDB) Push(ctx context.Context, commands ...eventstore.Command) (events []eventstore.Event, err error) { - events = make([]eventstore.Event, len(commands)) - - err = crdb.ExecuteTx(ctx, db.DB.DB, nil, func(tx *sql.Tx) error { - - var uniqueConstraints []*eventstore.UniqueConstraint - - for i, command := range commands { - if command.Aggregate().InstanceID == "" { - command.Aggregate().InstanceID = authz.GetInstance(ctx).InstanceID() - } - - var payload []byte - if command.Payload() != nil { - payload, err = json.Marshal(command.Payload()) - if err != nil { - return err - } - } - e := &repository.Event{ - Typ: command.Type(), - Data: payload, - EditorUser: command.Creator(), - Version: command.Aggregate().Version, - AggregateID: command.Aggregate().ID, - AggregateType: command.Aggregate().Type, - ResourceOwner: sql.NullString{String: command.Aggregate().ResourceOwner, Valid: command.Aggregate().ResourceOwner != ""}, - InstanceID: command.Aggregate().InstanceID, - } - - err := tx.QueryRowContext(ctx, crdbInsert, - e.Type(), - e.Aggregate().Type, - e.Aggregate().ID, - e.Aggregate().Version, - payload, - e.Creator(), - "zitadel", - e.Aggregate().ResourceOwner, - e.Aggregate().InstanceID, - i, - ).Scan(&e.ID, &e.Seq, &e.CreationDate, &e.ResourceOwner, &e.InstanceID) - - if err != nil { - logging.WithFields( - "aggregate", e.Aggregate().Type, - "aggregateId", e.Aggregate().ID, - "aggregateType", e.Aggregate().Type, - "eventType", e.Type(), - "instanceID", e.Aggregate().InstanceID, - ).WithError(err).Debug("query failed") - return zerrors.ThrowInternal(err, "SQL-SBP37", "unable to create event") - } - - uniqueConstraints = append(uniqueConstraints, command.UniqueConstraints()...) - events[i] = e - } - - return db.handleUniqueConstraints(ctx, tx, uniqueConstraints...) - }) - if err != nil && !errors.Is(err, &zerrors.ZitadelError{}) { - err = zerrors.ThrowInternal(err, "SQL-DjgtG", "unable to store events") - } - - return events, err -} - -// handleUniqueConstraints adds or removes unique constraints -func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueConstraints ...*eventstore.UniqueConstraint) (err error) { - if len(uniqueConstraints) == 0 || (len(uniqueConstraints) == 1 && uniqueConstraints[0] == nil) { - return nil - } - - for _, uniqueConstraint := range uniqueConstraints { - uniqueConstraint.UniqueField = strings.ToLower(uniqueConstraint.UniqueField) - switch uniqueConstraint.Action { - case eventstore.UniqueConstraintAdd: - _, err := tx.ExecContext(ctx, uniqueInsert, uniqueConstraint.UniqueType, uniqueConstraint.UniqueField, authz.GetInstance(ctx).InstanceID()) - if err != nil { - logging.WithFields( - "unique_type", uniqueConstraint.UniqueType, - "unique_field", uniqueConstraint.UniqueField).WithError(err).Info("insert unique constraint failed") - - if db.isUniqueViolationError(err) { - return zerrors.ThrowAlreadyExists(err, "SQL-wHcEq", uniqueConstraint.ErrorMessage) - } - - return zerrors.ThrowInternal(err, "SQL-dM9ds", "unable to create unique constraint") - } - case eventstore.UniqueConstraintRemove: - _, err := tx.ExecContext(ctx, uniqueDelete, uniqueConstraint.UniqueType, uniqueConstraint.UniqueField, authz.GetInstance(ctx).InstanceID()) - if err != nil { - logging.WithFields( - "unique_type", uniqueConstraint.UniqueType, - "unique_field", uniqueConstraint.UniqueField).WithError(err).Info("delete unique constraint failed") - return zerrors.ThrowInternal(err, "SQL-6n88i", "unable to remove unique constraint") - } - case eventstore.UniqueConstraintInstanceRemove: - _, err := tx.ExecContext(ctx, uniqueDeleteInstance, authz.GetInstance(ctx).InstanceID()) - if err != nil { - logging.WithFields( - "instance_id", authz.GetInstance(ctx).InstanceID()).WithError(err).Info("delete instance unique constraints failed") - return zerrors.ThrowInternal(err, "SQL-6n88i", "unable to remove unique constraints of instance") - } - } - } - return nil -} - -// FilterToReducer finds all events matching the given search query and passes them to the reduce function. -func (crdb *CRDB) FilterToReducer(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder, reduce eventstore.Reducer) (err error) { - ctx, span := tracing.NewSpan(ctx) - defer func() { span.EndWithError(err) }() - - err = query(ctx, crdb, searchQuery, reduce, false) - if err == nil { - return nil - } - pgErr := new(pgconn.PgError) - // check events2 not exists - if errors.As(err, &pgErr) && pgErr.Code == "42P01" { - return query(ctx, crdb, searchQuery, reduce, true) - } - return err -} - -// LatestSequence returns the latest sequence found by the search query -func (db *CRDB) LatestSequence(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) (float64, error) { - var position sql.NullFloat64 - err := query(ctx, db, searchQuery, &position, false) - return position.Float64, err -} - -// InstanceIDs returns the instance ids found by the search query -func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) ([]string, error) { - var ids []string - err := query(ctx, db, searchQuery, &ids, false) - if err != nil { - return nil, err - } - return ids, nil -} - -func (db *CRDB) Client() *database.DB { - return db.DB -} - -func (db *CRDB) orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string { - if useV1 { - if desc { - return ` ORDER BY event_sequence DESC` - } - return ` ORDER BY event_sequence` - } - if shouldOrderBySequence { - if desc { - return ` ORDER BY "sequence" DESC` - } - return ` ORDER BY "sequence"` - } - - if desc { - return ` ORDER BY "position" DESC, in_tx_order DESC` - } - return ` ORDER BY "position", in_tx_order` -} - -func (db *CRDB) eventQuery(useV1 bool) string { - if useV1 { - return "SELECT" + - " creation_date" + - ", event_type" + - ", event_sequence" + - ", event_data" + - ", editor_user" + - ", resource_owner" + - ", instance_id" + - ", aggregate_type" + - ", aggregate_id" + - ", aggregate_version" + - " FROM eventstore.events" - } - return "SELECT" + - " created_at" + - ", event_type" + - `, "sequence"` + - `, "position"` + - ", payload" + - ", creator" + - `, "owner"` + - ", instance_id" + - ", aggregate_type" + - ", aggregate_id" + - ", revision" + - " FROM eventstore.events2" -} - -func (db *CRDB) maxSequenceQuery(useV1 bool) string { - if useV1 { - return `SELECT event_sequence FROM eventstore.events` - } - return `SELECT "position" FROM eventstore.events2` -} - -func (db *CRDB) instanceIDsQuery(useV1 bool) string { - table := "eventstore.events2" - if useV1 { - table = "eventstore.events" - } - return "SELECT DISTINCT instance_id FROM " + table -} - -func (db *CRDB) columnName(col repository.Field, useV1 bool) string { - switch col { - case repository.FieldAggregateID: - return "aggregate_id" - case repository.FieldAggregateType: - return "aggregate_type" - case repository.FieldSequence: - if useV1 { - return "event_sequence" - } - return `"sequence"` - case repository.FieldResourceOwner: - if useV1 { - return "resource_owner" - } - return `"owner"` - case repository.FieldInstanceID: - return "instance_id" - case repository.FieldEditorService: - if useV1 { - return "editor_service" - } - return "" - case repository.FieldEditorUser: - if useV1 { - return "editor_user" - } - return "creator" - case repository.FieldEventType: - return "event_type" - case repository.FieldEventData: - if useV1 { - return "event_data" - } - return "payload" - case repository.FieldCreationDate: - if useV1 { - return "creation_date" - } - return "created_at" - case repository.FieldPosition: - return `"position"` - default: - return "" - } -} - -func (db *CRDB) conditionFormat(operation repository.Operation) string { - switch operation { - case repository.OperationIn: - return "%s %s ANY(?)" - case repository.OperationNotIn: - return "%s %s ALL(?)" - } - return "%s %s ?" -} - -func (db *CRDB) operation(operation repository.Operation) string { - switch operation { - case repository.OperationEquals, repository.OperationIn: - return "=" - case repository.OperationGreater: - return ">" - case repository.OperationLess: - return "<" - case repository.OperationJSONContains: - return "@>" - case repository.OperationNotIn: - return "<>" - } - return "" -} - -var ( - placeholder = regexp.MustCompile(`\?`) -) - -// placeholder replaces all "?" with postgres placeholders ($) -func (db *CRDB) placeholder(query string) string { - occurances := placeholder.FindAllStringIndex(query, -1) - if len(occurances) == 0 { - return query - } - replaced := query[:occurances[0][0]] - - for i, l := range occurances { - nextIDX := len(query) - if i < len(occurances)-1 { - nextIDX = occurances[i+1][0] - } - replaced = replaced + "$" + strconv.Itoa(i+1) + query[l[1]:nextIDX] - } - return replaced -} - -func (db *CRDB) isUniqueViolationError(err error) bool { - if pgxErr, ok := err.(*pgconn.PgError); ok { - if pgxErr.Code == "23505" { - return true - } - } - return false -} diff --git a/internal/eventstore/repository/sql/local_crdb_test.go b/internal/eventstore/repository/sql/local_postgres_test.go similarity index 54% rename from internal/eventstore/repository/sql/local_crdb_test.go rename to internal/eventstore/repository/sql/local_postgres_test.go index 0f8c934b47..4ec3f6a78d 100644 --- a/internal/eventstore/repository/sql/local_crdb_test.go +++ b/internal/eventstore/repository/sql/local_postgres_test.go @@ -7,72 +7,60 @@ import ( "testing" "time" - "github.com/cockroachdb/cockroach-go/v2/testserver" "github.com/jackc/pgx/v5/pgxpool" "github.com/jackc/pgx/v5/stdlib" "github.com/zitadel/logging" "github.com/zitadel/zitadel/cmd/initialise" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/cockroach" + "github.com/zitadel/zitadel/internal/database/postgres" new_es "github.com/zitadel/zitadel/internal/eventstore/v3" ) var ( - testCRDBClient *sql.DB + testClient *sql.DB ) func TestMain(m *testing.M) { - opts := make([]testserver.TestServerOpt, 0, 1) - if version := os.Getenv("ZITADEL_CRDB_VERSION"); version != "" { - opts = append(opts, testserver.CustomVersionOpt(version)) - } - ts, err := testserver.NewTestServer(opts...) - if err != nil { - logging.WithFields("error", err).Fatal("unable to start db") - } + os.Exit(func() int { + config, cleanup := postgres.StartEmbedded() + defer cleanup() - connConfig, err := pgxpool.ParseConfig(ts.PGURL().String()) - if err != nil { - logging.WithFields("error", err).Fatal("unable to parse db url") - } - connConfig.AfterConnect = new_es.RegisterEventstoreTypes - pool, err := pgxpool.NewWithConfig(context.Background(), connConfig) - if err != nil { - logging.WithFields("error", err).Fatal("unable to create db pool") - } + connConfig, err := pgxpool.ParseConfig(config.GetConnectionURL()) + logging.OnError(err).Fatal("unable to parse db url") - testCRDBClient = stdlib.OpenDBFromPool(pool) + connConfig.AfterConnect = new_es.RegisterEventstoreTypes + pool, err := pgxpool.NewWithConfig(context.Background(), connConfig) + logging.OnError(err).Fatal("unable to create db pool") - if err = testCRDBClient.Ping(); err != nil { - logging.WithFields("error", err).Fatal("unable to ping db") - } + testClient = stdlib.OpenDBFromPool(pool) - defer func() { - testCRDBClient.Close() - ts.Stop() - }() + err = testClient.Ping() + logging.OnError(err).Fatal("unable to ping db") - if err = initDB(context.Background(), &database.DB{DB: testCRDBClient, Database: &cockroach.Config{Database: "zitadel"}}); err != nil { - logging.WithFields("error", err).Fatal("migrations failed") - } + defer func() { + logging.OnError(testClient.Close()).Error("unable to close db") + }() - os.Exit(m.Run()) + err = initDB(context.Background(), &database.DB{DB: testClient, Database: &postgres.Config{Database: "zitadel"}}) + logging.OnError(err).Fatal("migrations failed") + + return m.Run() + }()) } func initDB(ctx context.Context, db *database.DB) error { config := new(database.Config) - config.SetConnector(&cockroach.Config{User: cockroach.User{Username: "zitadel"}, Database: "zitadel"}) + config.SetConnector(&postgres.Config{User: postgres.User{Username: "zitadel"}, Database: "zitadel"}) - if err := initialise.ReadStmts("cockroach"); err != nil { + if err := initialise.ReadStmts(); err != nil { return err } err := initialise.Init(ctx, db, initialise.VerifyUser(config.Username(), ""), initialise.VerifyDatabase(config.DatabaseName()), - initialise.VerifyGrant(config.DatabaseName(), config.Username()), - initialise.VerifySettings(config.DatabaseName(), config.Username())) + initialise.VerifyGrant(config.DatabaseName(), config.Username())) if err != nil { return err } @@ -95,7 +83,7 @@ func (*testDB) DatabaseName() string { return "db" } func (*testDB) Username() string { return "user" } -func (*testDB) Type() string { return "cockroach" } +func (*testDB) Type() string { return "postgres" } const oldEventsTable = `CREATE TABLE IF NOT EXISTS eventstore.events ( id UUID DEFAULT gen_random_uuid() @@ -116,5 +104,5 @@ const oldEventsTable = `CREATE TABLE IF NOT EXISTS eventstore.events ( , "position" DECIMAL NOT NULL , in_tx_order INTEGER NOT NULL - , PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence DESC) + , PRIMARY KEY (instance_id, aggregate_type, aggregate_id, event_sequence) );` diff --git a/internal/eventstore/repository/sql/postgres.go b/internal/eventstore/repository/sql/postgres.go new file mode 100644 index 0000000000..bc9ad2e029 --- /dev/null +++ b/internal/eventstore/repository/sql/postgres.go @@ -0,0 +1,240 @@ +package sql + +import ( + "context" + "database/sql" + "errors" + "regexp" + "strconv" + + "github.com/jackc/pgx/v5/pgconn" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/repository" + "github.com/zitadel/zitadel/internal/telemetry/tracing" +) + +// awaitOpenTransactions ensures event ordering, so we don't events younger that open transactions +var ( + awaitOpenTransactionsV1 = ` AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')` + awaitOpenTransactionsV2 = ` AND "position" < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')` +) + +func awaitOpenTransactions(useV1 bool) string { + if useV1 { + return awaitOpenTransactionsV1 + } + return awaitOpenTransactionsV2 +} + +type Postgres struct { + *database.DB +} + +func NewPostgres(client *database.DB) *Postgres { + return &Postgres{client} +} + +func (db *Postgres) Health(ctx context.Context) error { return db.Ping() } + +// FilterToReducer finds all events matching the given search query and passes them to the reduce function. +func (psql *Postgres) FilterToReducer(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder, reduce eventstore.Reducer) (err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + err = query(ctx, psql, searchQuery, reduce, false) + if err == nil { + return nil + } + pgErr := new(pgconn.PgError) + // check events2 not exists + if errors.As(err, &pgErr) && pgErr.Code == "42P01" { + return query(ctx, psql, searchQuery, reduce, true) + } + return err +} + +// LatestSequence returns the latest sequence found by the search query +func (db *Postgres) LatestSequence(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) (float64, error) { + var position sql.NullFloat64 + err := query(ctx, db, searchQuery, &position, false) + return position.Float64, err +} + +// InstanceIDs returns the instance ids found by the search query +func (db *Postgres) InstanceIDs(ctx context.Context, searchQuery *eventstore.SearchQueryBuilder) ([]string, error) { + var ids []string + err := query(ctx, db, searchQuery, &ids, false) + if err != nil { + return nil, err + } + return ids, nil +} + +func (db *Postgres) Client() *database.DB { + return db.DB +} + +func (db *Postgres) orderByEventSequence(desc, shouldOrderBySequence, useV1 bool) string { + if useV1 { + if desc { + return ` ORDER BY event_sequence DESC` + } + return ` ORDER BY event_sequence` + } + if shouldOrderBySequence { + if desc { + return ` ORDER BY "sequence" DESC` + } + return ` ORDER BY "sequence"` + } + + if desc { + return ` ORDER BY "position" DESC, in_tx_order DESC` + } + return ` ORDER BY "position", in_tx_order` +} + +func (db *Postgres) eventQuery(useV1 bool) string { + if useV1 { + return "SELECT" + + " creation_date" + + ", event_type" + + ", event_sequence" + + ", event_data" + + ", editor_user" + + ", resource_owner" + + ", instance_id" + + ", aggregate_type" + + ", aggregate_id" + + ", aggregate_version" + + " FROM eventstore.events" + } + return "SELECT" + + " created_at" + + ", event_type" + + `, "sequence"` + + `, "position"` + + ", payload" + + ", creator" + + `, "owner"` + + ", instance_id" + + ", aggregate_type" + + ", aggregate_id" + + ", revision" + + " FROM eventstore.events2" +} + +func (db *Postgres) maxSequenceQuery(useV1 bool) string { + if useV1 { + return `SELECT event_sequence FROM eventstore.events` + } + return `SELECT "position" FROM eventstore.events2` +} + +func (db *Postgres) instanceIDsQuery(useV1 bool) string { + table := "eventstore.events2" + if useV1 { + table = "eventstore.events" + } + return "SELECT DISTINCT instance_id FROM " + table +} + +func (db *Postgres) columnName(col repository.Field, useV1 bool) string { + switch col { + case repository.FieldAggregateID: + return "aggregate_id" + case repository.FieldAggregateType: + return "aggregate_type" + case repository.FieldSequence: + if useV1 { + return "event_sequence" + } + return `"sequence"` + case repository.FieldResourceOwner: + if useV1 { + return "resource_owner" + } + return `"owner"` + case repository.FieldInstanceID: + return "instance_id" + case repository.FieldEditorService: + if useV1 { + return "editor_service" + } + return "" + case repository.FieldEditorUser: + if useV1 { + return "editor_user" + } + return "creator" + case repository.FieldEventType: + return "event_type" + case repository.FieldEventData: + if useV1 { + return "event_data" + } + return "payload" + case repository.FieldCreationDate: + if useV1 { + return "creation_date" + } + return "created_at" + case repository.FieldPosition: + return `"position"` + default: + return "" + } +} + +func (db *Postgres) conditionFormat(operation repository.Operation) string { + switch operation { + case repository.OperationIn: + return "%s %s ANY(?)" + case repository.OperationNotIn: + return "%s %s ALL(?)" + case repository.OperationEquals, repository.OperationGreater, repository.OperationLess, repository.OperationJSONContains: + fallthrough + default: + return "%s %s ?" + } +} + +func (db *Postgres) operation(operation repository.Operation) string { + switch operation { + case repository.OperationEquals, repository.OperationIn: + return "=" + case repository.OperationGreater: + return ">" + case repository.OperationLess: + return "<" + case repository.OperationJSONContains: + return "@>" + case repository.OperationNotIn: + return "<>" + } + return "" +} + +var ( + placeholder = regexp.MustCompile(`\?`) +) + +// placeholder replaces all "?" with postgres placeholders ($) +func (db *Postgres) placeholder(query string) string { + occurrences := placeholder.FindAllStringIndex(query, -1) + if len(occurrences) == 0 { + return query + } + replaced := query[:occurrences[0][0]] + + for i, l := range occurrences { + nextIDX := len(query) + if i < len(occurrences)-1 { + nextIDX = occurrences[i+1][0] + } + replaced = replaced + "$" + strconv.Itoa(i+1) + query[l[1]:nextIDX] + } + return replaced +} diff --git a/internal/eventstore/repository/sql/crdb_test.go b/internal/eventstore/repository/sql/postgres_test.go similarity index 90% rename from internal/eventstore/repository/sql/crdb_test.go rename to internal/eventstore/repository/sql/postgres_test.go index a3f3331a82..151fdd1b6a 100644 --- a/internal/eventstore/repository/sql/crdb_test.go +++ b/internal/eventstore/repository/sql/postgres_test.go @@ -8,7 +8,7 @@ import ( "github.com/zitadel/zitadel/internal/eventstore/repository" ) -func TestCRDB_placeholder(t *testing.T) { +func TestPostgres_placeholder(t *testing.T) { type args struct { query string } @@ -50,15 +50,15 @@ func TestCRDB_placeholder(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - db := &CRDB{} + db := &Postgres{} if query := db.placeholder(tt.args.query); query != tt.res.query { - t.Errorf("CRDB.placeholder() = %v, want %v", query, tt.res.query) + t.Errorf("Postgres.placeholder() = %v, want %v", query, tt.res.query) } }) } } -func TestCRDB_operation(t *testing.T) { +func TestPostgres_operation(t *testing.T) { type res struct { op string } @@ -118,15 +118,15 @@ func TestCRDB_operation(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - db := &CRDB{} + db := &Postgres{} if got := db.operation(tt.args.operation); got != tt.res.op { - t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.op) + t.Errorf("Postgres.operation() = %v, want %v", got, tt.res.op) } }) } } -func TestCRDB_conditionFormat(t *testing.T) { +func TestPostgres_conditionFormat(t *testing.T) { type res struct { format string } @@ -159,15 +159,15 @@ func TestCRDB_conditionFormat(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - db := &CRDB{} + db := &Postgres{} if got := db.conditionFormat(tt.args.operation); got != tt.res.format { - t.Errorf("CRDB.conditionFormat() = %v, want %v", got, tt.res.format) + t.Errorf("Postgres.conditionFormat() = %v, want %v", got, tt.res.format) } }) } } -func TestCRDB_columnName(t *testing.T) { +func TestPostgres_columnName(t *testing.T) { type res struct { name string } @@ -295,9 +295,9 @@ func TestCRDB_columnName(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - db := &CRDB{} + db := &Postgres{} if got := db.columnName(tt.args.field, tt.args.useV1); got != tt.res.name { - t.Errorf("CRDB.operation() = %v, want %v", got, tt.res.name) + t.Errorf("Postgres.operation() = %v, want %v", got, tt.res.name) } }) } diff --git a/internal/eventstore/repository/sql/query.go b/internal/eventstore/repository/sql/query.go index 4e1cc87aff..a545225d9e 100644 --- a/internal/eventstore/repository/sql/query.go +++ b/internal/eventstore/repository/sql/query.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/eventstore" @@ -65,11 +64,6 @@ func query(ctx context.Context, criteria querier, searchQuery *eventstore.Search if where == "" || query == "" { return zerrors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") } - if q.Tx == nil { - if travel := prepareTimeTravel(ctx, criteria, q.AllowTimeTravel); travel != "" { - query += travel - } - } query += where // instead of using the max function of the database (which doesn't work for postgres) @@ -158,15 +152,7 @@ func prepareColumns(criteria querier, columns eventstore.Columns, useV1 bool) (s } } -func prepareTimeTravel(ctx context.Context, criteria querier, allow bool) string { - if !allow { - return "" - } - took := call.Took(ctx) - return criteria.Timetravel(took) -} - -func maxSequenceScanner(row scan, dest interface{}) (err error) { +func maxSequenceScanner(row scan, dest any) (err error) { position, ok := dest.(*sql.NullFloat64) if !ok { return zerrors.ThrowInvalidArgumentf(nil, "SQL-NBjA9", "type must be sql.NullInt64 got: %T", dest) diff --git a/internal/eventstore/repository/sql/query_test.go b/internal/eventstore/repository/sql/query_test.go index abac19ead0..3df819be64 100644 --- a/internal/eventstore/repository/sql/query_test.go +++ b/internal/eventstore/repository/sql/query_test.go @@ -14,10 +14,11 @@ import ( "github.com/stretchr/testify/assert" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/cockroach" db_mock "github.com/zitadel/zitadel/internal/database/mock" + "github.com/zitadel/zitadel/internal/database/postgres" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore/repository" + new_es "github.com/zitadel/zitadel/internal/eventstore/v3" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -68,7 +69,7 @@ func Test_getCondition(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - db := &CRDB{} + db := &Postgres{} if got := getCondition(db, tt.args.filter, false); got != tt.want { t.Errorf("getCondition() = %v, want %v", got, tt.want) } @@ -236,8 +237,7 @@ func Test_prepareColumns(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - crdb := &CRDB{} - query, rowScanner := prepareColumns(crdb, tt.args.columns, tt.args.useV1) + query, rowScanner := prepareColumns(new(Postgres), tt.args.columns, tt.args.useV1) if query != tt.res.query { t.Errorf("prepareColumns() got = %s, want %s", query, tt.res.query) } @@ -267,7 +267,7 @@ func Test_prepareColumns(t *testing.T) { got := reflect.Indirect(reflect.ValueOf(tt.args.dest)).Interface() if !reflect.DeepEqual(got, tt.res.expected) { - t.Errorf("unexpected result from rowScanner \nwant: %+v \ngot: %+v", tt.res.expected, got) + t.Errorf("unexpected result from rowScanner nwant: %+v ngot: %+v", tt.res.expected, got) } }) } @@ -403,7 +403,7 @@ func Test_prepareCondition(t *testing.T) { useV1: true, }, res: res{ - clause: " WHERE aggregate_type = ANY(?) AND creation_date::TIMESTAMP < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = ANY(?))", + clause: " WHERE aggregate_type = ANY(?) AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')", values: []interface{}{[]eventstore.AggregateType{"user", "org"}, database.TextArray[string]{}}, }, }, @@ -420,7 +420,7 @@ func Test_prepareCondition(t *testing.T) { }, }, res: res{ - clause: ` WHERE aggregate_type = ANY(?) AND hlc_to_timestamp("position") < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = ANY(?))`, + clause: ` WHERE aggregate_type = ANY(?) AND "position" < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')`, values: []interface{}{[]eventstore.AggregateType{"user", "org"}, database.TextArray[string]{}}, }, }, @@ -440,7 +440,7 @@ func Test_prepareCondition(t *testing.T) { useV1: true, }, res: res{ - clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) AND creation_date::TIMESTAMP < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = ANY(?))", + clause: " WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')", values: []interface{}{[]eventstore.AggregateType{"user", "org"}, "1234", []eventstore.EventType{"user.created", "org.created"}, database.TextArray[string]{}}, }, }, @@ -459,15 +459,14 @@ func Test_prepareCondition(t *testing.T) { }, }, res: res{ - clause: ` WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) AND hlc_to_timestamp("position") < (SELECT COALESCE(MIN(start), NOW())::TIMESTAMP FROM crdb_internal.cluster_transactions where application_name = ANY(?))`, + clause: ` WHERE aggregate_type = ANY(?) AND aggregate_id = ? AND event_type = ANY(?) AND "position" < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY(?) AND state <> 'idle')`, values: []interface{}{[]eventstore.AggregateType{"user", "org"}, "1234", []eventstore.EventType{"user.created", "org.created"}, database.TextArray[string]{}}, }, }, } - crdb := NewCRDB(&database.DB{Database: new(cockroach.Config)}) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - gotClause, gotValues := prepareConditions(crdb, tt.args.query, tt.args.useV1) + gotClause, gotValues := prepareConditions(NewPostgres(&database.DB{Database: new(postgres.Config)}), tt.args.query, tt.args.useV1) if gotClause != tt.res.clause { t.Errorf("prepareCondition() gotClause = %v, want %v", gotClause, tt.res.clause) } @@ -484,7 +483,7 @@ func Test_prepareCondition(t *testing.T) { } } -func Test_query_events_with_crdb(t *testing.T) { +func Test_query_events_with_postgres(t *testing.T) { type args struct { searchQuery *eventstore.SearchQueryBuilder } @@ -511,7 +510,7 @@ func Test_query_events_with_crdb(t *testing.T) { Builder(), }, fields: fields{ - client: testCRDBClient, + client: testClient, existingEvents: []eventstore.Command{ generateEvent(t, "300"), generateEvent(t, "300"), @@ -532,7 +531,7 @@ func Test_query_events_with_crdb(t *testing.T) { Builder(), }, fields: fields{ - client: testCRDBClient, + client: testClient, existingEvents: []eventstore.Command{ generateEvent(t, "301"), generateEvent(t, "302"), @@ -555,7 +554,7 @@ func Test_query_events_with_crdb(t *testing.T) { Builder(), }, fields: fields{ - client: testCRDBClient, + client: testClient, existingEvents: []eventstore.Command{ generateEvent(t, "303"), generateEvent(t, "303"), @@ -576,7 +575,7 @@ func Test_query_events_with_crdb(t *testing.T) { ResourceOwner("caos"), }, fields: fields{ - client: testCRDBClient, + client: testClient, existingEvents: []eventstore.Command{ generateEvent(t, "306", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }), generateEvent(t, "307", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }), @@ -599,7 +598,7 @@ func Test_query_events_with_crdb(t *testing.T) { Builder(), }, fields: fields{ - client: testCRDBClient, + client: testClient, existingEvents: []eventstore.Command{ generateEvent(t, "311", func(e *repository.Event) { e.Typ = "user.created" }), generateEvent(t, "311", func(e *repository.Event) { e.Typ = "user.updated" }), @@ -623,7 +622,7 @@ func Test_query_events_with_crdb(t *testing.T) { searchQuery: eventstore.NewSearchQueryBuilder(eventstore.Columns(-1)), }, fields: fields{ - client: testCRDBClient, + client: testClient, existingEvents: []eventstore.Command{}, }, res: res{ @@ -634,117 +633,37 @@ func Test_query_events_with_crdb(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - db := &CRDB{ - DB: &database.DB{ - DB: tt.fields.client, - Database: new(testDB), - }, + dbClient := &database.DB{ + DB: tt.fields.client, + Database: new(testDB), } + client := &Postgres{ + DB: dbClient, + } + + pusher := new_es.NewEventstore(dbClient) // setup initial data for query - if _, err := db.Push(context.Background(), tt.fields.existingEvents...); err != nil { + if _, err := pusher.Push(context.Background(), dbClient.DB, tt.fields.existingEvents...); err != nil { t.Errorf("error in setup = %v", err) return } events := []eventstore.Event{} - if err := query(context.Background(), db, tt.args.searchQuery, eventstore.Reducer(func(event eventstore.Event) error { + if err := query(context.Background(), client, tt.args.searchQuery, eventstore.Reducer(func(event eventstore.Event) error { events = append(events, event) return nil }), true); (err != nil) != tt.wantErr { - t.Errorf("CRDB.query() error = %v, wantErr %v", err, tt.wantErr) + t.Errorf("eventstore.query() error = %v, wantErr %v", err, tt.wantErr) } }) } } -/* Cockroach test DB doesn't seem to lock -func Test_query_events_with_crdb_locking(t *testing.T) { - type args struct { - searchQuery *eventstore.SearchQueryBuilder - } - type fields struct { - existingEvents []eventstore.Command - client *sql.DB - } - tests := []struct { - name string - fields fields - args args - lockOption eventstore.LockOption - wantErr bool - }{ - { - name: "skip locked", - fields: fields{ - client: testCRDBClient, - existingEvents: []eventstore.Command{ - generateEvent(t, "306", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }), - generateEvent(t, "307", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }), - generateEvent(t, "308", func(e *repository.Event) { e.ResourceOwner = sql.NullString{String: "caos", Valid: true} }), - }, - }, - args: args{ - searchQuery: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - ResourceOwner("caos"), - }, - lockOption: eventstore.LockOptionNoWait, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - db := &CRDB{ - DB: &database.DB{ - DB: tt.fields.client, - Database: new(testDB), - }, - } - // setup initial data for query - if _, err := db.Push(context.Background(), tt.fields.existingEvents...); err != nil { - t.Errorf("error in setup = %v", err) - return - } - // first TX should lock and return all events - tx1, err := db.DB.Begin() - require.NoError(t, err) - defer func() { - require.NoError(t, tx1.Rollback()) - }() - searchQuery1 := tt.args.searchQuery.LockRowsDuringTx(tx1, tt.lockOption) - gotEvents1 := []eventstore.Event{} - err = query(context.Background(), db, searchQuery1, eventstore.Reducer(func(event eventstore.Event) error { - gotEvents1 = append(gotEvents1, event) - return nil - }), true) - require.NoError(t, err) - assert.Len(t, gotEvents1, len(tt.fields.existingEvents)) - - // second TX should not return the events, and might return an error - tx2, err := db.DB.Begin() - require.NoError(t, err) - defer func() { - require.NoError(t, tx2.Rollback()) - }() - searchQuery2 := tt.args.searchQuery.LockRowsDuringTx(tx1, tt.lockOption) - gotEvents2 := []eventstore.Event{} - err = query(context.Background(), db, searchQuery2, eventstore.Reducer(func(event eventstore.Event) error { - gotEvents2 = append(gotEvents2, event) - return nil - }), true) - if tt.wantErr { - require.Error(t, err) - } - require.NoError(t, err) - assert.Len(t, gotEvents2, 0) - }) - } -} -*/ - func Test_query_events_mocked(t *testing.T) { type args struct { query *eventstore.SearchQueryBuilder - dest interface{} + dest any useV1 bool } type res struct { @@ -772,8 +691,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = ANY\(\$2\)\) ORDER BY event_sequence DESC`, + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY($2) AND state <> 'idle') ORDER BY event_sequence DESC`), []driver.Value{eventstore.AggregateType("user"), database.TextArray[string]{}}, ), }, @@ -795,8 +714,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = ANY\(\$2\)\) ORDER BY event_sequence LIMIT \$3`, + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY($2) AND state <> 'idle') ORDER BY event_sequence LIMIT $3`), []driver.Value{eventstore.AggregateType("user"), database.TextArray[string]{}, uint64(5)}, ), }, @@ -818,32 +737,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = ANY\(\$2\)\) ORDER BY event_sequence DESC LIMIT \$3`, - []driver.Value{eventstore.AggregateType("user"), database.TextArray[string]{}, uint64(5)}, - ), - }, - res: res{ - wantErr: false, - }, - }, - { - name: "with limit and order by desc as of system time", - args: args{ - dest: &[]*repository.Event{}, - query: eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). - OrderDesc(). - AwaitOpenTransactions(). - Limit(5). - AllowTimeTravel(). - AddQuery(). - AggregateTypes("user"). - Builder(), - useV1: true, - }, - fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events AS OF SYSTEM TIME '-1 ms' WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = ANY\(\$2\)\) ORDER BY event_sequence DESC LIMIT \$3`, + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY($2) AND state <> 'idle') ORDER BY event_sequence DESC LIMIT $3`), []driver.Value{eventstore.AggregateType("user"), database.TextArray[string]{}, uint64(5)}, ), }, @@ -864,8 +759,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 ORDER BY event_sequence DESC LIMIT \$2 FOR UPDATE`, + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2 FOR UPDATE`), []driver.Value{eventstore.AggregateType("user"), uint64(5)}, ), }, @@ -886,8 +781,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 ORDER BY event_sequence DESC LIMIT \$2 FOR UPDATE NOWAIT`, + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2 FOR UPDATE NOWAIT`), []driver.Value{eventstore.AggregateType("user"), uint64(5)}, ), }, @@ -908,8 +803,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 ORDER BY event_sequence DESC LIMIT \$2 FOR UPDATE SKIP LOCKED`, + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 ORDER BY event_sequence DESC LIMIT $2 FOR UPDATE SKIP LOCKED`), []driver.Value{eventstore.AggregateType("user"), uint64(5)}, ), }, @@ -931,8 +826,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQueryErr(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = ANY\(\$2\)\) ORDER BY event_sequence DESC`, + mock: newMockClient(t).expectQueryErr( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY($2) AND state <> 'idle') ORDER BY event_sequence DESC`), []driver.Value{eventstore.AggregateType("user"), database.TextArray[string]{}}, sql.ErrConnDone), }, @@ -954,8 +849,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQueryScanErr(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = \$1 AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = ANY\(\$2\)\) ORDER BY event_sequence DESC`, + mock: newMockClient(t).expectQueryScanErr( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE aggregate_type = $1 AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY($2) AND state <> 'idle') ORDER BY event_sequence DESC`), []driver.Value{eventstore.AggregateType("user"), database.TextArray[string]{}}, &repository.Event{Seq: 100}), }, @@ -989,8 +884,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \(aggregate_type = \$1 OR \(aggregate_type = \$2 AND aggregate_id = \$3\)\) AND creation_date::TIMESTAMP < \(SELECT COALESCE\(MIN\(start\), NOW\(\)\)::TIMESTAMP FROM crdb_internal\.cluster_transactions where application_name = ANY\(\$4\)\) ORDER BY event_sequence DESC LIMIT \$5`, + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE (aggregate_type = $1 OR (aggregate_type = $2 AND aggregate_id = $3)) AND EXTRACT(EPOCH FROM created_at) < (SELECT COALESCE(EXTRACT(EPOCH FROM min(xact_start)), EXTRACT(EPOCH FROM now())) FROM pg_stat_activity WHERE datname = current_database() AND application_name = ANY($4) AND state <> 'idle') ORDER BY event_sequence DESC LIMIT $5`), []driver.Value{eventstore.AggregateType("user"), eventstore.AggregateType("org"), "asdf42", database.TextArray[string]{}, uint64(5)}, ), }, @@ -1018,10 +913,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: true, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - regexp.QuoteMeta( - `SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE instance_id = $1 AND aggregate_type = $2 AND event_type = $3 AND "position" > $4 AND aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events WHERE aggregate_type = $5 AND event_type = ANY($6) AND instance_id = $7 AND "position" > $8) ORDER BY event_sequence DESC LIMIT $9`, - ), + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT creation_date, event_type, event_sequence, event_data, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE instance_id = $1 AND aggregate_type = $2 AND event_type = $3 AND "position" > $4 AND aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events WHERE aggregate_type = $5 AND event_type = ANY($6) AND instance_id = $7 AND "position" > $8) ORDER BY event_sequence DESC LIMIT $9`), []driver.Value{"instanceID", eventstore.AggregateType("notify"), eventstore.EventType("notify.foo.bar"), 123.456, eventstore.AggregateType("notify"), []eventstore.EventType{"notification.failed", "notification.success"}, "instanceID", 123.456, uint64(5)}, ), }, @@ -1049,10 +942,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: false, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - regexp.QuoteMeta( - `SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND event_type = $3 AND "position" > $4 AND aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events2 WHERE aggregate_type = $5 AND event_type = ANY($6) AND instance_id = $7 AND "position" > $8) ORDER BY "position" DESC, in_tx_order DESC LIMIT $9`, - ), + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND event_type = $3 AND "position" > $4 AND aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events2 WHERE aggregate_type = $5 AND event_type = ANY($6) AND instance_id = $7 AND "position" > $8) ORDER BY "position" DESC, in_tx_order DESC LIMIT $9`), []driver.Value{"instanceID", eventstore.AggregateType("notify"), eventstore.EventType("notify.foo.bar"), 123.456, eventstore.AggregateType("notify"), []eventstore.EventType{"notification.failed", "notification.success"}, "instanceID", 123.456, uint64(5)}, ), }, @@ -1080,10 +971,8 @@ func Test_query_events_mocked(t *testing.T) { useV1: false, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, - regexp.QuoteMeta( - `SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND event_type = $3 AND created_at > $4 AND aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events2 WHERE aggregate_type = $5 AND event_type = ANY($6) AND instance_id = $7 AND created_at > $8) ORDER BY "position" DESC, in_tx_order DESC LIMIT $9`, - ), + mock: newMockClient(t).expectQuery( + regexp.QuoteMeta(`SELECT created_at, event_type, "sequence", "position", payload, creator, "owner", instance_id, aggregate_type, aggregate_id, revision FROM eventstore.events2 WHERE instance_id = $1 AND aggregate_type = $2 AND event_type = $3 AND created_at > $4 AND aggregate_id NOT IN (SELECT aggregate_id FROM eventstore.events2 WHERE aggregate_type = $5 AND event_type = ANY($6) AND instance_id = $7 AND created_at > $8) ORDER BY "position" DESC, in_tx_order DESC LIMIT $9`), []driver.Value{"instanceID", eventstore.AggregateType("notify"), eventstore.EventType("notify.foo.bar"), time.Unix(123, 456), eventstore.AggregateType("notify"), []eventstore.EventType{"notification.failed", "notification.success"}, "instanceID", time.Unix(123, 456), uint64(5)}, ), }, @@ -1092,14 +981,14 @@ func Test_query_events_mocked(t *testing.T) { }, }, } - crdb := NewCRDB(&database.DB{Database: new(testDB)}) + client := NewPostgres(&database.DB{Database: new(testDB)}) for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { if tt.fields.mock != nil { - crdb.DB.DB = tt.fields.mock.client + client.DB.DB = tt.fields.mock.client } - err := query(context.Background(), crdb, tt.args.query, tt.args.dest, tt.args.useV1) + err := query(context.Background(), client, tt.args.query, tt.args.dest, tt.args.useV1) if (err != nil) != tt.res.wantErr { t.Errorf("query() error = %v, wantErr %v", err, tt.res.wantErr) } @@ -1120,7 +1009,7 @@ type dbMock struct { client *sql.DB } -func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock { +func (m *dbMock) expectQuery(expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock { query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...) rows := m.mock.NewRows([]string{"sequence"}) for _, event := range events { @@ -1130,7 +1019,7 @@ func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.V return m } -func (m *dbMock) expectQueryScanErr(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock { +func (m *dbMock) expectQueryScanErr(expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock { query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...) rows := m.mock.NewRows([]string{"sequence"}) for _, event := range events { @@ -1140,7 +1029,7 @@ func (m *dbMock) expectQueryScanErr(t *testing.T, expectedQuery string, args []d return m } -func (m *dbMock) expectQueryErr(t *testing.T, expectedQuery string, args []driver.Value, err error) *dbMock { +func (m *dbMock) expectQueryErr(expectedQuery string, args []driver.Value, err error) *dbMock { m.mock.ExpectQuery(expectedQuery).WithArgs(args...).WillReturnError(err) return m } diff --git a/internal/eventstore/search_query.go b/internal/eventstore/search_query.go index df38d15def..1596936a36 100644 --- a/internal/eventstore/search_query.go +++ b/internal/eventstore/search_query.go @@ -25,7 +25,6 @@ type SearchQueryBuilder struct { tx *sql.Tx lockRows bool lockOption LockOption - allowTimeTravel bool positionAfter float64 awaitOpenTransactions bool creationDateAfter time.Time @@ -77,10 +76,6 @@ func (b *SearchQueryBuilder) GetTx() *sql.Tx { return b.tx } -func (b *SearchQueryBuilder) GetAllowTimeTravel() bool { - return b.allowTimeTravel -} - func (b SearchQueryBuilder) GetPositionAfter() float64 { return b.positionAfter } @@ -289,13 +284,6 @@ func (builder *SearchQueryBuilder) EditorUser(id string) *SearchQueryBuilder { return builder } -// AllowTimeTravel activates the time travel feature of the database if supported -// The queries will be made based on the call time -func (builder *SearchQueryBuilder) AllowTimeTravel() *SearchQueryBuilder { - builder.allowTimeTravel = true - return builder -} - // PositionAfter filters for events which happened after the specified time func (builder *SearchQueryBuilder) PositionAfter(position float64) *SearchQueryBuilder { builder.positionAfter = position diff --git a/internal/eventstore/search_query_test.go b/internal/eventstore/search_query_test.go index 8c654911ea..b8f570dc0d 100644 --- a/internal/eventstore/search_query_test.go +++ b/internal/eventstore/search_query_test.go @@ -45,16 +45,6 @@ func testSetLimit(limit uint64) func(builder *SearchQueryBuilder) *SearchQueryBu } } -func testOr(queryFuncs ...func(*SearchQuery) *SearchQuery) func(*SearchQuery) *SearchQuery { - return func(query *SearchQuery) *SearchQuery { - subQuery := query.Or() - for _, queryFunc := range queryFuncs { - queryFunc(subQuery) - } - return subQuery - } -} - func testSetAggregateTypes(types ...AggregateType) func(*SearchQuery) *SearchQuery { return func(query *SearchQuery) *SearchQuery { query = query.AggregateTypes(types...) diff --git a/internal/eventstore/v3/eventstore.go b/internal/eventstore/v3/eventstore.go index 1bb515527c..424805c882 100644 --- a/internal/eventstore/v3/eventstore.go +++ b/internal/eventstore/v3/eventstore.go @@ -24,9 +24,9 @@ func init() { var ( // pushPlaceholderFmt defines how data are inserted into the events table - pushPlaceholderFmt string + pushPlaceholderFmt = "($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $%d)" // uniqueConstraintPlaceholderFmt defines the format of the unique constraint error returned from the database - uniqueConstraintPlaceholderFmt string + uniqueConstraintPlaceholderFmt = "(%s, %s, %s)" _ eventstore.Pusher = (*Eventstore)(nil) ) @@ -158,15 +158,6 @@ func (es *Eventstore) Client() *database.DB { } func NewEventstore(client *database.DB) *Eventstore { - switch client.Type() { - case "cockroach": - pushPlaceholderFmt = "($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, hlc_to_timestamp(cluster_logical_timestamp()), cluster_logical_timestamp(), $%d)" - uniqueConstraintPlaceholderFmt = "('%s', '%s', '%s')" - case "postgres": - pushPlaceholderFmt = "($%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, $%d, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $%d)" - uniqueConstraintPlaceholderFmt = "(%s, %s, %s)" - } - return &Eventstore{client: client} } @@ -200,14 +191,8 @@ func (es *Eventstore) pushTx(ctx context.Context, client database.ContextQueryEx beginner = es.client } - isolationLevel := sql.LevelReadCommitted - // cockroach requires serializable to execute the push function - // because we use [cluster_logical_timestamp()](https://www.cockroachlabs.com/docs/stable/functions-and-operators#system-info-functions) - if es.client.Type() == "cockroach" { - isolationLevel = sql.LevelSerializable - } tx, err = beginner.BeginTx(ctx, &sql.TxOptions{ - Isolation: isolationLevel, + Isolation: sql.LevelReadCommitted, ReadOnly: false, }) if err != nil { diff --git a/internal/eventstore/v3/push_test.go b/internal/eventstore/v3/push_test.go index a6c4f515fd..da583891e9 100644 --- a/internal/eventstore/v3/push_test.go +++ b/internal/eventstore/v3/push_test.go @@ -8,7 +8,7 @@ import ( "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/database/cockroach" + "github.com/zitadel/zitadel/internal/database/postgres" "github.com/zitadel/zitadel/internal/eventstore" ) @@ -65,7 +65,7 @@ func Test_mapCommands(t *testing.T) { ), }, placeHolders: []string{ - "($1, $2, $3, $4, $5, $6, $7, $8, $9, hlc_to_timestamp(cluster_logical_timestamp()), cluster_logical_timestamp(), $10)", + "($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)", }, args: []any{ "instance", @@ -114,8 +114,8 @@ func Test_mapCommands(t *testing.T) { ), }, placeHolders: []string{ - "($1, $2, $3, $4, $5, $6, $7, $8, $9, hlc_to_timestamp(cluster_logical_timestamp()), cluster_logical_timestamp(), $10)", - "($11, $12, $13, $14, $15, $16, $17, $18, $19, hlc_to_timestamp(cluster_logical_timestamp()), cluster_logical_timestamp(), $20)", + "($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)", + "($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)", }, args: []any{ // first event @@ -180,8 +180,8 @@ func Test_mapCommands(t *testing.T) { ), }, placeHolders: []string{ - "($1, $2, $3, $4, $5, $6, $7, $8, $9, hlc_to_timestamp(cluster_logical_timestamp()), cluster_logical_timestamp(), $10)", - "($11, $12, $13, $14, $15, $16, $17, $18, $19, hlc_to_timestamp(cluster_logical_timestamp()), cluster_logical_timestamp(), $20)", + "($1, $2, $3, $4, $5, $6, $7, $8, $9, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $10)", + "($11, $12, $13, $14, $15, $16, $17, $18, $19, statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()), $20)", }, args: []any{ // first event @@ -236,7 +236,7 @@ func Test_mapCommands(t *testing.T) { } } // is used to set the the [pushPlaceholderFmt] - NewEventstore(&database.DB{Database: new(cockroach.Config)}) + NewEventstore(&database.DB{Database: new(postgres.Config)}) t.Run(tt.name, func(t *testing.T) { defer func() { cause := recover() diff --git a/internal/eventstore/v3/push_without_func.go b/internal/eventstore/v3/push_without_func.go index 914b880204..b94a9e8f54 100644 --- a/internal/eventstore/v3/push_without_func.go +++ b/internal/eventstore/v3/push_without_func.go @@ -7,7 +7,6 @@ import ( "fmt" "strings" - "github.com/cockroachdb/cockroach-go/v2/crdb" "github.com/jackc/pgx/v5/pgconn" "github.com/zitadel/logging" @@ -16,25 +15,6 @@ import ( "github.com/zitadel/zitadel/internal/zerrors" ) -type transaction struct { - database.Tx -} - -var _ crdb.Tx = (*transaction)(nil) - -func (t *transaction) Exec(ctx context.Context, query string, args ...interface{}) error { - _, err := t.Tx.ExecContext(ctx, query, args...) - return err -} - -func (t *transaction) Commit(ctx context.Context) error { - return t.Tx.Commit() -} - -func (t *transaction) Rollback(ctx context.Context) error { - return t.Tx.Rollback() -} - // checks whether the error is caused because setup step 39 was not executed func isSetupNotExecutedError(err error) bool { if err == nil { @@ -64,7 +44,6 @@ func (es *Eventstore) pushWithoutFunc(ctx context.Context, client database.Conte err = closeTx(err) }() - // tx is not closed because [crdb.ExecuteInTx] takes care of that var ( sequences []*latestSequence ) diff --git a/internal/execution/execution_test.go b/internal/execution/execution_test.go index 5a45d96625..40731a840a 100644 --- a/internal/execution/execution_test.go +++ b/internal/execution/execution_test.go @@ -61,7 +61,7 @@ func Test_Call(t *testing.T) { args{ ctx: context.Background(), timeout: time.Second, - sleep: time.Second, + sleep: 2 * time.Second, method: http.MethodPost, body: []byte("{\"request\": \"values\"}"), respBody: []byte("{\"response\": \"values\"}"), diff --git a/internal/integration/config/cockroach.yaml b/internal/integration/config/cockroach.yaml deleted file mode 100644 index 920e3cd6ec..0000000000 --- a/internal/integration/config/cockroach.yaml +++ /dev/null @@ -1,10 +0,0 @@ -Database: - cockroach: - Host: localhost - Port: 26257 - Database: zitadel - Options: "" - User: - Username: zitadel - Admin: - Username: root diff --git a/internal/integration/config/docker-compose.yaml b/internal/integration/config/docker-compose.yaml index 19c68ae405..8b54a22aec 100644 --- a/internal/integration/config/docker-compose.yaml +++ b/internal/integration/config/docker-compose.yaml @@ -1,11 +1,6 @@ version: '3.8' services: - cockroach: - extends: - file: '../../../e2e/config/localhost/docker-compose.yaml' - service: 'db' - postgres: restart: 'always' image: 'postgres:latest' diff --git a/internal/integration/config/postgres.yaml b/internal/integration/config/postgres.yaml index df1d08d3bc..904f973d56 100644 --- a/internal/integration/config/postgres.yaml +++ b/internal/integration/config/postgres.yaml @@ -1,16 +1,11 @@ Database: - EventPushConnRatio: 0.2 # 4 - ProjectionSpoolerConnRatio: 0.3 # 6 postgres: - Host: localhost - Port: 5432 - Database: zitadel MaxOpenConns: 20 MaxIdleConns: 20 MaxConnLifetime: 1h MaxConnIdleTime: 5m User: - Username: zitadel + Password: zitadel SSL: Mode: disable Admin: diff --git a/internal/query/action.go b/internal/query/action.go index 30ded403d1..45017572e2 100644 --- a/internal/query/action.go +++ b/internal/query/action.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -118,7 +117,7 @@ func (q *Queries) SearchActions(ctx context.Context, queries *ActionSearchQuerie ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareActionsQuery(ctx, q.client) + query, scan := prepareActionsQuery() eq := sq.Eq{ ActionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } @@ -146,7 +145,7 @@ func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, wi ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareActionQuery(ctx, q.client) + stmt, scan := prepareActionQuery() eq := sq.Eq{ ActionColumnID.identifier(): id, ActionColumnResourceOwner.identifier(): orgID, @@ -183,7 +182,7 @@ func NewActionIDSearchQuery(id string) (SearchQuery, error) { return NewTextQuery(ActionColumnID, id, TextEquals) } -func prepareActionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*Actions, error)) { +func prepareActionsQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*Actions, error)) { return sq.Select( ActionColumnID.identifier(), ActionColumnCreationDate.identifier(), @@ -196,7 +195,7 @@ func prepareActionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil ActionColumnTimeout.identifier(), ActionColumnAllowedToFail.identifier(), countColumn.identifier(), - ).From(actionTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(actionTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Actions, error) { actions := make([]*Action, 0) @@ -235,7 +234,7 @@ func prepareActionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil } } -func prepareActionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) (*Action, error)) { +func prepareActionQuery() (sq.SelectBuilder, func(row *sql.Row) (*Action, error)) { return sq.Select( ActionColumnID.identifier(), ActionColumnCreationDate.identifier(), @@ -247,7 +246,7 @@ func prepareActionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuild ActionColumnScript.identifier(), ActionColumnTimeout.identifier(), ActionColumnAllowedToFail.identifier(), - ).From(actionTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(actionTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Action, error) { action := new(Action) diff --git a/internal/query/action_flow.go b/internal/query/action_flow.go index c5263d6c43..6011b3f6e1 100644 --- a/internal/query/action_flow.go +++ b/internal/query/action_flow.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -67,7 +66,7 @@ func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID s ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareFlowQuery(ctx, q.client, flowType) + query, scan := prepareFlowQuery(flowType) eq := sq.Eq{ FlowsTriggersColumnFlowType.identifier(): flowType, FlowsTriggersColumnResourceOwner.identifier(): orgID, @@ -89,7 +88,7 @@ func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flow ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareTriggerActionsQuery(ctx, q.client) + stmt, scan := prepareTriggerActionsQuery() eq := sq.Eq{ FlowsTriggersColumnFlowType.identifier(): flowType, FlowsTriggersColumnTriggerType.identifier(): triggerType, @@ -113,7 +112,7 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string) ( ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareFlowTypesQuery(ctx, q.client) + stmt, scan := prepareFlowTypesQuery() eq := sq.Eq{ FlowsTriggersColumnActionID.identifier(): actionID, FlowsTriggersColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -130,11 +129,11 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string) ( return types, err } -func prepareFlowTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowType, error)) { +func prepareFlowTypesQuery() (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowType, error)) { return sq.Select( FlowsTriggersColumnFlowType.identifier(), ). - From(flowsTriggersTable.identifier() + db.Timetravel(call.Took(ctx))). + From(flowsTriggersTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) ([]domain.FlowType, error) { types := []domain.FlowType{} @@ -153,7 +152,7 @@ func prepareFlowTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu } -func prepareTriggerActionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]*Action, error)) { +func prepareTriggerActionsQuery() (sq.SelectBuilder, func(*sql.Rows) ([]*Action, error)) { return sq.Select( ActionColumnID.identifier(), ActionColumnCreationDate.identifier(), @@ -167,7 +166,7 @@ func prepareTriggerActionsQuery(ctx context.Context, db prepareDatabase) (sq.Sel ActionColumnTimeout.identifier(), ). From(flowsTriggersTable.name). - LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID)). OrderBy(FlowsTriggersColumnTriggerSequence.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) ([]*Action, error) { @@ -200,7 +199,7 @@ func prepareTriggerActionsQuery(ctx context.Context, db prepareDatabase) (sq.Sel } } -func prepareFlowQuery(ctx context.Context, db prepareDatabase, flowType domain.FlowType) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { +func prepareFlowQuery(flowType domain.FlowType) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { return sq.Select( ActionColumnID.identifier(), ActionColumnCreationDate.identifier(), @@ -220,7 +219,7 @@ func prepareFlowQuery(ctx context.Context, db prepareDatabase, flowType domain.F FlowsTriggersColumnResourceOwner.identifier(), ). From(flowsTriggersTable.name). - LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(ActionColumnID, FlowsTriggersColumnActionID)). OrderBy(FlowsTriggersColumnTriggerSequence.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Flow, error) { diff --git a/internal/query/action_flow_test.go b/internal/query/action_flow_test.go index af0db27278..7447313064 100644 --- a/internal/query/action_flow_test.go +++ b/internal/query/action_flow_test.go @@ -1,7 +1,6 @@ package query import ( - "context" "database/sql" "database/sql/driver" "errors" @@ -34,7 +33,6 @@ var ( ` projections.flow_triggers3.resource_owner` + ` FROM projections.flow_triggers3` + ` LEFT JOIN projections.actions3 ON projections.flow_triggers3.action_id = projections.actions3.id AND projections.flow_triggers3.instance_id = projections.actions3.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` ORDER BY projections.flow_triggers3.trigger_sequence` prepareFlowCols = []string{ "id", @@ -68,7 +66,6 @@ var ( ` projections.actions3.timeout` + ` FROM projections.flow_triggers3` + ` LEFT JOIN projections.actions3 ON projections.flow_triggers3.action_id = projections.actions3.id AND projections.flow_triggers3.instance_id = projections.actions3.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` ORDER BY projections.flow_triggers3.trigger_sequence` prepareTriggerActionCols = []string{ @@ -86,7 +83,6 @@ var ( prepareFlowTypeStmt = `SELECT projections.flow_triggers3.flow_type` + ` FROM projections.flow_triggers3` - // ` AS OF SYSTEM TIME '-1 ms'` prepareFlowTypeCols = []string{ "flow_type", @@ -106,8 +102,8 @@ func Test_FlowPrepares(t *testing.T) { }{ { name: "prepareFlowQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { - return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) }, want: want{ sqlExpectations: mockQueries( @@ -123,8 +119,8 @@ func Test_FlowPrepares(t *testing.T) { }, { name: "prepareFlowQuery one action", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { - return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) }, want: want{ sqlExpectations: mockQueries( @@ -177,8 +173,8 @@ func Test_FlowPrepares(t *testing.T) { }, { name: "prepareFlowQuery multiple actions", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { - return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) }, want: want{ sqlExpectations: mockQueries( @@ -263,8 +259,8 @@ func Test_FlowPrepares(t *testing.T) { }, { name: "prepareFlowQuery no action", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { - return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) }, want: want{ sqlExpectations: mockQueries( @@ -302,8 +298,8 @@ func Test_FlowPrepares(t *testing.T) { }, { name: "prepareFlowQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { - return prepareFlowQuery(ctx, db, domain.FlowTypeExternalAuthentication) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Flow, error)) { + return prepareFlowQuery(domain.FlowTypeExternalAuthentication) }, want: want{ sqlExpectations: mockQueryErr( @@ -520,7 +516,7 @@ func Test_FlowPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/action_test.go b/internal/query/action_test.go index f6ba5be4b9..e5cad0e269 100644 --- a/internal/query/action_test.go +++ b/internal/query/action_test.go @@ -26,7 +26,6 @@ var ( ` projections.actions3.allowed_to_fail,` + ` COUNT(*) OVER ()` + ` FROM projections.actions3` - // ` AS OF SYSTEM TIME '-1 ms'` prepareActionsCols = []string{ "id", "creation_date", @@ -52,7 +51,6 @@ var ( ` projections.actions3.timeout,` + ` projections.actions3.allowed_to_fail` + ` FROM projections.actions3` - // ` AS OF SYSTEM TIME '-1 ms'` prepareActionCols = []string{ "id", "creation_date", @@ -289,7 +287,7 @@ func Test_ActionPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/app.go b/internal/query/app.go index fafbbe72d9..5fed1e3ced 100644 --- a/internal/query/app.go +++ b/internal/query/app.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" @@ -290,7 +289,7 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo traceSpan.EndWithError(err) } - stmt, scan := prepareAppQuery(ctx, q.client, false) + stmt, scan := prepareAppQuery(false) eq := sq.Eq{ AppColumnID.identifier(): appID, AppColumnProjectID.identifier(): projectID, @@ -312,7 +311,7 @@ func (q *Queries) AppByID(ctx context.Context, appID string, activeOnly bool) (a ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareAppQuery(ctx, q.client, activeOnly) + stmt, scan := prepareAppQuery(activeOnly) eq := sq.Eq{ AppColumnID.identifier(): appID, AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -338,7 +337,7 @@ func (q *Queries) ProjectByClientID(ctx context.Context, appID string) (project ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareProjectByAppQuery(ctx, q.client) + stmt, scan := prepareProjectByAppQuery() eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} query, args, err := stmt.Where(sq.And{ eq, @@ -434,7 +433,7 @@ func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string) (id s ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareProjectIDByAppQuery(ctx, q.client) + stmt, scan := prepareProjectIDByAppQuery() eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} where := sq.And{ eq, @@ -460,7 +459,7 @@ func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string) (project ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareProjectByOIDCAppQuery(ctx, q.client) + stmt, scan := prepareProjectByOIDCAppQuery() eq := sq.Eq{ AppOIDCConfigColumnClientID.identifier(): id, AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -502,7 +501,7 @@ func (q *Queries) AppByClientID(ctx context.Context, clientID string) (app *App, ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareAppQuery(ctx, q.client, true) + stmt, scan := prepareAppQuery(true) eq := sq.Eq{ AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), AppColumnState.identifier(): domain.AppStateActive, @@ -531,7 +530,7 @@ func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, wit ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareAppsQuery(ctx, q.client) + query, scan := prepareAppsQuery() eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -560,7 +559,7 @@ func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries traceSpan.EndWithError(err) } - query, scan := prepareClientIDsQuery(ctx, q.client) + query, scan := prepareClientIDsQuery() eq := sq.Eq{AppColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -581,7 +580,7 @@ func (q *Queries) OIDCClientLoginVersion(ctx context.Context, clientID string) ( ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginVersionByOIDCClientID(ctx, q.client) + query, scan := prepareLoginVersionByOIDCClientID() eq := sq.Eq{ AppOIDCConfigColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), AppOIDCConfigColumnClientID.identifier(): clientID, @@ -605,7 +604,7 @@ func (q *Queries) SAMLAppLoginVersion(ctx context.Context, appID string) (loginV ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginVersionBySAMLAppID(ctx, q.client) + query, scan := prepareLoginVersionBySAMLAppID() eq := sq.Eq{ AppSAMLConfigColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), AppSAMLConfigColumnAppID.identifier(): appID, @@ -633,7 +632,7 @@ func NewAppProjectIDSearchQuery(id string) (SearchQuery, error) { return NewTextQuery(AppColumnProjectID, id, TextEquals) } -func prepareAppQuery(ctx context.Context, db prepareDatabase, activeOnly bool) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { +func prepareAppQuery(activeOnly bool) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { query := sq.Select( AppColumnID.identifier(), AppColumnName.identifier(), @@ -684,13 +683,13 @@ func prepareAppQuery(ctx context.Context, db prepareDatabase, activeOnly bool) ( LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)). LeftJoin(join(ProjectColumnID, AppColumnProjectID)). - LeftJoin(join(OrgColumnID, AppColumnResourceOwner) + db.Timetravel(call.Took(ctx))), + LeftJoin(join(OrgColumnID, AppColumnResourceOwner)), scanApp } return query. LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). - LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))), + LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)), scanApp } @@ -845,13 +844,13 @@ func prepareOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { } } -func prepareProjectIDByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (projectID string, err error)) { +func prepareProjectIDByAppQuery() (sq.SelectBuilder, func(*sql.Row) (projectID string, err error)) { return sq.Select( AppColumnProjectID.identifier(), ).From(appsTable.identifier()). LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). - LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (projectID string, err error) { err = row.Scan( &projectID, @@ -868,7 +867,7 @@ func prepareProjectIDByAppQuery(ctx context.Context, db prepareDatabase) (sq.Sel } } -func prepareProjectByOIDCAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { +func prepareProjectByOIDCAppQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { return sq.Select( ProjectColumnID.identifier(), ProjectColumnCreationDate.identifier(), @@ -910,7 +909,7 @@ func prepareProjectByOIDCAppQuery(ctx context.Context, db prepareDatabase) (sq.S } } -func prepareProjectByAppQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { +func prepareProjectByAppQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { return sq.Select( ProjectColumnID.identifier(), ProjectColumnCreationDate.identifier(), @@ -927,7 +926,7 @@ func prepareProjectByAppQuery(ctx context.Context, db prepareDatabase) (sq.Selec Join(join(AppColumnProjectID, ProjectColumnID)). LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). - LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Project, error) { p := new(Project) @@ -954,7 +953,7 @@ func prepareProjectByAppQuery(ctx context.Context, db prepareDatabase) (sq.Selec } } -func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) { +func prepareAppsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Apps, error)) { return sq.Select( AppColumnID.identifier(), AppColumnName.identifier(), @@ -1000,7 +999,7 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder ).From(appsTable.identifier()). LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). - LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(AppSAMLConfigColumnAppID, AppColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Rows) (*Apps, error) { apps := &Apps{Apps: []*App{}} @@ -1072,13 +1071,13 @@ func prepareAppsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder } } -func prepareClientIDsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]string, error)) { +func prepareClientIDsQuery() (sq.SelectBuilder, func(*sql.Rows) ([]string, error)) { return sq.Select( AppAPIConfigColumnClientID.identifier(), AppOIDCConfigColumnClientID.identifier(), ).From(appsTable.identifier()). LeftJoin(join(AppAPIConfigColumnAppID, AppColumnID)). - LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(AppOIDCConfigColumnAppID, AppColumnID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) ([]string, error) { ids := database.TextArray[string]{} @@ -1102,7 +1101,7 @@ func prepareClientIDsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu } } -func prepareLoginVersionByOIDCClientID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) { +func prepareLoginVersionByOIDCClientID() (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) { return sq.Select( AppOIDCConfigColumnLoginVersion.identifier(), ).From(appOIDCConfigsTable.identifier()). @@ -1117,7 +1116,7 @@ func prepareLoginVersionByOIDCClientID(ctx context.Context, db prepareDatabase) } } -func prepareLoginVersionBySAMLAppID(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) { +func prepareLoginVersionBySAMLAppID() (sq.SelectBuilder, func(*sql.Row) (domain.LoginVersion, error)) { return sq.Select( AppSAMLConfigColumnLoginVersion.identifier(), ).From(appSAMLConfigsTable.identifier()). diff --git a/internal/query/app_test.go b/internal/query/app_test.go index dbbcaef47c..c24060a60c 100644 --- a/internal/query/app_test.go +++ b/internal/query/app_test.go @@ -1,7 +1,6 @@ package query import ( - "context" "database/sql" "database/sql/driver" "errors" @@ -111,20 +110,17 @@ var ( ` FROM projections.apps7` + ` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` + ` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id` + - ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id`) expectedAppIDsQuery = regexp.QuoteMeta(`SELECT projections.apps7_api_configs.client_id,` + ` projections.apps7_oidc_configs.client_id` + ` FROM projections.apps7` + ` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` + - ` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id`) expectedProjectIDByAppQuery = regexp.QuoteMeta(`SELECT projections.apps7.project_id` + ` FROM projections.apps7` + ` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` + ` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id` + - ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id`) expectedProjectByAppQuery = regexp.QuoteMeta(`SELECT projections.projects4.id,` + ` projections.projects4.creation_date,` + ` projections.projects4.change_date,` + @@ -140,8 +136,7 @@ var ( ` JOIN projections.apps7 ON projections.projects4.id = projections.apps7.project_id AND projections.projects4.instance_id = projections.apps7.instance_id` + ` LEFT JOIN projections.apps7_api_configs ON projections.apps7.id = projections.apps7_api_configs.app_id AND projections.apps7.instance_id = projections.apps7_api_configs.instance_id` + ` LEFT JOIN projections.apps7_oidc_configs ON projections.apps7.id = projections.apps7_oidc_configs.app_id AND projections.apps7.instance_id = projections.apps7_oidc_configs.instance_id` + - ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.apps7_saml_configs ON projections.apps7.id = projections.apps7_saml_configs.app_id AND projections.apps7.instance_id = projections.apps7_saml_configs.instance_id`) appCols = database.TextArray[string]{ "id", @@ -1228,7 +1223,7 @@ func Test_AppsPrepare(t *testing.T) { if tt.name == "prepareAppsQuery oidc app" { _ = tt.name } - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -1246,8 +1241,8 @@ func Test_AppPrepare(t *testing.T) { }{ { name: "prepareAppQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueriesScanErr( @@ -1266,8 +1261,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery found", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQuery( @@ -1330,8 +1325,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery api app", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueries( @@ -1400,8 +1395,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery oidc app", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueries( @@ -1489,8 +1484,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery oidc app active only", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, true) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(true) }, want: want{ sqlExpectations: mockQueries( @@ -1578,8 +1573,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery saml app", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueries( @@ -1651,8 +1646,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery oidc app IsDevMode inactive", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueries( @@ -1740,8 +1735,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery oidc app AssertAccessTokenRole inactive", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueries( @@ -1829,8 +1824,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery oidc app AssertIDTokenRole inactive", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueries( @@ -1918,8 +1913,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery oidc app AssertIDTokenUserinfo inactive", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueries( @@ -2007,8 +2002,8 @@ func Test_AppPrepare(t *testing.T) { }, { name: "prepareAppQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*App, error)) { - return prepareAppQuery(ctx, db, false) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*App, error)) { + return prepareAppQuery(false) }, want: want{ sqlExpectations: mockQueryErr( @@ -2027,7 +2022,7 @@ func Test_AppPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -2113,7 +2108,7 @@ func Test_AppIDsPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -2179,7 +2174,7 @@ func Test_ProjectIDByAppPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -2377,7 +2372,7 @@ func Test_ProjectByAppPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/auth_request.go b/internal/query/auth_request.go index 20ac0f5abd..eaf5e52491 100644 --- a/internal/query/auth_request.go +++ b/internal/query/auth_request.go @@ -5,13 +5,11 @@ import ( "database/sql" _ "embed" "errors" - "fmt" "time" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" @@ -44,10 +42,6 @@ func (a *AuthRequest) checkLoginClient(ctx context.Context, permissionCheck doma //go:embed auth_request_by_id.sql var authRequestByIDQuery string -func (q *Queries) authRequestByIDQuery(ctx context.Context) string { - return fmt.Sprintf(authRequestByIDQuery, q.client.Timetravel(call.Took(ctx))) -} - func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *AuthRequest, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -74,7 +68,7 @@ func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, i &prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID, ) }, - q.authRequestByIDQuery(ctx), + authRequestByIDQuery, id, authz.GetInstance(ctx).InstanceID(), ) if errors.Is(err, sql.ErrNoRows) { diff --git a/internal/query/auth_request_by_id.sql b/internal/query/auth_request_by_id.sql index ffc18fccd6..f842719d0e 100644 --- a/internal/query/auth_request_by_id.sql +++ b/internal/query/auth_request_by_id.sql @@ -10,6 +10,6 @@ select login_hint, max_age, hint_user_id -from projections.auth_requests %s +from projections.auth_requests where id = $1 and instance_id = $2 limit 1; diff --git a/internal/query/auth_request_test.go b/internal/query/auth_request_test.go index 479282f9f7..152a032cd8 100644 --- a/internal/query/auth_request_test.go +++ b/internal/query/auth_request_test.go @@ -24,7 +24,6 @@ import ( func TestQueries_AuthRequestByID(t *testing.T) { expQuery := regexp.QuoteMeta(fmt.Sprintf( authRequestByIDQuery, - asOfSystemTime, )) cols := []string{ @@ -207,8 +206,7 @@ func TestQueries_AuthRequestByID(t *testing.T) { execMock(t, tt.expect, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, checkPermission: tt.permissionCheck, } diff --git a/internal/query/authn_key.go b/internal/query/authn_key.go index 6c05a03f6f..8075422e63 100644 --- a/internal/query/authn_key.go +++ b/internal/query/authn_key.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -129,7 +128,7 @@ func (q *Queries) SearchAuthNKeys(ctx context.Context, queries *AuthNKeySearchQu ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareAuthNKeysQuery(ctx, q.client) + query, scan := prepareAuthNKeysQuery() query = queries.toQuery(query) eq := sq.Eq{ AuthNKeyColumnEnabled.identifier(): true, @@ -156,7 +155,7 @@ func (q *Queries) SearchAuthNKeysData(ctx context.Context, queries *AuthNKeySear ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareAuthNKeysDataQuery(ctx, q.client) + query, scan := prepareAuthNKeysDataQuery() query = queries.toQuery(query) eq := sq.Eq{ AuthNKeyColumnEnabled.identifier(): true, @@ -189,7 +188,7 @@ func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, i traceSpan.EndWithError(err) } - query, scan := prepareAuthNKeyQuery(ctx, q.client) + query, scan := prepareAuthNKeyQuery() for _, q := range queries { query = q.toQuery(query) } @@ -214,7 +213,7 @@ func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareAuthNKeyPublicKeyQuery(ctx, q.client) + stmt, scan := prepareAuthNKeyPublicKeyQuery() eq := sq.And{ sq.Eq{ AuthNKeyColumnID.identifier(): id, @@ -288,7 +287,7 @@ func (q *Queries) GetAuthNKeyUser(ctx context.Context, keyID, userID string) (_ return dst, nil } -func prepareAuthNKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeys, error)) { +func prepareAuthNKeysQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeys, error)) { return sq.Select( AuthNKeyColumnID.identifier(), AuthNKeyColumnCreationDate.identifier(), @@ -298,7 +297,7 @@ func prepareAuthNKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu AuthNKeyColumnExpiration.identifier(), AuthNKeyColumnType.identifier(), countColumn.identifier(), - ).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(authNKeyTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*AuthNKeys, error) { authNKeys := make([]*AuthNKey, 0) @@ -334,7 +333,7 @@ func prepareAuthNKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu } } -func prepareAuthNKeyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) (*AuthNKey, error)) { +func prepareAuthNKeyQuery() (sq.SelectBuilder, func(row *sql.Row) (*AuthNKey, error)) { return sq.Select( AuthNKeyColumnID.identifier(), AuthNKeyColumnCreationDate.identifier(), @@ -343,7 +342,7 @@ func prepareAuthNKeyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui AuthNKeyColumnSequence.identifier(), AuthNKeyColumnExpiration.identifier(), AuthNKeyColumnType.identifier(), - ).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(authNKeyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*AuthNKey, error) { authNKey := new(AuthNKey) @@ -366,10 +365,10 @@ func prepareAuthNKeyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui } } -func prepareAuthNKeyPublicKeyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) ([]byte, error)) { +func prepareAuthNKeyPublicKeyQuery() (sq.SelectBuilder, func(row *sql.Row) ([]byte, error)) { return sq.Select( AuthNKeyColumnPublicKey.identifier(), - ).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(authNKeyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) ([]byte, error) { var publicKey []byte @@ -386,7 +385,7 @@ func prepareAuthNKeyPublicKeyQuery(ctx context.Context, db prepareDatabase) (sq. } } -func prepareAuthNKeysDataQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeysData, error)) { +func prepareAuthNKeysDataQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*AuthNKeysData, error)) { return sq.Select( AuthNKeyColumnID.identifier(), AuthNKeyColumnCreationDate.identifier(), @@ -398,7 +397,7 @@ func prepareAuthNKeysDataQuery(ctx context.Context, db prepareDatabase) (sq.Sele AuthNKeyColumnIdentifier.identifier(), AuthNKeyColumnPublicKey.identifier(), countColumn.identifier(), - ).From(authNKeyTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(authNKeyTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*AuthNKeysData, error) { authNKeys := make([]*AuthNKeyData, 0) diff --git a/internal/query/authn_key_test.go b/internal/query/authn_key_test.go index 19005893f8..c7441f8dae 100644 --- a/internal/query/authn_key_test.go +++ b/internal/query/authn_key_test.go @@ -26,8 +26,7 @@ var ( ` projections.authn_keys2.expiration,` + ` projections.authn_keys2.type,` + ` COUNT(*) OVER ()` + - ` FROM projections.authn_keys2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.authn_keys2` prepareAuthNKeysCols = []string{ "id", "creation_date", @@ -49,8 +48,7 @@ var ( ` projections.authn_keys2.identifier,` + ` projections.authn_keys2.public_key,` + ` COUNT(*) OVER ()` + - ` FROM projections.authn_keys2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.authn_keys2` prepareAuthNKeysDataCols = []string{ "id", "creation_date", @@ -71,8 +69,7 @@ var ( ` projections.authn_keys2.sequence,` + ` projections.authn_keys2.expiration,` + ` projections.authn_keys2.type` + - ` FROM projections.authn_keys2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.authn_keys2` prepareAuthNKeyCols = []string{ "id", "creation_date", @@ -84,8 +81,7 @@ var ( } prepareAuthNKeyPublicKeyStmt = `SELECT projections.authn_keys2.public_key` + - ` FROM projections.authn_keys2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.authn_keys2` prepareAuthNKeyPublicKeyCols = []string{ "public_key", } @@ -471,7 +467,7 @@ func Test_AuthNKeyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -525,8 +521,7 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } ctx := authz.NewMockContext("instanceID", "orgID", "userID") diff --git a/internal/query/certificate.go b/internal/query/certificate.go index e4d53213cf..ebe4b249f4 100644 --- a/internal/query/certificate.go +++ b/internal/query/certificate.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -69,7 +68,7 @@ func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage cry ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareCertificateQuery(ctx, q.client) + query, scan := prepareCertificateQuery() if t.IsZero() { t = time.Now() } @@ -102,7 +101,7 @@ func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage cry return certs, nil } -func prepareCertificateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) { +func prepareCertificateQuery() (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) { return sq.Select( KeyColID.identifier(), KeyColCreationDate.identifier(), @@ -117,7 +116,7 @@ func prepareCertificateQuery(ctx context.Context, db prepareDatabase) (sq.Select countColumn.identifier(), ).From(keyTable.identifier()). LeftJoin(join(CertificateColID, KeyColID)). - LeftJoin(join(KeyPrivateColID, KeyColID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(KeyPrivateColID, KeyColID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Certificates, error) { certificates := make([]Certificate, 0) diff --git a/internal/query/certificate_test.go b/internal/query/certificate_test.go index 01e563de11..eae011bb69 100644 --- a/internal/query/certificate_test.go +++ b/internal/query/certificate_test.go @@ -26,8 +26,7 @@ var ( ` COUNT(*) OVER ()` + ` FROM projections.keys4` + ` LEFT JOIN projections.keys4_certificate ON projections.keys4.id = projections.keys4_certificate.id AND projections.keys4.instance_id = projections.keys4_certificate.instance_id` + - ` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id` prepareCertificateCols = []string{ "id", "creation_date", @@ -142,7 +141,7 @@ func Test_CertificatePrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/current_state.go b/internal/query/current_state.go index 29497e6eec..6fae52713f 100644 --- a/internal/query/current_state.go +++ b/internal/query/current_state.go @@ -12,7 +12,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -68,7 +67,7 @@ func (q *Queries) SearchCurrentStates(ctx context.Context, queries *CurrentState ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareCurrentStateQuery(ctx, q.client) + query, scan := prepareCurrentStateQuery() stmt, args, err := queries.toQuery(query).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-MmFef", "Errors.Query.InvalidRequest") @@ -210,12 +209,12 @@ func reset(ctx context.Context, tx *sql.Tx, tables []string, projectionName stri return nil } -func prepareLatestState(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*State, error)) { +func prepareLatestState() (sq.SelectBuilder, func(*sql.Row) (*State, error)) { return sq.Select( CurrentStateColEventDate.identifier(), CurrentStateColPosition.identifier(), CurrentStateColLastUpdated.identifier()). - From(currentStateTable.identifier() + db.Timetravel(call.Took(ctx))). + From(currentStateTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*State, error) { var ( @@ -239,7 +238,7 @@ func prepareLatestState(ctx context.Context, db prepareDatabase) (sq.SelectBuild } } -func prepareCurrentStateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*CurrentStates, error)) { +func prepareCurrentStateQuery() (sq.SelectBuilder, func(*sql.Rows) (*CurrentStates, error)) { return sq.Select( CurrentStateColLastUpdated.identifier(), CurrentStateColEventDate.identifier(), @@ -249,7 +248,7 @@ func prepareCurrentStateQuery(ctx context.Context, db prepareDatabase) (sq.Selec CurrentStateColAggregateID.identifier(), CurrentStateColSequence.identifier(), countColumn.identifier()). - From(currentStateTable.identifier() + db.Timetravel(call.Took(ctx))). + From(currentStateTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*CurrentStates, error) { states := make([]*CurrentState, 0) diff --git a/internal/query/current_state_test.go b/internal/query/current_state_test.go index c76dae710e..c0895dc439 100644 --- a/internal/query/current_state_test.go +++ b/internal/query/current_state_test.go @@ -19,8 +19,7 @@ var ( ` projections.current_states.aggregate_id,` + ` projections.current_states.sequence,` + ` COUNT(*) OVER ()` + - ` FROM projections.current_states` + - " AS OF SYSTEM TIME '-1 ms' " + ` FROM projections.current_states` currentSequenceCols = []string{ "last_updated", @@ -175,7 +174,7 @@ func Test_CurrentSequencesPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/custom_text.go b/internal/query/custom_text.go index e92c910b69..0bc909d614 100644 --- a/internal/query/custom_text.go +++ b/internal/query/custom_text.go @@ -13,7 +13,6 @@ import ( "sigs.k8s.io/yaml" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/i18n" @@ -90,7 +89,7 @@ func (q *Queries) CustomTextList(ctx context.Context, aggregateID, template, lan ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareCustomTextsQuery(ctx, q.client) + stmt, scan := prepareCustomTextsQuery() eq := sq.Eq{ CustomTextColAggregateID.identifier(): aggregateID, CustomTextColTemplate.identifier(): template, @@ -121,7 +120,7 @@ func (q *Queries) CustomTextListByTemplate(ctx context.Context, aggregateID, tem ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareCustomTextsQuery(ctx, q.client) + stmt, scan := prepareCustomTextsQuery() eq := sq.Eq{ CustomTextColAggregateID.identifier(): aggregateID, CustomTextColTemplate.identifier(): template, @@ -230,7 +229,7 @@ func (q *Queries) readLoginTranslationFile(ctx context.Context, lang string) ([] return contents, nil } -func prepareCustomTextsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*CustomTexts, error)) { +func prepareCustomTextsQuery() (sq.SelectBuilder, func(*sql.Rows) (*CustomTexts, error)) { return sq.Select( CustomTextColAggregateID.identifier(), CustomTextColSequence.identifier(), @@ -241,7 +240,7 @@ func prepareCustomTextsQuery(ctx context.Context, db prepareDatabase) (sq.Select CustomTextColKey.identifier(), CustomTextColText.identifier(), countColumn.identifier()). - From(customTextTable.identifier() + db.Timetravel(call.Took(ctx))). + From(customTextTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*CustomTexts, error) { customTexts := make([]*CustomText, 0) diff --git a/internal/query/custom_text_test.go b/internal/query/custom_text_test.go index 0453f71a2a..c31072793b 100644 --- a/internal/query/custom_text_test.go +++ b/internal/query/custom_text_test.go @@ -23,8 +23,7 @@ var ( ` projections.custom_texts2.key,` + ` projections.custom_texts2.text,` + ` COUNT(*) OVER ()` + - ` FROM projections.custom_texts2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.custom_texts2` prepareCustomTextsCols = []string{ "aggregate_id", "sequence", @@ -185,7 +184,7 @@ func Test_CustomTextPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/device_auth.go b/internal/query/device_auth.go index e42b5a114e..d2f86a44af 100644 --- a/internal/query/device_auth.go +++ b/internal/query/device_auth.go @@ -63,7 +63,7 @@ func (q *Queries) DeviceAuthRequestByUserCode(ctx context.Context, userCode stri ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareDeviceAuthQuery(ctx, q.client) + stmt, scan := prepareDeviceAuthQuery() eq := sq.Eq{ DeviceAuthRequestColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), DeviceAuthRequestColumnUserCode.identifier(): userCode, @@ -90,7 +90,7 @@ var deviceAuthSelectColumns = []string{ ProjectColumnName.identifier(), } -func prepareDeviceAuthQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*domain.AuthRequestDevice, error)) { +func prepareDeviceAuthQuery() (sq.SelectBuilder, func(*sql.Row) (*domain.AuthRequestDevice, error)) { return sq.Select(deviceAuthSelectColumns...). From(deviceAuthRequestTable.identifier()). LeftJoin(join(AppOIDCConfigColumnClientID, DeviceAuthRequestColumnClientID)). diff --git a/internal/query/device_auth_test.go b/internal/query/device_auth_test.go index 6f0f82b3be..52ac50abb7 100644 --- a/internal/query/device_auth_test.go +++ b/internal/query/device_auth_test.go @@ -138,7 +138,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, prepareDeviceAuthQuery, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, prepareDeviceAuthQuery, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/domain_policy.go b/internal/query/domain_policy.go index d971723bcf..3eba664e75 100644 --- a/internal/query/domain_policy.go +++ b/internal/query/domain_policy.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -118,7 +117,7 @@ func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, } } - stmt, scan := prepareDomainPolicyQuery(ctx, q.client) + stmt, scan := prepareDomainPolicyQuery() query, args, err := stmt.Where(eq).OrderBy(DomainPolicyColIsDefault.identifier()). Limit(1).ToSql() if err != nil { @@ -136,7 +135,7 @@ func (q *Queries) DefaultDomainPolicy(ctx context.Context) (policy *DomainPolicy ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareDomainPolicyQuery(ctx, q.client) + stmt, scan := prepareDomainPolicyQuery() query, args, err := stmt.Where(sq.Eq{ DomainPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(), DomainPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -154,7 +153,7 @@ func (q *Queries) DefaultDomainPolicy(ctx context.Context) (policy *DomainPolicy return policy, err } -func prepareDomainPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy, error)) { +func prepareDomainPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy, error)) { return sq.Select( DomainPolicyColID.identifier(), DomainPolicyColSequence.identifier(), @@ -167,7 +166,7 @@ func prepareDomainPolicyQuery(ctx context.Context, db prepareDatabase) (sq.Selec DomainPolicyColIsDefault.identifier(), DomainPolicyColState.identifier(), ). - From(domainPolicyTable.identifier() + db.Timetravel(call.Took(ctx))). + From(domainPolicyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*DomainPolicy, error) { policy := new(DomainPolicy) diff --git a/internal/query/domain_policy_test.go b/internal/query/domain_policy_test.go index 70d3ddc391..0ff2567979 100644 --- a/internal/query/domain_policy_test.go +++ b/internal/query/domain_policy_test.go @@ -23,8 +23,7 @@ var ( ` projections.domain_policies2.smtp_sender_address_matches_instance_domain,` + ` projections.domain_policies2.is_default,` + ` projections.domain_policies2.state` + - ` FROM projections.domain_policies2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.domain_policies2` prepareDomainPolicyCols = []string{ "id", "sequence", @@ -122,7 +121,7 @@ func Test_DomainPolicyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/execution.go b/internal/query/execution.go index b98c680f57..0a2a989918 100644 --- a/internal/query/execution.go +++ b/internal/query/execution.go @@ -101,8 +101,8 @@ func (q *Queries) SearchExecutions(ctx context.Context, queries *ExecutionSearch eq := sq.Eq{ ExecutionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } - query, scan := prepareExecutionsQuery(ctx, q.client) - return genericRowsQueryWithState[*Executions](ctx, q.client, executionTable, combineToWhereStmt(query, queries.toQuery, eq), scan) + query, scan := prepareExecutionsQuery() + return genericRowsQueryWithState(ctx, q.client, executionTable, combineToWhereStmt(query, queries.toQuery, eq), scan) } func (q *Queries) GetExecutionByID(ctx context.Context, id string) (execution *Execution, err error) { @@ -110,8 +110,8 @@ func (q *Queries) GetExecutionByID(ctx context.Context, id string) (execution *E ExecutionColumnID.identifier(): id, ExecutionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } - query, scan := prepareExecutionQuery(ctx, q.client) - return genericRowQuery[*Execution](ctx, q.client, query.Where(eq), scan) + query, scan := prepareExecutionQuery() + return genericRowQuery(ctx, q.client, query.Where(eq), scan) } func NewExecutionInIDsSearchQuery(values []string) (SearchQuery, error) { @@ -219,7 +219,7 @@ func (q *Queries) TargetsByExecutionIDs(ctx context.Context, ids1, ids2 []string return execution, err } -func prepareExecutionQuery(context.Context, prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) (*Execution, error)) { +func prepareExecutionQuery() (sq.SelectBuilder, func(row *sql.Row) (*Execution, error)) { return sq.Select( ExecutionColumnInstanceID.identifier(), ExecutionColumnID.identifier(), @@ -235,7 +235,7 @@ func prepareExecutionQuery(context.Context, prepareDatabase) (sq.SelectBuilder, scanExecution } -func prepareExecutionsQuery(context.Context, prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*Executions, error)) { +func prepareExecutionsQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*Executions, error)) { return sq.Select( ExecutionColumnInstanceID.identifier(), ExecutionColumnID.identifier(), diff --git a/internal/query/execution_test.go b/internal/query/execution_test.go index ee6bdc4d96..eaaac1e9ba 100644 --- a/internal/query/execution_test.go +++ b/internal/query/execution_test.go @@ -263,7 +263,7 @@ func Test_ExecutionPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/failed_events.go b/internal/query/failed_events.go index 7d2e875cee..c5ad1ae1d9 100644 --- a/internal/query/failed_events.go +++ b/internal/query/failed_events.go @@ -7,7 +7,6 @@ import ( sq "github.com/Masterminds/squirrel" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -83,7 +82,7 @@ type FailedEventSearchQueries struct { } func (q *Queries) SearchFailedEvents(ctx context.Context, queries *FailedEventSearchQueries) (failedEvents *FailedEvents, err error) { - query, scan := prepareFailedEventsQuery(ctx, q.client) + query, scan := prepareFailedEventsQuery() stmt, args, err := queries.toQuery(query).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-n8rjJ", "Errors.Query.InvalidRequest") @@ -139,7 +138,7 @@ func (q *FailedEventSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuil return query } -func prepareFailedEventsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*FailedEvents, error)) { +func prepareFailedEventsQuery() (sq.SelectBuilder, func(*sql.Rows) (*FailedEvents, error)) { return sq.Select( FailedEventsColumnProjectionName.identifier(), FailedEventsColumnFailedSequence.identifier(), @@ -149,7 +148,7 @@ func prepareFailedEventsQuery(ctx context.Context, db prepareDatabase) (sq.Selec FailedEventsColumnLastFailed.identifier(), FailedEventsColumnError.identifier(), countColumn.identifier()). - From(failedEventsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(failedEventsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*FailedEvents, error) { failedEvents := make([]*FailedEvent, 0) diff --git a/internal/query/failed_events_test.go b/internal/query/failed_events_test.go index 7e575b5891..25c15e9b8f 100644 --- a/internal/query/failed_events_test.go +++ b/internal/query/failed_events_test.go @@ -19,8 +19,7 @@ var ( ` projections.failed_events2.last_failed,` + ` projections.failed_events2.error,` + ` COUNT(*) OVER ()` + - ` FROM projections.failed_events2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.failed_events2` prepareFailedEventsCols = []string{ "projection_name", @@ -168,7 +167,7 @@ func Test_FailedEventsPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/generic.go b/internal/query/generic.go index 70fc2884a2..ea2257c013 100644 --- a/internal/query/generic.go +++ b/internal/query/generic.go @@ -44,7 +44,7 @@ func genericRowsQueryWithState[R Stateful]( scan func(rows *sql.Rows) (R, error), ) (resp R, err error) { var rnil R - resp, err = genericRowsQuery[R](ctx, client, query, scan) + resp, err = genericRowsQuery(ctx, client, query, scan) if err != nil { return rnil, err } @@ -60,7 +60,7 @@ func latestState(ctx context.Context, client *database.DB, projections ...table) ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLatestState(ctx, client) + query, scan := prepareLatestState() or := make(sq.Or, len(projections)) for i, projection := range projections { or[i] = sq.Eq{CurrentStateColProjectionName.identifier(): projection.name} diff --git a/internal/query/iam_member.go b/internal/query/iam_member.go index 87b906aa51..139208c7b8 100644 --- a/internal/query/iam_member.go +++ b/internal/query/iam_member.go @@ -7,7 +7,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -71,7 +70,7 @@ func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery) (mem ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareInstanceMembersQuery(ctx, q.client) + query, scan := prepareInstanceMembersQuery() eq := sq.Eq{InstanceMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -94,7 +93,7 @@ func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery) (mem return members, err } -func prepareInstanceMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { +func prepareInstanceMembersQuery() (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { return sq.Select( InstanceMemberCreationDate.identifier(), InstanceMemberChangeDate.identifier(), @@ -116,7 +115,7 @@ func prepareInstanceMembersQuery(ctx context.Context, db prepareDatabase) (sq.Se LeftJoin(join(HumanUserIDCol, InstanceMemberUserID)). LeftJoin(join(MachineUserIDCol, InstanceMemberUserID)). LeftJoin(join(UserIDCol, InstanceMemberUserID)). - LeftJoin(join(LoginNameUserIDCol, InstanceMemberUserID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(LoginNameUserIDCol, InstanceMemberUserID)). Where( sq.Eq{LoginNameIsPrimaryCol.identifier(): true}, ).PlaceholderFormat(sq.Dollar), diff --git a/internal/query/iam_member_test.go b/internal/query/iam_member_test.go index 82cea360c8..5c10ebc5bc 100644 --- a/internal/query/iam_member_test.go +++ b/internal/query/iam_member_test.go @@ -39,7 +39,6 @@ var ( "ON members.user_id = projections.users14.id AND members.instance_id = projections.users14.instance_id " + "LEFT JOIN projections.login_names3 " + "ON members.user_id = projections.login_names3.user_id AND members.instance_id = projections.login_names3.instance_id " + - "AS OF SYSTEM TIME '-1 ms' " + "WHERE projections.login_names3.is_primary = $1") instanceMembersColumns = []string{ "creation_date", @@ -295,7 +294,7 @@ func Test_IAMMemberPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/idp.go b/internal/query/idp.go index 2687397330..4f0d77c62b 100644 --- a/internal/query/idp.go +++ b/internal/query/idp.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" @@ -215,7 +214,7 @@ func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk sq.Eq{IDPResourceOwnerCol.identifier(): authz.GetInstance(ctx).InstanceID()}, }, } - stmt, scan := prepareIDPByIDQuery(ctx, q.client) + stmt, scan := prepareIDPByIDQuery() query, args, err := stmt.Where(where).ToSql() if err != nil { return nil, zerrors.ThrowInternal(err, "QUERY-0gocI", "Errors.Query.SQLStatement") @@ -233,7 +232,7 @@ func (q *Queries) IDPs(ctx context.Context, queries *IDPSearchQueries, withOwner ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareIDPsQuery(ctx, q.client) + query, scan := prepareIDPsQuery() eq := sq.Eq{ IDPInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), } @@ -293,7 +292,7 @@ func (q *IDPSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query } -func prepareIDPByIDQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*IDP, error)) { +func prepareIDPByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDP, error)) { return sq.Select( IDPIDCol.identifier(), IDPResourceOwnerCol.identifier(), @@ -321,7 +320,7 @@ func prepareIDPByIDQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil JWTIDPColEndpoint.identifier(), ).From(idpTable.identifier()). LeftJoin(join(OIDCIDPColIDPID, IDPIDCol)). - LeftJoin(join(JWTIDPColIDPID, IDPIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(JWTIDPColIDPID, IDPIDCol)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*IDP, error) { idp := new(IDP) @@ -401,7 +400,7 @@ func prepareIDPByIDQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil } } -func prepareIDPsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPs, error)) { +func prepareIDPsQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPs, error)) { return sq.Select( IDPIDCol.identifier(), IDPResourceOwnerCol.identifier(), @@ -430,7 +429,7 @@ func prepareIDPsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder countColumn.identifier(), ).From(idpTable.identifier()). LeftJoin(join(OIDCIDPColIDPID, IDPIDCol)). - LeftJoin(join(JWTIDPColIDPID, IDPIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(JWTIDPColIDPID, IDPIDCol)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*IDPs, error) { idps := make([]*IDP, 0) diff --git a/internal/query/idp_login_policy_link.go b/internal/query/idp_login_policy_link.go index bdc2ef15b1..65f855bc51 100644 --- a/internal/query/idp_login_policy_link.go +++ b/internal/query/idp_login_policy_link.go @@ -7,7 +7,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -97,7 +96,7 @@ func (q *Queries) IDPLoginPolicyLinks(ctx context.Context, resourceOwner string, ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareIDPLoginPolicyLinksQuery(ctx, q.client, resourceOwner) + query, scan := prepareIDPLoginPolicyLinksQuery(ctx, resourceOwner) eq := sq.Eq{ IDPLoginPolicyLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), } @@ -122,7 +121,8 @@ func (q *Queries) IDPLoginPolicyLinks(ctx context.Context, resourceOwner string, return idps, err } -func prepareIDPLoginPolicyLinksQuery(ctx context.Context, db prepareDatabase, resourceOwner string) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { +//nolint:gocognit +func prepareIDPLoginPolicyLinksQuery(ctx context.Context, resourceOwner string) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { resourceOwnerQuery, resourceOwnerArgs, err := prepareIDPLoginPolicyLinksResourceOwnerQuery(ctx, resourceOwner) if err != nil { return sq.SelectBuilder{}, nil @@ -142,8 +142,7 @@ func prepareIDPLoginPolicyLinksQuery(ctx context.Context, db prepareDatabase, re LeftJoin(join(IDPTemplateIDCol, IDPLoginPolicyLinkIDPIDCol)). RightJoin("("+resourceOwnerQuery+") AS "+idpLoginPolicyOwnerTable.alias+" ON "+ idpLoginPolicyOwnerIDCol.identifier()+" = "+IDPLoginPolicyLinkResourceOwnerCol.identifier()+" AND "+ - idpLoginPolicyOwnerInstanceIDCol.identifier()+" = "+IDPLoginPolicyLinkInstanceIDCol.identifier()+ - " "+db.Timetravel(call.Took(ctx)), + idpLoginPolicyOwnerInstanceIDCol.identifier()+" = "+IDPLoginPolicyLinkInstanceIDCol.identifier(), resourceOwnerArgs...). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*IDPLoginPolicyLinks, error) { diff --git a/internal/query/idp_login_policy_link_test.go b/internal/query/idp_login_policy_link_test.go index 245eb22ccc..9f66e118ea 100644 --- a/internal/query/idp_login_policy_link_test.go +++ b/internal/query/idp_login_policy_link_test.go @@ -6,6 +6,7 @@ import ( "database/sql/driver" "errors" "fmt" + "reflect" "regexp" "testing" @@ -29,8 +30,7 @@ var ( ` LEFT JOIN projections.idp_templates6 ON projections.idp_login_policy_links5.idp_id = projections.idp_templates6.id AND projections.idp_login_policy_links5.instance_id = projections.idp_templates6.instance_id` + ` RIGHT JOIN (SELECT login_policy_owner.aggregate_id, login_policy_owner.instance_id, login_policy_owner.owner_removed FROM projections.login_policies5 AS login_policy_owner` + ` WHERE (login_policy_owner.instance_id = $1 AND (login_policy_owner.aggregate_id = $2 OR login_policy_owner.aggregate_id = $3)) ORDER BY login_policy_owner.is_default LIMIT 1) AS login_policy_owner` + - ` ON login_policy_owner.aggregate_id = projections.idp_login_policy_links5.resource_owner AND login_policy_owner.instance_id = projections.idp_login_policy_links5.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` ON login_policy_owner.aggregate_id = projections.idp_login_policy_links5.resource_owner AND login_policy_owner.instance_id = projections.idp_login_policy_links5.instance_id`) loginPolicyIDPLinksCols = []string{ "idp_id", "name", @@ -52,14 +52,14 @@ func Test_IDPLoginPolicyLinkPrepares(t *testing.T) { } tests := []struct { name string - prepare interface{} + prepare any want want - object interface{} + object any }{ { name: "prepareIDPsQuery found", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { - return prepareIDPLoginPolicyLinksQuery(ctx, db, "resourceOwner") + prepare: func(ctx context.Context) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { + return prepareIDPLoginPolicyLinksQuery(ctx, "resourceOwner") }, want: want{ sqlExpectations: mockQueries( @@ -101,8 +101,8 @@ func Test_IDPLoginPolicyLinkPrepares(t *testing.T) { }, { name: "prepareIDPsQuery no idp", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { - return prepareIDPLoginPolicyLinksQuery(ctx, db, "resourceOwner") + prepare: func(ctx context.Context) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { + return prepareIDPLoginPolicyLinksQuery(ctx, "resourceOwner") }, want: want{ sqlExpectations: mockQueries( @@ -143,8 +143,8 @@ func Test_IDPLoginPolicyLinkPrepares(t *testing.T) { }, { name: "prepareIDPsQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { - return prepareIDPLoginPolicyLinksQuery(ctx, db, "resourceOwner") + prepare: func(ctx context.Context) (sq.SelectBuilder, func(*sql.Rows) (*IDPLoginPolicyLinks, error)) { + return prepareIDPLoginPolicyLinksQuery(ctx, "resourceOwner") }, want: want{ sqlExpectations: mockQueryErr( @@ -163,7 +163,7 @@ func Test_IDPLoginPolicyLinkPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, reflect.ValueOf(context.Background())) }) } } diff --git a/internal/query/idp_template.go b/internal/query/idp_template.go index a63cb6f485..f51e9a11a7 100644 --- a/internal/query/idp_template.go +++ b/internal/query/idp_template.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" @@ -767,7 +766,7 @@ func (q *Queries) idpTemplateByID(ctx context.Context, shouldTriggerBulk bool, i if !withOwnerRemoved { eq[IDPTemplateOwnerRemovedCol.identifier()] = false } - query, scan := prepareIDPTemplateByIDQuery(ctx, q.client) + query, scan := prepareIDPTemplateByIDQuery() for _, q := range queries { query = q.toQuery(query) } @@ -788,7 +787,7 @@ func (q *Queries) IDPTemplates(ctx context.Context, queries *IDPTemplateSearchQu ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareIDPTemplatesQuery(ctx, q.client) + query, scan := prepareIDPTemplatesQuery() eq := sq.Eq{ IDPTemplateInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), } @@ -864,7 +863,7 @@ func (q *IDPTemplateSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuil return query } -func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*IDPTemplate, error)) { +func prepareIDPTemplateByIDQuery() (sq.SelectBuilder, func(*sql.Row) (*IDPTemplate, error)) { return sq.Select( IDPTemplateIDCol.identifier(), IDPTemplateResourceOwnerCol.identifier(), @@ -993,7 +992,7 @@ func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.Se LeftJoin(join(GoogleIDCol, IDPTemplateIDCol)). LeftJoin(join(SAMLIDCol, IDPTemplateIDCol)). LeftJoin(join(LDAPIDCol, IDPTemplateIDCol)). - LeftJoin(join(AppleIDCol, IDPTemplateIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(AppleIDCol, IDPTemplateIDCol)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*IDPTemplate, error) { idpTemplate := new(IDPTemplate) @@ -1371,7 +1370,8 @@ func prepareIDPTemplateByIDQuery(ctx context.Context, db prepareDatabase) (sq.Se } } -func prepareIDPTemplatesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPTemplates, error)) { +//nolint:gocognit +func prepareIDPTemplatesQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPTemplates, error)) { return sq.Select( IDPTemplateIDCol.identifier(), IDPTemplateResourceOwnerCol.identifier(), @@ -1502,7 +1502,7 @@ func prepareIDPTemplatesQuery(ctx context.Context, db prepareDatabase) (sq.Selec LeftJoin(join(GoogleIDCol, IDPTemplateIDCol)). LeftJoin(join(SAMLIDCol, IDPTemplateIDCol)). LeftJoin(join(LDAPIDCol, IDPTemplateIDCol)). - LeftJoin(join(AppleIDCol, IDPTemplateIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(AppleIDCol, IDPTemplateIDCol)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*IDPTemplates, error) { templates := make([]*IDPTemplate, 0) diff --git a/internal/query/idp_template_test.go b/internal/query/idp_template_test.go index bee19e492a..a30d501bae 100644 --- a/internal/query/idp_template_test.go +++ b/internal/query/idp_template_test.go @@ -143,8 +143,7 @@ var ( ` LEFT JOIN projections.idp_templates6_google ON projections.idp_templates6.id = projections.idp_templates6_google.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_google.instance_id` + ` LEFT JOIN projections.idp_templates6_saml ON projections.idp_templates6.id = projections.idp_templates6_saml.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_saml.instance_id` + ` LEFT JOIN projections.idp_templates6_ldap3 ON projections.idp_templates6.id = projections.idp_templates6_ldap3.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_ldap3.instance_id` + - ` LEFT JOIN projections.idp_templates6_apple ON projections.idp_templates6.id = projections.idp_templates6_apple.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_apple.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.idp_templates6_apple ON projections.idp_templates6.id = projections.idp_templates6_apple.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_apple.instance_id` idpTemplateCols = []string{ "id", "resource_owner", @@ -390,8 +389,7 @@ var ( ` LEFT JOIN projections.idp_templates6_google ON projections.idp_templates6.id = projections.idp_templates6_google.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_google.instance_id` + ` LEFT JOIN projections.idp_templates6_saml ON projections.idp_templates6.id = projections.idp_templates6_saml.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_saml.instance_id` + ` LEFT JOIN projections.idp_templates6_ldap3 ON projections.idp_templates6.id = projections.idp_templates6_ldap3.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_ldap3.instance_id` + - ` LEFT JOIN projections.idp_templates6_apple ON projections.idp_templates6.id = projections.idp_templates6_apple.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_apple.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.idp_templates6_apple ON projections.idp_templates6.id = projections.idp_templates6_apple.idp_id AND projections.idp_templates6.instance_id = projections.idp_templates6_apple.instance_id` idpTemplatesCols = []string{ "id", "resource_owner", @@ -3485,7 +3483,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/idp_test.go b/internal/query/idp_test.go index 9474a0c751..a7f6fb95c1 100644 --- a/internal/query/idp_test.go +++ b/internal/query/idp_test.go @@ -733,7 +733,7 @@ func Test_IDPPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/idp_user_link.go b/internal/query/idp_user_link.go index 5caf6c6646..23305dfd6e 100644 --- a/internal/query/idp_user_link.go +++ b/internal/query/idp_user_link.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -129,7 +128,7 @@ func (q *Queries) idpUserLinks(ctx context.Context, queries *IDPUserLinksSearchQ ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareIDPUserLinksQuery(ctx, q.client) + query, scan := prepareIDPUserLinksQuery() eq := sq.Eq{IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()} if !withOwnerRemoved { eq[IDPUserLinkOwnerRemovedCol.identifier()] = false @@ -166,7 +165,7 @@ func NewIDPUserLinksExternalIDSearchQuery(value string) (SearchQuery, error) { return NewTextQuery(IDPUserLinkExternalUserIDCol, value, TextEquals) } -func prepareIDPUserLinksQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*IDPUserLinks, error)) { +func prepareIDPUserLinksQuery() (sq.SelectBuilder, func(*sql.Rows) (*IDPUserLinks, error)) { return sq.Select( IDPUserLinkIDPIDCol.identifier(), IDPUserLinkUserIDCol.identifier(), @@ -177,7 +176,7 @@ func prepareIDPUserLinksQuery(ctx context.Context, db prepareDatabase) (sq.Selec IDPUserLinkResourceOwnerCol.identifier(), countColumn.identifier()). From(idpUserLinkTable.identifier()). - LeftJoin(join(IDPTemplateIDCol, IDPUserLinkIDPIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(IDPTemplateIDCol, IDPUserLinkIDPIDCol)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*IDPUserLinks, error) { idps := make([]*IDPUserLink, 0) diff --git a/internal/query/idp_user_link_test.go b/internal/query/idp_user_link_test.go index b8ba2d087a..eac9669110 100644 --- a/internal/query/idp_user_link_test.go +++ b/internal/query/idp_user_link_test.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "regexp" + "slices" "testing" "github.com/stretchr/testify/require" @@ -165,10 +166,8 @@ func TestUser_idpLinksCheckPermission(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { checkPermission := func(ctx context.Context, permission, orgID, resourceID string) (err error) { - for _, perm := range tt.permissions { - if resourceID == perm { - return nil - } + if slices.Contains(tt.permissions, resourceID) { + return nil } return errors.New("failed") } @@ -188,8 +187,7 @@ var ( ` projections.idp_user_links3.resource_owner,` + ` COUNT(*) OVER ()` + ` FROM projections.idp_user_links3` + - ` LEFT JOIN projections.idp_templates6 ON projections.idp_user_links3.idp_id = projections.idp_templates6.id AND projections.idp_user_links3.instance_id = projections.idp_templates6.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.idp_templates6 ON projections.idp_user_links3.idp_id = projections.idp_templates6.id AND projections.idp_user_links3.instance_id = projections.idp_templates6.instance_id`) idpUserLinksCols = []string{ "idp_id", "user_id", @@ -307,7 +305,7 @@ func Test_IDPUserLinkPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/instance.go b/internal/query/instance.go index d7d66b1607..1b3bb055cb 100644 --- a/internal/query/instance.go +++ b/internal/query/instance.go @@ -16,7 +16,6 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" @@ -151,7 +150,7 @@ func (q *Queries) SearchInstances(ctx context.Context, queries *InstanceSearchQu ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - filter, query, scan := prepareInstancesQuery(ctx, q.client) + filter, query, scan := prepareInstancesQuery() stmt, args, err := query(queries.toQuery(filter)).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-M9fow", "Errors.Query.SQLStatement") @@ -178,7 +177,7 @@ func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (instanc traceSpan.EndWithError(err) } - stmt, scan := prepareInstanceDomainQuery(ctx, q.client) + stmt, scan := prepareInstanceDomainQuery() query, args, err := stmt.Where(sq.Eq{ InstanceColumnID.identifier(): authz.GetInstance(ctx).InstanceID(), }).ToSql() @@ -261,7 +260,7 @@ func (q *Queries) GetDefaultLanguage(ctx context.Context) language.Tag { return instance.DefaultLang } -func prepareInstancesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(sq.SelectBuilder) sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { +func prepareInstancesQuery() (sq.SelectBuilder, func(sq.SelectBuilder) sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { instanceFilterTable := instanceTable.setAlias(InstancesFilterTableAlias) instanceFilterIDColumn := InstanceColumnID.setTable(instanceFilterTable) instanceFilterCountColumn := InstancesFilterTableAlias + ".count" @@ -291,7 +290,7 @@ func prepareInstancesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu InstanceDomainSequenceCol.identifier(), ).FromSelect(builder, InstancesFilterTableAlias). LeftJoin(join(InstanceColumnID, instanceFilterIDColumn)). - LeftJoin(join(InstanceDomainInstanceIDCol, instanceFilterIDColumn) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(InstanceDomainInstanceIDCol, instanceFilterIDColumn)). PlaceholderFormat(sq.Dollar) }, func(rows *sql.Rows) (*Instances, error) { @@ -366,7 +365,7 @@ func prepareInstancesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu } } -func prepareInstanceDomainQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) { +func prepareInstanceDomainQuery() (sq.SelectBuilder, func(*sql.Rows) (*Instance, error)) { return sq.Select( InstanceColumnID.identifier(), InstanceColumnCreationDate.identifier(), @@ -386,7 +385,7 @@ func prepareInstanceDomainQuery(ctx context.Context, db prepareDatabase) (sq.Sel InstanceDomainSequenceCol.identifier(), ). From(instanceTable.identifier()). - LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(InstanceDomainInstanceIDCol, InstanceColumnID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Instance, error) { instance := &Instance{ diff --git a/internal/query/instance_domain.go b/internal/query/instance_domain.go index 285bd12936..47b5fab27f 100644 --- a/internal/query/instance_domain.go +++ b/internal/query/instance_domain.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" @@ -62,7 +61,7 @@ func (q *Queries) SearchInstanceDomains(ctx context.Context, queries *InstanceDo ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareInstanceDomainsQuery(ctx, q.client) + query, scan := prepareInstanceDomainsQuery() stmt, args, err := queries.toQuery(query). Where(sq.Eq{ InstanceDomainInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -78,7 +77,7 @@ func (q *Queries) SearchInstanceDomainsGlobal(ctx context.Context, queries *Inst ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareInstanceDomainsQuery(ctx, q.client) + query, scan := prepareInstanceDomainsQuery() stmt, args, err := queries.toQuery(query).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-IHhLR", "Errors.Query.SQLStatement") @@ -99,7 +98,7 @@ func (q *Queries) queryInstanceDomains(ctx context.Context, stmt string, scan fu return domains, err } -func prepareInstanceDomainsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*InstanceDomains, error)) { +func prepareInstanceDomainsQuery() (sq.SelectBuilder, func(*sql.Rows) (*InstanceDomains, error)) { return sq.Select( InstanceDomainCreationDateCol.identifier(), InstanceDomainChangeDateCol.identifier(), @@ -109,7 +108,7 @@ func prepareInstanceDomainsQuery(ctx context.Context, db prepareDatabase) (sq.Se InstanceDomainIsGeneratedCol.identifier(), InstanceDomainIsPrimaryCol.identifier(), countColumn.identifier(), - ).From(instanceDomainsTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(instanceDomainsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*InstanceDomains, error) { domains := make([]*InstanceDomain, 0) diff --git a/internal/query/instance_domain_test.go b/internal/query/instance_domain_test.go index 4f72c0def4..fd147bf4b7 100644 --- a/internal/query/instance_domain_test.go +++ b/internal/query/instance_domain_test.go @@ -18,8 +18,7 @@ var ( ` projections.instance_domains.is_generated,` + ` projections.instance_domains.is_primary,` + ` COUNT(*) OVER ()` + - ` FROM projections.instance_domains` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.instance_domains` prepareInstanceDomainsCols = []string{ "creation_date", "change_date", @@ -167,7 +166,7 @@ func Test_InstanceDomainPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/instance_features_test.go b/internal/query/instance_features_test.go index e182f4002f..903c2872a9 100644 --- a/internal/query/instance_features_test.go +++ b/internal/query/instance_features_test.go @@ -84,28 +84,28 @@ func TestQueries_GetInstanceFeatures(t *testing.T) { { name: "all features set", eventstore: expectEventstore( - expectFilter(eventFromEventPusher(feature_v2.NewSetEvent[bool]( + expectFilter(eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLoginDefaultOrgEventType, true, ))), expectFilter( - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceLoginDefaultOrgEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceTriggerIntrospectionProjectionsEventType, true, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceLegacyIntrospectionEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceUserSchemaEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceActionsEventType, false, )), @@ -141,28 +141,28 @@ func TestQueries_GetInstanceFeatures(t *testing.T) { { name: "all features set, reset, set some feature, cascaded", eventstore: expectEventstore( - expectFilter(eventFromEventPusher(feature_v2.NewSetEvent[bool]( + expectFilter(eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLoginDefaultOrgEventType, true, ))), expectFilter( - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceLoginDefaultOrgEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceTriggerIntrospectionProjectionsEventType, true, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceLegacyIntrospectionEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceUserSchemaEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceActionsEventType, false, )), @@ -170,7 +170,7 @@ func TestQueries_GetInstanceFeatures(t *testing.T) { ctx, aggregate, feature_v2.InstanceResetEventType, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceTriggerIntrospectionProjectionsEventType, true, )), @@ -207,23 +207,23 @@ func TestQueries_GetInstanceFeatures(t *testing.T) { name: "all features set, reset, set some feature, not cascaded", eventstore: expectEventstore( expectFilter( - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceLoginDefaultOrgEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceTriggerIntrospectionProjectionsEventType, true, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceLegacyIntrospectionEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceUserSchemaEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceActionsEventType, false, )), @@ -231,7 +231,7 @@ func TestQueries_GetInstanceFeatures(t *testing.T) { ctx, aggregate, feature_v2.InstanceResetEventType, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( ctx, aggregate, feature_v2.InstanceTriggerIntrospectionProjectionsEventType, true, )), diff --git a/internal/query/instance_test.go b/internal/query/instance_test.go index 8d9c7e1597..55b1c8314b 100644 --- a/internal/query/instance_test.go +++ b/internal/query/instance_test.go @@ -1,7 +1,6 @@ package query import ( - "context" "database/sql" "database/sql/driver" "errors" @@ -34,8 +33,7 @@ var ( ` FROM (SELECT DISTINCT projections.instances.id, COUNT(*) OVER () FROM projections.instances` + ` LEFT JOIN projections.instance_domains ON projections.instances.id = projections.instance_domains.instance_id) AS f` + ` LEFT JOIN projections.instances ON f.id = projections.instances.id` + - ` LEFT JOIN projections.instance_domains ON f.id = projections.instance_domains.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.instance_domains ON f.id = projections.instance_domains.instance_id` instancesCols = []string{ "count", "id", @@ -64,15 +62,15 @@ func Test_InstancePrepares(t *testing.T) { } tests := []struct { name string - prepare interface{} + prepare any additionalArgs []reflect.Value want want - object interface{} + object any }{ { name: "prepareInstancesQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { - filter, query, scan := prepareInstancesQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { + filter, query, scan := prepareInstancesQuery() return query(filter), scan }, want: want{ @@ -86,8 +84,8 @@ func Test_InstancePrepares(t *testing.T) { }, { name: "prepareInstancesQuery one result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { - filter, query, scan := prepareInstancesQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { + filter, query, scan := prepareInstancesQuery() return query(filter), scan }, want: want{ @@ -150,8 +148,8 @@ func Test_InstancePrepares(t *testing.T) { }, { name: "prepareInstancesQuery multiple results", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { - filter, query, scan := prepareInstancesQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { + filter, query, scan := prepareInstancesQuery() return query(filter), scan }, want: want{ @@ -283,8 +281,8 @@ func Test_InstancePrepares(t *testing.T) { }, { name: "prepareInstancesQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { - filter, query, scan := prepareInstancesQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*Instances, error)) { + filter, query, scan := prepareInstancesQuery() return query(filter), scan }, want: want{ @@ -304,7 +302,7 @@ func Test_InstancePrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, append(defaultPrepareArgs, tt.additionalArgs...)...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, tt.additionalArgs...) }) } } diff --git a/internal/query/instance_trusted_domain.go b/internal/query/instance_trusted_domain.go index 2847c3969a..8c3fd99987 100644 --- a/internal/query/instance_trusted_domain.go +++ b/internal/query/instance_trusted_domain.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" @@ -48,7 +47,7 @@ func (q *Queries) SearchInstanceTrustedDomains(ctx context.Context, queries *Ins ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareInstanceTrustedDomainsQuery(ctx, q.client) + query, scan := prepareInstanceTrustedDomainsQuery() stmt, args, err := queries.toQuery(query). Where(sq.Eq{ InstanceTrustedDomainInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -72,7 +71,7 @@ func (q *Queries) queryInstanceTrustedDomains(ctx context.Context, stmt string, return domains, err } -func prepareInstanceTrustedDomainsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*InstanceTrustedDomains, error)) { +func prepareInstanceTrustedDomainsQuery() (sq.SelectBuilder, func(*sql.Rows) (*InstanceTrustedDomains, error)) { return sq.Select( InstanceTrustedDomainCreationDateCol.identifier(), InstanceTrustedDomainChangeDateCol.identifier(), @@ -80,7 +79,7 @@ func prepareInstanceTrustedDomainsQuery(ctx context.Context, db prepareDatabase) InstanceTrustedDomainDomainCol.identifier(), InstanceTrustedDomainInstanceIDCol.identifier(), countColumn.identifier(), - ).From(instanceTrustedDomainsTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(instanceTrustedDomainsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*InstanceTrustedDomains, error) { domains := make([]*InstanceTrustedDomain, 0) diff --git a/internal/query/instance_trusted_domain_test.go b/internal/query/instance_trusted_domain_test.go index 6e3eea027e..518d2edb6b 100644 --- a/internal/query/instance_trusted_domain_test.go +++ b/internal/query/instance_trusted_domain_test.go @@ -16,8 +16,7 @@ var ( ` projections.instance_trusted_domains.domain,` + ` projections.instance_trusted_domains.instance_id,` + ` COUNT(*) OVER ()` + - ` FROM projections.instance_trusted_domains` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.instance_trusted_domains` prepareInstanceTrustedDomainsCols = []string{ "creation_date", "change_date", @@ -151,7 +150,7 @@ func Test_InstanceTrustedDomainPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/introspection_test.go b/internal/query/introspection_test.go index 4346842bf9..92c571ebf9 100644 --- a/internal/query/introspection_test.go +++ b/internal/query/introspection_test.go @@ -91,8 +91,7 @@ func TestQueries_ActiveIntrospectionClientByID(t *testing.T) { execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } ctx := authz.NewMockContext("instanceID", "orgID", "userID") diff --git a/internal/query/key.go b/internal/query/key.go index d7475e424b..4831d88654 100644 --- a/internal/query/key.go +++ b/internal/query/key.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/query/projection" @@ -182,7 +181,7 @@ func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (keys *Publ ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := preparePublicKeysQuery(ctx, q.client) + query, scan := preparePublicKeysQuery() if t.IsZero() { t = time.Now() } @@ -214,7 +213,7 @@ func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (key ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := preparePrivateKeysQuery(ctx, q.client) + stmt, scan := preparePrivateKeysQuery() if t.IsZero() { t = time.Now() } @@ -244,7 +243,7 @@ func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (key return keys, nil } -func preparePublicKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*PublicKeys, error)) { +func preparePublicKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PublicKeys, error)) { return sq.Select( KeyColID.identifier(), KeyColCreationDate.identifier(), @@ -257,7 +256,7 @@ func preparePublicKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectB KeyPublicColKey.identifier(), countColumn.identifier(), ).From(keyTable.identifier()). - LeftJoin(join(KeyPublicColID, KeyColID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(KeyPublicColID, KeyColID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*PublicKeys, error) { keys := make([]PublicKey, 0) @@ -300,7 +299,7 @@ func preparePublicKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectB } } -func preparePrivateKeysQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*PrivateKeys, error)) { +func preparePrivateKeysQuery() (sq.SelectBuilder, func(*sql.Rows) (*PrivateKeys, error)) { return sq.Select( KeyColID.identifier(), KeyColCreationDate.identifier(), @@ -313,7 +312,7 @@ func preparePrivateKeysQuery(ctx context.Context, db prepareDatabase) (sq.Select KeyPrivateColKey.identifier(), countColumn.identifier(), ).From(keyTable.identifier()). - LeftJoin(join(KeyPrivateColID, KeyColID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(KeyPrivateColID, KeyColID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*PrivateKeys, error) { keys := make([]PrivateKey, 0) diff --git a/internal/query/key_test.go b/internal/query/key_test.go index a977bfb58e..7bc029fd7f 100644 --- a/internal/query/key_test.go +++ b/internal/query/key_test.go @@ -36,8 +36,7 @@ var ( ` projections.keys4_public.key,` + ` COUNT(*) OVER ()` + ` FROM projections.keys4` + - ` LEFT JOIN projections.keys4_public ON projections.keys4.id = projections.keys4_public.id AND projections.keys4.instance_id = projections.keys4_public.instance_id` + - ` AS OF SYSTEM TIME '-1 ms' ` + ` LEFT JOIN projections.keys4_public ON projections.keys4.id = projections.keys4_public.id AND projections.keys4.instance_id = projections.keys4_public.instance_id` preparePublicKeysCols = []string{ "id", "creation_date", @@ -62,8 +61,7 @@ var ( ` projections.keys4_private.key,` + ` COUNT(*) OVER ()` + ` FROM projections.keys4` + - ` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id` + - ` AS OF SYSTEM TIME '-1 ms' ` + ` LEFT JOIN projections.keys4_private ON projections.keys4.id = projections.keys4_private.id AND projections.keys4.instance_id = projections.keys4_private.instance_id` ) func Test_KeyPrepares(t *testing.T) { @@ -244,7 +242,7 @@ func Test_KeyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/label_policy.go b/internal/query/label_policy.go index 6dc7b00922..d3952a210a 100644 --- a/internal/query/label_policy.go +++ b/internal/query/label_policy.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -47,7 +46,7 @@ func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, with ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareLabelPolicyQuery(ctx, q.client) + stmt, scan := prepareLabelPolicyQuery() eq := sq.Eq{ LabelPolicyColState.identifier(): domain.LabelPolicyStateActive, LabelPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -80,7 +79,7 @@ func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (po ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareLabelPolicyQuery(ctx, q.client) + stmt, scan := prepareLabelPolicyQuery() query, args, err := stmt.Where( sq.And{ sq.Or{ @@ -113,7 +112,7 @@ func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (policy *LabelPo ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareLabelPolicyQuery(ctx, q.client) + stmt, scan := prepareLabelPolicyQuery() query, args, err := stmt.Where(sq.Eq{ LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(), LabelPolicyColState.identifier(): domain.LabelPolicyStateActive, @@ -136,7 +135,7 @@ func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (policy *LabelP ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareLabelPolicyQuery(ctx, q.client) + stmt, scan := prepareLabelPolicyQuery() query, args, err := stmt.Where(sq.Eq{ LabelPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(), LabelPolicyColState.identifier(): domain.LabelPolicyStatePreview, @@ -240,7 +239,7 @@ var ( } ) -func prepareLabelPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LabelPolicy, error)) { +func prepareLabelPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*LabelPolicy, error)) { return sq.Select( LabelPolicyColCreationDate.identifier(), LabelPolicyColChangeDate.identifier(), @@ -270,7 +269,7 @@ func prepareLabelPolicyQuery(ctx context.Context, db prepareDatabase) (sq.Select LabelPolicyColDarkLogoURL.identifier(), LabelPolicyColDarkIconURL.identifier(), ). - From(labelPolicyTable.identifier() + db.Timetravel(call.Took(ctx))). + From(labelPolicyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*LabelPolicy, error) { policy := new(LabelPolicy) diff --git a/internal/query/lockout_policy.go b/internal/query/lockout_policy.go index be4b162785..078c743413 100644 --- a/internal/query/lockout_policy.go +++ b/internal/query/lockout_policy.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -98,7 +97,7 @@ func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool LockoutColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } - stmt, scan := prepareLockoutPolicyQuery(ctx, q.client) + stmt, scan := prepareLockoutPolicyQuery() query, args, err := stmt.Where( sq.And{ eq, @@ -124,7 +123,7 @@ func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (policy *LockoutPoli ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareLockoutPolicyQuery(ctx, q.client) + stmt, scan := prepareLockoutPolicyQuery() query, args, err := stmt.Where(sq.Eq{ LockoutColID.identifier(): authz.GetInstance(ctx).InstanceID(), LockoutColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -142,7 +141,7 @@ func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (policy *LockoutPoli return policy, err } -func prepareLockoutPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LockoutPolicy, error)) { +func prepareLockoutPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*LockoutPolicy, error)) { return sq.Select( LockoutColID.identifier(), LockoutColSequence.identifier(), @@ -155,7 +154,7 @@ func prepareLockoutPolicyQuery(ctx context.Context, db prepareDatabase) (sq.Sele LockoutColIsDefault.identifier(), LockoutColState.identifier(), ). - From(lockoutTable.identifier() + db.Timetravel(call.Took(ctx))). + From(lockoutTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*LockoutPolicy, error) { policy := new(LockoutPolicy) diff --git a/internal/query/lockout_policy_test.go b/internal/query/lockout_policy_test.go index 2805ef8fdc..0c0a9f04eb 100644 --- a/internal/query/lockout_policy_test.go +++ b/internal/query/lockout_policy_test.go @@ -23,8 +23,7 @@ var ( ` projections.lockout_policies3.max_otp_attempts,` + ` projections.lockout_policies3.is_default,` + ` projections.lockout_policies3.state` + - ` FROM projections.lockout_policies3` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.lockout_policies3` prepareLockoutPolicyCols = []string{ "id", @@ -123,7 +122,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/login_policy.go b/internal/query/login_policy.go index 5ab54cfa55..946dbb04de 100644 --- a/internal/query/login_policy.go +++ b/internal/query/login_policy.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" @@ -183,7 +182,7 @@ func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, o eq[LoginPolicyColumnOwnerRemoved.identifier()] = false } - query, scan := prepareLoginPolicyQuery(ctx, q.client) + query, scan := prepareLoginPolicyQuery() stmt, args, err := query.Where( sq.And{ eq, @@ -219,7 +218,7 @@ func (q *Queries) DefaultLoginPolicy(ctx context.Context) (policy *LoginPolicy, ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginPolicyQuery(ctx, q.client) + query, scan := prepareLoginPolicyQuery() stmt, args, err := query.Where(sq.Eq{ LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(), LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -242,7 +241,7 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (factors ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginPolicy2FAsQuery(ctx, q.client) + query, scan := prepareLoginPolicy2FAsQuery() stmt, args, err := query.Where( sq.And{ sq.Eq{ @@ -278,7 +277,7 @@ func (q *Queries) DefaultSecondFactors(ctx context.Context) (factors *SecondFact ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginPolicy2FAsQuery(ctx, q.client) + query, scan := prepareLoginPolicy2FAsQuery() stmt, args, err := query.Where(sq.Eq{ LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(), LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -302,7 +301,7 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (factors ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginPolicyMFAsQuery(ctx, q.client) + query, scan := prepareLoginPolicyMFAsQuery() stmt, args, err := query.Where( sq.And{ sq.Eq{ @@ -338,7 +337,7 @@ func (q *Queries) DefaultMultiFactors(ctx context.Context) (factors *MultiFactor ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareLoginPolicyMFAsQuery(ctx, q.client) + query, scan := prepareLoginPolicyMFAsQuery() stmt, args, err := query.Where(sq.Eq{ LoginPolicyColumnOrgID.identifier(): authz.GetInstance(ctx).InstanceID(), LoginPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -358,7 +357,7 @@ func (q *Queries) DefaultMultiFactors(ctx context.Context) (factors *MultiFactor return factors, err } -func prepareLoginPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*LoginPolicy, error)) { +func prepareLoginPolicyQuery() (sq.SelectBuilder, func(*sql.Rows) (*LoginPolicy, error)) { return sq.Select( LoginPolicyColumnOrgID.identifier(), LoginPolicyColumnCreationDate.identifier(), @@ -384,7 +383,7 @@ func prepareLoginPolicyQuery(ctx context.Context, db prepareDatabase) (sq.Select LoginPolicyColumnMFAInitSkipLifetime.identifier(), LoginPolicyColumnSecondFactorCheckLifetime.identifier(), LoginPolicyColumnMultiFactorCheckLifetime.identifier(), - ).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(loginPolicyTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*LoginPolicy, error) { p := new(LoginPolicy) @@ -428,10 +427,10 @@ func prepareLoginPolicyQuery(ctx context.Context, db prepareDatabase) (sq.Select } } -func prepareLoginPolicy2FAsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SecondFactors, error)) { +func prepareLoginPolicy2FAsQuery() (sq.SelectBuilder, func(*sql.Row) (*SecondFactors, error)) { return sq.Select( LoginPolicyColumnSecondFactors.identifier(), - ).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(loginPolicyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*SecondFactors, error) { p := new(SecondFactors) @@ -450,10 +449,10 @@ func prepareLoginPolicy2FAsQuery(ctx context.Context, db prepareDatabase) (sq.Se } } -func prepareLoginPolicyMFAsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MultiFactors, error)) { +func prepareLoginPolicyMFAsQuery() (sq.SelectBuilder, func(*sql.Row) (*MultiFactors, error)) { return sq.Select( LoginPolicyColumnMultiFactors.identifier(), - ).From(loginPolicyTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(loginPolicyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*MultiFactors, error) { p := new(MultiFactors) diff --git a/internal/query/login_policy_test.go b/internal/query/login_policy_test.go index f64c94e275..792f30517a 100644 --- a/internal/query/login_policy_test.go +++ b/internal/query/login_policy_test.go @@ -39,8 +39,7 @@ var ( ` projections.login_policies5.mfa_init_skip_lifetime,` + ` projections.login_policies5.second_factor_check_lifetime,` + ` projections.login_policies5.multi_factor_check_lifetime` + - ` FROM projections.login_policies5` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.login_policies5` loginPolicyCols = []string{ "aggregate_id", "creation_date", @@ -69,15 +68,13 @@ var ( } prepareLoginPolicy2FAsStmt = `SELECT projections.login_policies5.second_factors` + - ` FROM projections.login_policies5` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.login_policies5` prepareLoginPolicy2FAsCols = []string{ "second_factors", } prepareLoginPolicyMFAsStmt = `SELECT projections.login_policies5.multi_factors` + - ` FROM projections.login_policies5` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.login_policies5` prepareLoginPolicyMFAsCols = []string{ "multi_factors", } @@ -331,7 +328,7 @@ func Test_LoginPolicyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/mail_template.go b/internal/query/mail_template.go index 9d5ff83162..518ab7aec5 100644 --- a/internal/query/mail_template.go +++ b/internal/query/mail_template.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -74,7 +73,7 @@ func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwner ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareMailTemplateQuery(ctx, q.client) + stmt, scan := prepareMailTemplateQuery() eq := sq.Eq{MailTemplateColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} if !withOwnerRemoved { eq[MailTemplateColOwnerRemoved.identifier()] = false @@ -104,7 +103,7 @@ func (q *Queries) DefaultMailTemplate(ctx context.Context) (template *MailTempla ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareMailTemplateQuery(ctx, q.client) + stmt, scan := prepareMailTemplateQuery() query, args, err := stmt.Where(sq.Eq{ MailTemplateColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(), MailTemplateColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -122,7 +121,7 @@ func (q *Queries) DefaultMailTemplate(ctx context.Context) (template *MailTempla return template, err } -func prepareMailTemplateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MailTemplate, error)) { +func prepareMailTemplateQuery() (sq.SelectBuilder, func(*sql.Row) (*MailTemplate, error)) { return sq.Select( MailTemplateColAggregateID.identifier(), MailTemplateColSequence.identifier(), @@ -132,7 +131,7 @@ func prepareMailTemplateQuery(ctx context.Context, db prepareDatabase) (sq.Selec MailTemplateColIsDefault.identifier(), MailTemplateColState.identifier(), ). - From(mailTemplateTable.identifier() + db.Timetravel(call.Took(ctx))). + From(mailTemplateTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*MailTemplate, error) { policy := new(MailTemplate) diff --git a/internal/query/message_text.go b/internal/query/message_text.go index cb524d289a..b64c93e4ec 100644 --- a/internal/query/message_text.go +++ b/internal/query/message_text.go @@ -15,7 +15,6 @@ import ( "sigs.k8s.io/yaml" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/i18n" "github.com/zitadel/zitadel/internal/query/projection" @@ -131,7 +130,7 @@ func (q *Queries) DefaultMessageText(ctx context.Context) (text *MessageText, er ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareMessageTextQuery(ctx, q.client) + stmt, scan := prepareMessageTextQuery() query, args, err := stmt.Where(sq.Eq{ MessageTextColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(), MessageTextColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -167,7 +166,7 @@ func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggreg ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareMessageTextQuery(ctx, q.client) + stmt, scan := prepareMessageTextQuery() eq := sq.Eq{ MessageTextColLanguage.identifier(): language, MessageTextColType.identifier(): messageType, @@ -249,7 +248,7 @@ func (q *Queries) readNotificationTextMessages(ctx context.Context, language str return contents, nil } -func prepareMessageTextQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MessageText, error)) { +func prepareMessageTextQuery() (sq.SelectBuilder, func(*sql.Row) (*MessageText, error)) { return sq.Select( MessageTextColAggregateID.identifier(), MessageTextColSequence.identifier(), @@ -266,7 +265,7 @@ func prepareMessageTextQuery(ctx context.Context, db prepareDatabase) (sq.Select MessageTextColButtonText.identifier(), MessageTextColFooter.identifier(), ). - From(messageTextTable.identifier() + db.Timetravel(call.Took(ctx))). + From(messageTextTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*MessageText, error) { msg := new(MessageText) diff --git a/internal/query/message_text_test.go b/internal/query/message_text_test.go index 09df5dcd83..4e78f4813d 100644 --- a/internal/query/message_text_test.go +++ b/internal/query/message_text_test.go @@ -29,8 +29,7 @@ var ( ` projections.message_texts2.text,` + ` projections.message_texts2.button_text,` + ` projections.message_texts2.footer_text` + - ` FROM projections.message_texts2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.message_texts2` prepareMessgeTextCols = []string{ "aggregate_id", "sequence", @@ -140,7 +139,7 @@ func Test_MessageTextPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/milestone.go b/internal/query/milestone.go index 4277b8e68a..21edeba55a 100644 --- a/internal/query/milestone.go +++ b/internal/query/milestone.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/milestone" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -68,7 +67,7 @@ var ( func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *MilestonesSearchQueries) (milestones *Milestones, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareMilestonesQuery(ctx, q.client) + query, scan := prepareMilestonesQuery() if len(instanceIDs) == 0 { instanceIDs = []string{authz.GetInstance(ctx).InstanceID()} } @@ -92,7 +91,7 @@ func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, qu } -func prepareMilestonesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Milestones, error)) { +func prepareMilestonesQuery() (sq.SelectBuilder, func(*sql.Rows) (*Milestones, error)) { return sq.Select( MilestoneInstanceIDColID.identifier(), InstanceDomainDomainCol.identifier(), @@ -101,7 +100,7 @@ func prepareMilestonesQuery(ctx context.Context, db prepareDatabase) (sq.SelectB MilestoneTypeColID.identifier(), countColumn.identifier(), ). - From(milestonesTable.identifier() + db.Timetravel(call.Took(ctx))). + From(milestonesTable.identifier()). LeftJoin(join(InstanceDomainInstanceIDCol, MilestoneInstanceIDColID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Milestones, error) { diff --git a/internal/query/milestone_test.go b/internal/query/milestone_test.go index ee99474ec2..027ebca48c 100644 --- a/internal/query/milestone_test.go +++ b/internal/query/milestone_test.go @@ -17,7 +17,7 @@ var ( projections.milestones3.last_pushed_date, projections.milestones3.type, COUNT(*) OVER () - FROM projections.milestones3 AS OF SYSTEM TIME '-1 ms' + FROM projections.milestones3 LEFT JOIN projections.instance_domains ON projections.milestones3.instance_id = projections.instance_domains.instance_id `) @@ -184,7 +184,7 @@ func Test_MilestonesPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/notification_policy.go b/internal/query/notification_policy.go index f3878e7987..45779762d9 100644 --- a/internal/query/notification_policy.go +++ b/internal/query/notification_policy.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -93,7 +92,7 @@ func (q *Queries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk if !withOwnerRemoved { eq[NotificationPolicyColOwnerRemoved.identifier()] = false } - stmt, scan := prepareNotificationPolicyQuery(ctx, q.client) + stmt, scan := prepareNotificationPolicyQuery() query, args, err := stmt.Where( sq.And{ eq, @@ -127,7 +126,7 @@ func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBu } } - stmt, scan := prepareNotificationPolicyQuery(ctx, q.client) + stmt, scan := prepareNotificationPolicyQuery() query, args, err := stmt.Where(sq.Eq{ NotificationPolicyColID.identifier(): authz.GetInstance(ctx).InstanceID(), NotificationPolicyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -145,7 +144,7 @@ func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBu return policy, err } -func prepareNotificationPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*NotificationPolicy, error)) { +func prepareNotificationPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*NotificationPolicy, error)) { return sq.Select( NotificationPolicyColID.identifier(), NotificationPolicyColSequence.identifier(), @@ -156,7 +155,7 @@ func prepareNotificationPolicyQuery(ctx context.Context, db prepareDatabase) (sq NotificationPolicyColIsDefault.identifier(), NotificationPolicyColState.identifier(), ). - From(notificationPolicyTable.identifier() + db.Timetravel(call.Took(ctx))). + From(notificationPolicyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*NotificationPolicy, error) { policy := new(NotificationPolicy) diff --git a/internal/query/notification_policy_test.go b/internal/query/notification_policy_test.go index d755bdc544..bbb40d4e5b 100644 --- a/internal/query/notification_policy_test.go +++ b/internal/query/notification_policy_test.go @@ -21,8 +21,7 @@ var ( ` projections.notification_policies.password_change,` + ` projections.notification_policies.is_default,` + ` projections.notification_policies.state` + - ` FROM projections.notification_policies` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` FROM projections.notification_policies`) notificationPolicyCols = []string{ "id", "sequence", @@ -114,7 +113,7 @@ func Test_NotificationPolicyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/notification_provider.go b/internal/query/notification_provider.go index b2038c603d..fa48e42c9b 100644 --- a/internal/query/notification_provider.go +++ b/internal/query/notification_provider.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -74,7 +73,7 @@ func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID str ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareDebugNotificationProviderQuery(ctx, q.client) + query, scan := prepareDebugNotificationProviderQuery() stmt, args, err := query.Where( sq.And{ sq.Eq{NotificationProviderColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}, @@ -97,7 +96,7 @@ func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID str return provider, err } -func prepareDebugNotificationProviderQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DebugNotificationProvider, error)) { +func prepareDebugNotificationProviderQuery() (sq.SelectBuilder, func(*sql.Row) (*DebugNotificationProvider, error)) { return sq.Select( NotificationProviderColumnAggID.identifier(), NotificationProviderColumnCreationDate.identifier(), @@ -107,7 +106,7 @@ func prepareDebugNotificationProviderQuery(ctx context.Context, db prepareDataba NotificationProviderColumnState.identifier(), NotificationProviderColumnType.identifier(), NotificationProviderColumnCompact.identifier(), - ).From(notificationProviderTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(notificationProviderTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*DebugNotificationProvider, error) { p := new(DebugNotificationProvider) diff --git a/internal/query/notification_provider_test.go b/internal/query/notification_provider_test.go index 2fce31e118..a2c88ccbcb 100644 --- a/internal/query/notification_provider_test.go +++ b/internal/query/notification_provider_test.go @@ -21,8 +21,7 @@ var ( ` projections.notification_providers.state,` + ` projections.notification_providers.provider_type,` + ` projections.notification_providers.compact` + - ` FROM projections.notification_providers` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.notification_providers` prepareNotificationProviderCols = []string{ "aggregate_id", "creation_date", @@ -114,7 +113,7 @@ func Test_NotificationProviderPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/oidc_client_test.go b/internal/query/oidc_client_test.go index 25e069da85..826e5071db 100644 --- a/internal/query/oidc_client_test.go +++ b/internal/query/oidc_client_test.go @@ -268,8 +268,7 @@ low2kyJov38V4Uk2I8kuXpLcnrpw5Tio2ooiUE27b0vHZqBKOei9Uo88qCrn3EKx execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") diff --git a/internal/query/oidc_settings.go b/internal/query/oidc_settings.go index 32cbc32429..bdd21cfd15 100644 --- a/internal/query/oidc_settings.go +++ b/internal/query/oidc_settings.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" @@ -79,7 +78,7 @@ func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) ( ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareOIDCSettingsQuery(ctx, q.client) + stmt, scan := prepareOIDCSettingsQuery() query, args, err := stmt.Where(sq.Eq{ OIDCSettingsColumnAggregateID.identifier(): aggregateID, OIDCSettingsColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -95,7 +94,7 @@ func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) ( return settings, err } -func prepareOIDCSettingsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*OIDCSettings, error)) { +func prepareOIDCSettingsQuery() (sq.SelectBuilder, func(*sql.Row) (*OIDCSettings, error)) { return sq.Select( OIDCSettingsColumnAggregateID.identifier(), OIDCSettingsColumnCreationDate.identifier(), @@ -106,7 +105,7 @@ func prepareOIDCSettingsQuery(ctx context.Context, db prepareDatabase) (sq.Selec OIDCSettingsColumnIdTokenLifetime.identifier(), OIDCSettingsColumnRefreshTokenIdleExpiration.identifier(), OIDCSettingsColumnRefreshTokenExpiration.identifier()). - From(oidcSettingsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(oidcSettingsTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*OIDCSettings, error) { oidcSettings := new(OIDCSettings) diff --git a/internal/query/oidc_settings_test.go b/internal/query/oidc_settings_test.go index bdb5cb96ec..625c16f34c 100644 --- a/internal/query/oidc_settings_test.go +++ b/internal/query/oidc_settings_test.go @@ -22,8 +22,7 @@ var ( ` projections.oidc_settings2.id_token_lifetime,` + ` projections.oidc_settings2.refresh_token_idle_expiration,` + ` projections.oidc_settings2.refresh_token_expiration` + - ` FROM projections.oidc_settings2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.oidc_settings2` prepareOIDCSettingsCols = []string{ "aggregate_id", "creation_date", @@ -118,7 +117,7 @@ func Test_OIDCConfigsPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/org.go b/internal/query/org.go index e5bfc5140f..b1f5eaea02 100644 --- a/internal/query/org.go +++ b/internal/query/org.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" domain_pkg "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/feature" @@ -164,7 +163,7 @@ func (q *Queries) oldOrgByID(ctx context.Context, shouldTriggerBulk bool, id str traceSpan.EndWithError(err) } - stmt, scan := prepareOrgQuery(ctx, q.client) + stmt, scan := prepareOrgQuery() query, args, err := stmt.Where(sq.Eq{ OrgColumnID.identifier(): id, OrgColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -189,7 +188,7 @@ func (q *Queries) OrgByPrimaryDomain(ctx context.Context, domain string) (org *O return org, nil } - stmt, scan := prepareOrgQuery(ctx, q.client) + stmt, scan := prepareOrgQuery() query, args, err := stmt.Where(sq.Eq{ OrgColumnDomain.identifier(): domain, OrgColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -213,7 +212,7 @@ func (q *Queries) OrgByVerifiedDomain(ctx context.Context, domain string) (org * ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareOrgWithDomainsQuery(ctx, q.client) + stmt, scan := prepareOrgWithDomainsQuery() query, args, err := stmt.Where(sq.Eq{ OrgDomainDomainCol.identifier(): domain, OrgDomainIsVerifiedCol.identifier(): true, @@ -237,7 +236,7 @@ func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUniqu if name == "" && domain == "" { return false, zerrors.ThrowInvalidArgument(nil, "QUERY-DGqfd", "Errors.Query.InvalidRequest") } - query, scan := prepareOrgUniqueQuery(ctx, q.client) + query, scan := prepareOrgUniqueQuery() stmt, args, err := query.Where( sq.And{ sq.Eq{ @@ -298,7 +297,7 @@ func (q *Queries) searchOrgs(ctx context.Context, queries *OrgSearchQueries) (or ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareOrgsQuery(ctx, q.client) + query, scan := prepareOrgsQuery() stmt, args, err := queries.toQuery(query). Where(sq.And{ sq.Eq{ @@ -361,7 +360,7 @@ func NewOrgIDsSearchQuery(ids ...string) (SearchQuery, error) { return NewListQuery(OrgColumnID, list, ListIn) } -func prepareOrgsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Orgs, error)) { +func prepareOrgsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Orgs, error)) { return sq.Select( OrgColumnID.identifier(), OrgColumnCreationDate.identifier(), @@ -372,7 +371,7 @@ func prepareOrgsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder OrgColumnName.identifier(), OrgColumnDomain.identifier(), countColumn.identifier()). - From(orgsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(orgsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Orgs, error) { orgs := make([]*Org, 0) @@ -409,42 +408,7 @@ func prepareOrgsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder } } -func prepareOrgQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Org, error)) { - return sq.Select( - OrgColumnID.identifier(), - OrgColumnCreationDate.identifier(), - OrgColumnChangeDate.identifier(), - OrgColumnResourceOwner.identifier(), - OrgColumnState.identifier(), - OrgColumnSequence.identifier(), - OrgColumnName.identifier(), - OrgColumnDomain.identifier(), - ). - From(orgsTable.identifier() + db.Timetravel(call.Took(ctx))). - PlaceholderFormat(sq.Dollar), - func(row *sql.Row) (*Org, error) { - o := new(Org) - err := row.Scan( - &o.ID, - &o.CreationDate, - &o.ChangeDate, - &o.ResourceOwner, - &o.State, - &o.Sequence, - &o.Name, - &o.Domain, - ) - if err != nil { - if errors.Is(err, sql.ErrNoRows) { - return nil, zerrors.ThrowNotFound(err, "QUERY-iTTGJ", "Errors.Org.NotFound") - } - return nil, zerrors.ThrowInternal(err, "QUERY-pWS5H", "Errors.Internal") - } - return o, nil - } -} - -func prepareOrgWithDomainsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Org, error)) { +func prepareOrgQuery() (sq.SelectBuilder, func(*sql.Row) (*Org, error)) { return sq.Select( OrgColumnID.identifier(), OrgColumnCreationDate.identifier(), @@ -456,7 +420,6 @@ func prepareOrgWithDomainsQuery(ctx context.Context, db prepareDatabase) (sq.Sel OrgColumnDomain.identifier(), ). From(orgsTable.identifier()). - LeftJoin(join(OrgDomainOrgIDCol, OrgColumnID) + db.Timetravel(call.Took(ctx))). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Org, error) { o := new(Org) @@ -480,10 +443,46 @@ func prepareOrgWithDomainsQuery(ctx context.Context, db prepareDatabase) (sq.Sel } } -func prepareOrgUniqueQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (bool, error)) { +func prepareOrgWithDomainsQuery() (sq.SelectBuilder, func(*sql.Row) (*Org, error)) { + return sq.Select( + OrgColumnID.identifier(), + OrgColumnCreationDate.identifier(), + OrgColumnChangeDate.identifier(), + OrgColumnResourceOwner.identifier(), + OrgColumnState.identifier(), + OrgColumnSequence.identifier(), + OrgColumnName.identifier(), + OrgColumnDomain.identifier(), + ). + From(orgsTable.identifier()). + LeftJoin(join(OrgDomainOrgIDCol, OrgColumnID)). + PlaceholderFormat(sq.Dollar), + func(row *sql.Row) (*Org, error) { + o := new(Org) + err := row.Scan( + &o.ID, + &o.CreationDate, + &o.ChangeDate, + &o.ResourceOwner, + &o.State, + &o.Sequence, + &o.Name, + &o.Domain, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-iTTGJ", "Errors.Org.NotFound") + } + return nil, zerrors.ThrowInternal(err, "QUERY-pWS5H", "Errors.Internal") + } + return o, nil + } +} + +func prepareOrgUniqueQuery() (sq.SelectBuilder, func(*sql.Row) (bool, error)) { return sq.Select(uniqueColumn.identifier()). From(orgsTable.identifier()). - LeftJoin(join(OrgDomainOrgIDCol, OrgColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(OrgDomainOrgIDCol, OrgColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (isUnique bool, err error) { err = row.Scan(&isUnique) diff --git a/internal/query/org_domain.go b/internal/query/org_domain.go index 595ba897d0..ed0dba9c17 100644 --- a/internal/query/org_domain.go +++ b/internal/query/org_domain.go @@ -8,7 +8,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -60,7 +59,7 @@ func (q *Queries) SearchOrgDomains(ctx context.Context, queries *OrgDomainSearch ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareDomainsQuery(ctx, q.client) + query, scan := prepareDomainsQuery() eq := sq.Eq{OrgDomainInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()} if !withOwnerRemoved { eq[OrgDomainOwnerRemovedCol.identifier()] = false @@ -82,7 +81,7 @@ func (q *Queries) SearchOrgDomains(ctx context.Context, queries *OrgDomainSearch return domains, err } -func prepareDomainsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Domains, error)) { +func prepareDomainsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Domains, error)) { return sq.Select( OrgDomainCreationDateCol.identifier(), OrgDomainChangeDateCol.identifier(), @@ -93,7 +92,7 @@ func prepareDomainsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil OrgDomainIsPrimaryCol.identifier(), OrgDomainValidationTypeCol.identifier(), countColumn.identifier(), - ).From(orgDomainsTable.identifier() + db.Timetravel(call.Took(ctx))). + ).From(orgDomainsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Domains, error) { domains := make([]*Domain, 0) diff --git a/internal/query/org_domain_test.go b/internal/query/org_domain_test.go index 5757eda657..6668528241 100644 --- a/internal/query/org_domain_test.go +++ b/internal/query/org_domain_test.go @@ -21,8 +21,7 @@ var ( ` projections.org_domains2.is_primary,` + ` projections.org_domains2.validation_type,` + ` COUNT(*) OVER ()` + - ` FROM projections.org_domains2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.org_domains2` prepareOrgDomainsCols = []string{ "id", "creation_date", @@ -177,7 +176,7 @@ func Test_OrgDomainPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/org_member.go b/internal/query/org_member.go index 4daa31d341..a85c5d5f6a 100644 --- a/internal/query/org_member.go +++ b/internal/query/org_member.go @@ -7,7 +7,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -73,7 +72,7 @@ func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery) (mem ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareOrgMembersQuery(ctx, q.client) + query, scan := prepareOrgMembersQuery() eq := sq.Eq{OrgMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -97,7 +96,7 @@ func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery) (mem return members, err } -func prepareOrgMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { +func prepareOrgMembersQuery() (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { return sq.Select( OrgMemberCreationDate.identifier(), OrgMemberChangeDate.identifier(), @@ -119,7 +118,7 @@ func prepareOrgMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectB LeftJoin(join(HumanUserIDCol, OrgMemberUserID)). LeftJoin(join(MachineUserIDCol, OrgMemberUserID)). LeftJoin(join(UserIDCol, OrgMemberUserID)). - LeftJoin(join(LoginNameUserIDCol, OrgMemberUserID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(LoginNameUserIDCol, OrgMemberUserID)). Where( sq.Eq{LoginNameIsPrimaryCol.identifier(): true}, ).PlaceholderFormat(sq.Dollar), diff --git a/internal/query/org_member_test.go b/internal/query/org_member_test.go index 8433c338ee..cb0b64d55f 100644 --- a/internal/query/org_member_test.go +++ b/internal/query/org_member_test.go @@ -43,7 +43,6 @@ var ( "LEFT JOIN projections.login_names3 " + "ON members.user_id = projections.login_names3.user_id " + "AND members.instance_id = projections.login_names3.instance_id " + - "AS OF SYSTEM TIME '-1 ms' " + "WHERE projections.login_names3.is_primary = $1") orgMembersColumns = []string{ "creation_date", @@ -299,7 +298,7 @@ func Test_OrgMemberPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/org_metadata.go b/internal/query/org_metadata.go index 1ce95e2880..84b204de2b 100644 --- a/internal/query/org_metadata.go +++ b/internal/query/org_metadata.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -90,7 +89,7 @@ func (q *Queries) GetOrgMetadataByKey(ctx context.Context, shouldTriggerBulk boo traceSpan.EndWithError(err) } - query, scan := prepareOrgMetadataQuery(ctx, q.client) + query, scan := prepareOrgMetadataQuery() for _, q := range queries { query = q.toQuery(query) } @@ -131,7 +130,7 @@ func (q *Queries) SearchOrgMetadata(ctx context.Context, shouldTriggerBulk bool, if !withOwnerRemoved { eq[OrgMetadataOwnerRemovedCol.identifier()] = false } - query, scan := prepareOrgMetadataListQuery(ctx, q.client) + query, scan := prepareOrgMetadataListQuery() stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { return nil, zerrors.ThrowInternal(err, "QUERY-Egbld", "Errors.Query.SQLStatment") @@ -174,7 +173,7 @@ func NewOrgMetadataKeySearchQuery(value string, comparison TextComparison) (Sear return NewTextQuery(OrgMetadataKeyCol, value, comparison) } -func prepareOrgMetadataQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*OrgMetadata, error)) { +func prepareOrgMetadataQuery() (sq.SelectBuilder, func(*sql.Row) (*OrgMetadata, error)) { return sq.Select( OrgMetadataCreationDateCol.identifier(), OrgMetadataChangeDateCol.identifier(), @@ -183,7 +182,7 @@ func prepareOrgMetadataQuery(ctx context.Context, db prepareDatabase) (sq.Select OrgMetadataKeyCol.identifier(), OrgMetadataValueCol.identifier(), ). - From(orgMetadataTable.identifier() + db.Timetravel(call.Took(ctx))). + From(orgMetadataTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*OrgMetadata, error) { m := new(OrgMetadata) @@ -206,7 +205,7 @@ func prepareOrgMetadataQuery(ctx context.Context, db prepareDatabase) (sq.Select } } -func prepareOrgMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*OrgMetadataList, error)) { +func prepareOrgMetadataListQuery() (sq.SelectBuilder, func(*sql.Rows) (*OrgMetadataList, error)) { return sq.Select( OrgMetadataCreationDateCol.identifier(), OrgMetadataChangeDateCol.identifier(), @@ -215,7 +214,7 @@ func prepareOrgMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.Se OrgMetadataKeyCol.identifier(), OrgMetadataValueCol.identifier(), countColumn.identifier()). - From(orgMetadataTable.identifier() + db.Timetravel(call.Took(ctx))). + From(orgMetadataTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*OrgMetadataList, error) { metadata := make([]*OrgMetadata, 0) diff --git a/internal/query/org_metadata_test.go b/internal/query/org_metadata_test.go index 0225ef1c2a..666fddd0fd 100644 --- a/internal/query/org_metadata_test.go +++ b/internal/query/org_metadata_test.go @@ -18,8 +18,7 @@ var ( ` projections.org_metadata2.sequence,` + ` projections.org_metadata2.key,` + ` projections.org_metadata2.value` + - ` FROM projections.org_metadata2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.org_metadata2` orgMetadataCols = []string{ "creation_date", "change_date", @@ -35,8 +34,7 @@ var ( ` projections.org_metadata2.key,` + ` projections.org_metadata2.value,` + ` COUNT(*) OVER ()` + - ` FROM projections.org_metadata2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.org_metadata2` orgMetadataListCols = []string{ "creation_date", "change_date", @@ -244,7 +242,7 @@ func Test_OrgMetadataPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/org_test.go b/internal/query/org_test.go index db41f9ffd1..d704d2901a 100644 --- a/internal/query/org_test.go +++ b/internal/query/org_test.go @@ -19,7 +19,7 @@ import ( ) var ( - orgUniqueQuery = "SELECT COUNT(*) = 0 FROM projections.orgs1 LEFT JOIN projections.org_domains2 ON projections.orgs1.id = projections.org_domains2.org_id AND projections.orgs1.instance_id = projections.org_domains2.instance_id AS OF SYSTEM TIME '-1 ms' WHERE (projections.org_domains2.is_verified = $1 AND projections.orgs1.instance_id = $2 AND (projections.org_domains2.domain ILIKE $3 OR projections.orgs1.name ILIKE $4) AND projections.orgs1.org_state <> $5)" + orgUniqueQuery = "SELECT COUNT(*) = 0 FROM projections.orgs1 LEFT JOIN projections.org_domains2 ON projections.orgs1.id = projections.org_domains2.org_id AND projections.orgs1.instance_id = projections.org_domains2.instance_id WHERE (projections.org_domains2.is_verified = $1 AND projections.orgs1.instance_id = $2 AND (projections.org_domains2.domain ILIKE $3 OR projections.orgs1.name ILIKE $4) AND projections.orgs1.org_state <> $5)" orgUniqueCols = []string{"is_unique"} prepareOrgsQueryStmt = `SELECT projections.orgs1.id,` + @@ -31,8 +31,7 @@ var ( ` projections.orgs1.name,` + ` projections.orgs1.primary_domain,` + ` COUNT(*) OVER ()` + - ` FROM projections.orgs1` + - ` AS OF SYSTEM TIME '-1 ms' ` + ` FROM projections.orgs1` prepareOrgsQueryCols = []string{ "id", "creation_date", @@ -53,8 +52,7 @@ var ( ` projections.orgs1.sequence,` + ` projections.orgs1.name,` + ` projections.orgs1.primary_domain` + - ` FROM projections.orgs1` + - ` AS OF SYSTEM TIME '-1 ms' ` + ` FROM projections.orgs1` prepareOrgQueryCols = []string{ "id", "creation_date", @@ -68,8 +66,7 @@ var ( prepareOrgUniqueStmt = `SELECT COUNT(*) = 0` + ` FROM projections.orgs1` + - ` LEFT JOIN projections.org_domains2 ON projections.orgs1.id = projections.org_domains2.org_id AND projections.orgs1.instance_id = projections.org_domains2.instance_id` + - ` AS OF SYSTEM TIME '-1 ms' ` + ` LEFT JOIN projections.org_domains2 ON projections.orgs1.id = projections.org_domains2.org_id AND projections.orgs1.instance_id = projections.org_domains2.instance_id` prepareOrgUniqueCols = []string{ "count", } @@ -330,7 +327,7 @@ func Test_OrgPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -421,8 +418,7 @@ func TestQueries_IsOrgUnique(t *testing.T) { t.Run(tt.name, func(t *testing.T) { q := &Queries{ client: &database.DB{ - DB: client, - Database: new(prepareDB), + DB: client, }, } diff --git a/internal/query/password_age_policy.go b/internal/query/password_age_policy.go index 15b1b248c8..f5f0491d7b 100644 --- a/internal/query/password_age_policy.go +++ b/internal/query/password_age_policy.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -97,7 +96,7 @@ func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk if !withOwnerRemoved { eq[PasswordAgeColOwnerRemoved.identifier()] = false } - stmt, scan := preparePasswordAgePolicyQuery(ctx, q.client) + stmt, scan := preparePasswordAgePolicyQuery() query, args, err := stmt.Where( sq.And{ eq, @@ -130,7 +129,7 @@ func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBul traceSpan.EndWithError(err) } - stmt, scan := preparePasswordAgePolicyQuery(ctx, q.client) + stmt, scan := preparePasswordAgePolicyQuery() query, args, err := stmt.Where(sq.Eq{ PasswordAgeColID.identifier(): authz.GetInstance(ctx).InstanceID(), }). @@ -147,7 +146,7 @@ func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBul return policy, err } -func preparePasswordAgePolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PasswordAgePolicy, error)) { +func preparePasswordAgePolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*PasswordAgePolicy, error)) { return sq.Select( PasswordAgeColID.identifier(), PasswordAgeColSequence.identifier(), @@ -159,7 +158,7 @@ func preparePasswordAgePolicyQuery(ctx context.Context, db prepareDatabase) (sq. PasswordAgeColIsDefault.identifier(), PasswordAgeColState.identifier(), ). - From(passwordAgeTable.identifier() + db.Timetravel(call.Took(ctx))). + From(passwordAgeTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*PasswordAgePolicy, error) { policy := new(PasswordAgePolicy) diff --git a/internal/query/password_age_policy_test.go b/internal/query/password_age_policy_test.go index b140f82a06..f40acdb559 100644 --- a/internal/query/password_age_policy_test.go +++ b/internal/query/password_age_policy_test.go @@ -22,8 +22,7 @@ var ( ` projections.password_age_policies2.max_age_days,` + ` projections.password_age_policies2.is_default,` + ` projections.password_age_policies2.state` + - ` FROM projections.password_age_policies2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.password_age_policies2` preparePasswordAgePolicyCols = []string{ "id", "sequence", @@ -118,7 +117,7 @@ func Test_PasswordAgePolicyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/password_complexity_policy.go b/internal/query/password_complexity_policy.go index a895c98b75..fa0e5b2691 100644 --- a/internal/query/password_complexity_policy.go +++ b/internal/query/password_complexity_policy.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -49,7 +48,7 @@ func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTrigg if !withOwnerRemoved { eq[PasswordComplexityColOwnerRemoved.identifier()] = false } - stmt, scan := preparePasswordComplexityPolicyQuery(ctx, q.client) + stmt, scan := preparePasswordComplexityPolicyQuery() query, args, err := stmt.Where( sq.And{ eq, @@ -82,7 +81,7 @@ func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTri traceSpan.EndWithError(err) } - stmt, scan := preparePasswordComplexityPolicyQuery(ctx, q.client) + stmt, scan := preparePasswordComplexityPolicyQuery() query, args, err := stmt.Where(sq.Eq{ PasswordComplexityColID.identifier(): authz.GetInstance(ctx).InstanceID(), PasswordComplexityColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -163,7 +162,7 @@ var ( } ) -func preparePasswordComplexityPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PasswordComplexityPolicy, error)) { +func preparePasswordComplexityPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*PasswordComplexityPolicy, error)) { return sq.Select( PasswordComplexityColID.identifier(), PasswordComplexityColSequence.identifier(), @@ -178,7 +177,7 @@ func preparePasswordComplexityPolicyQuery(ctx context.Context, db prepareDatabas PasswordComplexityColIsDefault.identifier(), PasswordComplexityColState.identifier(), ). - From(passwordComplexityTable.identifier() + db.Timetravel(call.Took(ctx))). + From(passwordComplexityTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*PasswordComplexityPolicy, error) { policy := new(PasswordComplexityPolicy) diff --git a/internal/query/password_complexity_policy_test.go b/internal/query/password_complexity_policy_test.go index ac471f3994..e5738049dd 100644 --- a/internal/query/password_complexity_policy_test.go +++ b/internal/query/password_complexity_policy_test.go @@ -25,8 +25,7 @@ var ( ` projections.password_complexity_policies2.has_symbol,` + ` projections.password_complexity_policies2.is_default,` + ` projections.password_complexity_policies2.state` + - ` FROM projections.password_complexity_policies2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.password_complexity_policies2` preparePasswordComplexityPolicyCols = []string{ "id", "sequence", @@ -130,7 +129,7 @@ func Test_PasswordComplexityPolicyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/prepare_test.go b/internal/query/prepare_test.go index f8cf31cdef..e243426260 100644 --- a/internal/query/prepare_test.go +++ b/internal/query/prepare_test.go @@ -1,7 +1,6 @@ package query import ( - "context" "database/sql" "database/sql/driver" "errors" @@ -32,7 +31,7 @@ var ( // func() (sq.SelectBuilder, func(*sql.Row) (*struct, error)) // expectedObject represents the return value of scan // sqlExpectation represents the query executed on the database -func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExpectation sqlExpectation, isErr checkErr, prepareArgs ...reflect.Value) bool { +func assertPrepare(t *testing.T, prepareFunc, expectedObject any, sqlExpectation sqlExpectation, isErr checkErr, prepareArgs ...reflect.Value) bool { t.Helper() client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter))) @@ -243,9 +242,9 @@ func validateScan(scanType reflect.Type) error { return nil } -func execPrepare(prepare interface{}, args []reflect.Value) (builder sq.SelectBuilder, scan interface{}, err error) { +func execPrepare(prepare any, args []reflect.Value) (builder sq.SelectBuilder, scan interface{}, err error) { prepareVal := reflect.ValueOf(prepare) - if err := validatePrepare(prepareVal.Type()); err != nil { + if err := validatePrepare(prepareVal.Type(), len(args)); err != nil { return sq.SelectBuilder{}, nil, err } res := prepareVal.Call(args) @@ -253,12 +252,12 @@ func execPrepare(prepare interface{}, args []reflect.Value) (builder sq.SelectBu return res[0].Interface().(sq.SelectBuilder), res[1].Interface(), nil } -func validatePrepare(prepareType reflect.Type) error { +func validatePrepare(prepareType reflect.Type, numArgs int) error { if prepareType.Kind() != reflect.Func { return errors.New("prepare is not a function") } - if prepareType.NumIn() != 0 && prepareType.NumIn() != 2 { - return fmt.Errorf("prepare: invalid number of inputs: want: 0 or 2 got %d", prepareType.NumIn()) + if prepareType.NumIn() != numArgs { + return fmt.Errorf("prepare: invalid number of inputs: want: %d got %d", numArgs, prepareType.NumIn()) } if prepareType.NumOut() != 2 { return fmt.Errorf("prepare: invalid number of outputs: want: 2 got %d", prepareType.NumOut()) @@ -363,7 +362,7 @@ func TestValidatePrepare(t *testing.T) { }, { name: "correct", - t: reflect.TypeOf(func(context.Context, prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) { + t: reflect.TypeOf(func() (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) { log.Fatal("should not be executed") return sq.SelectBuilder{}, nil }), @@ -372,24 +371,10 @@ func TestValidatePrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - err := validatePrepare(tt.t) + err := validatePrepare(tt.t, 0) if (err != nil) != tt.expectErr { t.Errorf("unexpected err: %v", err) } }) } } - -type prepareDB struct{} - -const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' " - -func (*prepareDB) Timetravel(time.Duration) string { return asOfSystemTime } - -var defaultPrepareArgs = []reflect.Value{reflect.ValueOf(context.Background()), reflect.ValueOf(new(prepareDB))} - -func (*prepareDB) DatabaseName() string { return "db" } - -func (*prepareDB) Username() string { return "user" } - -func (*prepareDB) Type() string { return "type" } diff --git a/internal/query/privacy_policy.go b/internal/query/privacy_policy.go index 59394e92b1..e26948f478 100644 --- a/internal/query/privacy_policy.go +++ b/internal/query/privacy_policy.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -122,7 +121,7 @@ func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool if !withOwnerRemoved { eq[PrivacyColOwnerRemoved.identifier()] = false } - stmt, scan := preparePrivacyPolicyQuery(ctx, q.client) + stmt, scan := preparePrivacyPolicyQuery() query, args, err := stmt.Where( sq.And{ eq, @@ -154,7 +153,7 @@ func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bo traceSpan.EndWithError(err) } - stmt, scan := preparePrivacyPolicyQuery(ctx, q.client) + stmt, scan := preparePrivacyPolicyQuery() query, args, err := stmt.Where(sq.Eq{ PrivacyColID.identifier(): authz.GetInstance(ctx).InstanceID(), PrivacyColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -172,7 +171,7 @@ func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bo return policy, err } -func preparePrivacyPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PrivacyPolicy, error)) { +func preparePrivacyPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*PrivacyPolicy, error)) { return sq.Select( PrivacyColID.identifier(), PrivacyColSequence.identifier(), @@ -189,7 +188,7 @@ func preparePrivacyPolicyQuery(ctx context.Context, db prepareDatabase) (sq.Sele PrivacyColIsDefault.identifier(), PrivacyColState.identifier(), ). - From(privacyTable.identifier() + db.Timetravel(call.Took(ctx))). + From(privacyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*PrivacyPolicy, error) { policy := new(PrivacyPolicy) diff --git a/internal/query/privacy_policy_test.go b/internal/query/privacy_policy_test.go index ade541d0cc..1777ca1991 100644 --- a/internal/query/privacy_policy_test.go +++ b/internal/query/privacy_policy_test.go @@ -27,8 +27,7 @@ var ( ` projections.privacy_policies4.custom_link_text,` + ` projections.privacy_policies4.is_default,` + ` projections.privacy_policies4.state` + - ` FROM projections.privacy_policies4` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.privacy_policies4` preparePrivacyPolicyCols = []string{ "id", "sequence", @@ -138,7 +137,7 @@ func Test_PrivacyPolicyPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/project.go b/internal/query/project.go index a92448f25d..7501047182 100644 --- a/internal/query/project.go +++ b/internal/query/project.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -109,7 +108,7 @@ func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id st traceSpan.EndWithError(err) } - stmt, scan := prepareProjectQuery(ctx, q.client) + stmt, scan := prepareProjectQuery() eq := sq.Eq{ ProjectColumnID.identifier(): id, ProjectColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -130,7 +129,7 @@ func (q *Queries) SearchProjects(ctx context.Context, queries *ProjectSearchQuer ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareProjectsQuery(ctx, q.client) + query, scan := prepareProjectsQuery() eq := sq.Eq{ProjectColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -193,7 +192,7 @@ func (q *ProjectSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder return query } -func prepareProjectQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { +func prepareProjectQuery() (sq.SelectBuilder, func(*sql.Row) (*Project, error)) { return sq.Select( ProjectColumnID.identifier(), ProjectColumnCreationDate.identifier(), @@ -206,7 +205,7 @@ func prepareProjectQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil ProjectColumnProjectRoleCheck.identifier(), ProjectColumnHasProjectCheck.identifier(), ProjectColumnPrivateLabelingSetting.identifier()). - From(projectsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(projectsTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Project, error) { p := new(Project) @@ -233,7 +232,7 @@ func prepareProjectQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil } } -func prepareProjectsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Projects, error)) { +func prepareProjectsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Projects, error)) { return sq.Select( ProjectColumnID.identifier(), ProjectColumnCreationDate.identifier(), @@ -247,7 +246,7 @@ func prepareProjectsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui ProjectColumnHasProjectCheck.identifier(), ProjectColumnPrivateLabelingSetting.identifier(), countColumn.identifier()). - From(projectsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(projectsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Projects, error) { projects := make([]*Project, 0) diff --git a/internal/query/project_grant.go b/internal/query/project_grant.go index 1bc68e984a..b971593c77 100644 --- a/internal/query/project_grant.go +++ b/internal/query/project_grant.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" @@ -116,7 +115,7 @@ func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool, traceSpan.EndWithError(err) } - stmt, scan := prepareProjectGrantQuery(ctx, q.client) + stmt, scan := prepareProjectGrantQuery() eq := sq.Eq{ ProjectGrantColumnGrantID.identifier(): id, ProjectGrantColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -137,7 +136,7 @@ func (q *Queries) ProjectGrantByIDAndGrantedOrg(ctx context.Context, id, granted ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareProjectGrantQuery(ctx, q.client) + stmt, scan := prepareProjectGrantQuery() eq := sq.Eq{ ProjectGrantColumnGrantID.identifier(): id, ProjectGrantColumnGrantedOrgID.identifier(): grantedOrg, @@ -159,7 +158,7 @@ func (q *Queries) SearchProjectGrants(ctx context.Context, queries *ProjectGrant ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareProjectGrantsQuery(ctx, q.client) + query, scan := prepareProjectGrantsQuery() eq := sq.Eq{ ProjectGrantColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } @@ -264,7 +263,7 @@ func (q *ProjectGrantSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBui return query } -func prepareProjectGrantQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*ProjectGrant, error)) { +func prepareProjectGrantQuery() (sq.SelectBuilder, func(*sql.Row) (*ProjectGrant, error)) { resourceOwnerOrgTable := orgsTable.setAlias(ProjectGrantResourceOwnerTableAlias) resourceOwnerIDColumn := OrgColumnID.setTable(resourceOwnerOrgTable) grantedOrgTable := orgsTable.setAlias(ProjectGrantGrantedOrgTableAlias) @@ -286,7 +285,7 @@ func prepareProjectGrantQuery(ctx context.Context, db prepareDatabase) (sq.Selec PlaceholderFormat(sq.Dollar). LeftJoin(join(ProjectColumnID, ProjectGrantColumnProjectID)). LeftJoin(join(resourceOwnerIDColumn, ProjectGrantColumnResourceOwner)). - LeftJoin(join(grantedOrgIDColumn, ProjectGrantColumnGrantedOrgID) + db.Timetravel(call.Took(ctx))), + LeftJoin(join(grantedOrgIDColumn, ProjectGrantColumnGrantedOrgID)), func(row *sql.Row) (*ProjectGrant, error) { grant := new(ProjectGrant) var ( @@ -323,7 +322,7 @@ func prepareProjectGrantQuery(ctx context.Context, db prepareDatabase) (sq.Selec } } -func prepareProjectGrantsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*ProjectGrants, error)) { +func prepareProjectGrantsQuery() (sq.SelectBuilder, func(*sql.Rows) (*ProjectGrants, error)) { resourceOwnerOrgTable := orgsTable.setAlias(ProjectGrantResourceOwnerTableAlias) resourceOwnerIDColumn := OrgColumnID.setTable(resourceOwnerOrgTable) grantedOrgTable := orgsTable.setAlias(ProjectGrantGrantedOrgTableAlias) @@ -346,7 +345,7 @@ func prepareProjectGrantsQuery(ctx context.Context, db prepareDatabase) (sq.Sele PlaceholderFormat(sq.Dollar). LeftJoin(join(ProjectColumnID, ProjectGrantColumnProjectID)). LeftJoin(join(resourceOwnerIDColumn, ProjectGrantColumnResourceOwner)). - LeftJoin(join(grantedOrgIDColumn, ProjectGrantColumnGrantedOrgID) + db.Timetravel(call.Took(ctx))), + LeftJoin(join(grantedOrgIDColumn, ProjectGrantColumnGrantedOrgID)), func(rows *sql.Rows) (*ProjectGrants, error) { projects := make([]*ProjectGrant, 0) var ( diff --git a/internal/query/project_grant_member.go b/internal/query/project_grant_member.go index 0820ada826..a9cc49c498 100644 --- a/internal/query/project_grant_member.go +++ b/internal/query/project_grant_member.go @@ -7,7 +7,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/zerrors" @@ -82,7 +81,7 @@ func (q *ProjectGrantMembersQuery) toQuery(query sq.SelectBuilder) sq.SelectBuil } func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrantMembersQuery) (members *Members, err error) { - query, scan := prepareProjectGrantMembersQuery(ctx, q.client) + query, scan := prepareProjectGrantMembersQuery() eq := sq.Eq{ProjectGrantMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -106,7 +105,7 @@ func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrant return members, err } -func prepareProjectGrantMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { +func prepareProjectGrantMembersQuery() (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { return sq.Select( ProjectGrantMemberCreationDate.identifier(), ProjectGrantMemberChangeDate.identifier(), @@ -129,7 +128,7 @@ func prepareProjectGrantMembersQuery(ctx context.Context, db prepareDatabase) (s LeftJoin(join(MachineUserIDCol, ProjectGrantMemberUserID)). LeftJoin(join(UserIDCol, ProjectGrantMemberUserID)). LeftJoin(join(LoginNameUserIDCol, ProjectGrantMemberUserID)). - LeftJoin(join(ProjectGrantColumnGrantID, ProjectGrantMemberGrantID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(ProjectGrantColumnGrantID, ProjectGrantMemberGrantID)). Where( sq.Eq{LoginNameIsPrimaryCol.identifier(): true}, ).PlaceholderFormat(sq.Dollar), diff --git a/internal/query/project_grant_member_test.go b/internal/query/project_grant_member_test.go index 72eaf76d6e..23d1258b7c 100644 --- a/internal/query/project_grant_member_test.go +++ b/internal/query/project_grant_member_test.go @@ -46,7 +46,6 @@ var ( "LEFT JOIN projections.project_grants4 " + "ON members.grant_id = projections.project_grants4.grant_id " + "AND members.instance_id = projections.project_grants4.instance_id " + - `AS OF SYSTEM TIME '-1 ms' ` + "WHERE projections.login_names3.is_primary = $1") projectGrantMembersColumns = []string{ "creation_date", @@ -302,7 +301,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/project_grant_test.go b/internal/query/project_grant_test.go index 6d2131dfc4..2801e0b23e 100644 --- a/internal/query/project_grant_test.go +++ b/internal/query/project_grant_test.go @@ -30,8 +30,7 @@ var ( ` FROM projections.project_grants4 ` + ` LEFT JOIN projections.projects4 ON projections.project_grants4.project_id = projections.projects4.id AND projections.project_grants4.instance_id = projections.projects4.instance_id ` + ` LEFT JOIN projections.orgs1 AS r ON projections.project_grants4.resource_owner = r.id AND projections.project_grants4.instance_id = r.instance_id` + - ` LEFT JOIN projections.orgs1 AS o ON projections.project_grants4.granted_org_id = o.id AND projections.project_grants4.instance_id = o.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.orgs1 AS o ON projections.project_grants4.granted_org_id = o.id AND projections.project_grants4.instance_id = o.instance_id` projectGrantsCols = []string{ "project_id", "grant_id", @@ -62,8 +61,7 @@ var ( ` FROM projections.project_grants4 ` + ` LEFT JOIN projections.projects4 ON projections.project_grants4.project_id = projections.projects4.id AND projections.project_grants4.instance_id = projections.projects4.instance_id ` + ` LEFT JOIN projections.orgs1 AS r ON projections.project_grants4.resource_owner = r.id AND projections.project_grants4.instance_id = r.instance_id` + - ` LEFT JOIN projections.orgs1 AS o ON projections.project_grants4.granted_org_id = o.id AND projections.project_grants4.instance_id = o.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.orgs1 AS o ON projections.project_grants4.granted_org_id = o.id AND projections.project_grants4.instance_id = o.instance_id` projectGrantCols = []string{ "project_id", "grant_id", @@ -573,7 +571,7 @@ func Test_ProjectGrantPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/project_member.go b/internal/query/project_member.go index 347eac12b9..1b66b45ccc 100644 --- a/internal/query/project_member.go +++ b/internal/query/project_member.go @@ -7,7 +7,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -73,7 +72,7 @@ func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQue ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareProjectMembersQuery(ctx, q.client) + query, scan := prepareProjectMembersQuery() eq := sq.Eq{ProjectMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -97,7 +96,7 @@ func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQue return members, err } -func prepareProjectMembersQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { +func prepareProjectMembersQuery() (sq.SelectBuilder, func(*sql.Rows) (*Members, error)) { return sq.Select( ProjectMemberCreationDate.identifier(), ProjectMemberChangeDate.identifier(), @@ -119,7 +118,7 @@ func prepareProjectMembersQuery(ctx context.Context, db prepareDatabase) (sq.Sel LeftJoin(join(HumanUserIDCol, ProjectMemberUserID)). LeftJoin(join(MachineUserIDCol, ProjectMemberUserID)). LeftJoin(join(UserIDCol, ProjectMemberUserID)). - LeftJoin(join(LoginNameUserIDCol, ProjectMemberUserID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(LoginNameUserIDCol, ProjectMemberUserID)). Where( sq.Eq{LoginNameIsPrimaryCol.identifier(): true}, ).PlaceholderFormat(sq.Dollar), diff --git a/internal/query/project_member_test.go b/internal/query/project_member_test.go index 24548dadfb..6e552eb2ec 100644 --- a/internal/query/project_member_test.go +++ b/internal/query/project_member_test.go @@ -43,7 +43,6 @@ var ( "LEFT JOIN projections.login_names3 " + "ON members.user_id = projections.login_names3.user_id " + "AND members.instance_id = projections.login_names3.instance_id " + - `AS OF SYSTEM TIME '-1 ms' ` + "WHERE projections.login_names3.is_primary = $1") projectMembersColumns = []string{ "creation_date", @@ -299,7 +298,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/project_role.go b/internal/query/project_role.go index 76e113da65..ab4f40ca38 100644 --- a/internal/query/project_role.go +++ b/internal/query/project_role.go @@ -9,7 +9,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -94,7 +93,7 @@ func (q *Queries) SearchProjectRoles(ctx context.Context, shouldTriggerBulk bool eq := sq.Eq{ProjectRoleColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} - query, scan := prepareProjectRolesQuery(ctx, q.client) + query, scan := prepareProjectRolesQuery() stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-3N9ff", "Errors.Query.InvalidRequest") @@ -126,7 +125,7 @@ func (q *Queries) SearchGrantedProjectRoles(ctx context.Context, grantID, grante eq := sq.Eq{ProjectRoleColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} - query, scan := prepareProjectRolesQuery(ctx, q.client) + query, scan := prepareProjectRolesQuery() stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-3N9ff", "Errors.Query.InvalidRequest") @@ -207,7 +206,7 @@ func (q *ProjectRoleSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuil return query } -func prepareProjectRolesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*ProjectRoles, error)) { +func prepareProjectRolesQuery() (sq.SelectBuilder, func(*sql.Rows) (*ProjectRoles, error)) { return sq.Select( ProjectRoleColumnProjectID.identifier(), ProjectRoleColumnCreationDate.identifier(), @@ -218,7 +217,7 @@ func prepareProjectRolesQuery(ctx context.Context, db prepareDatabase) (sq.Selec ProjectRoleColumnDisplayName.identifier(), ProjectRoleColumnGroupName.identifier(), countColumn.identifier()). - From(projectRolesTable.identifier() + db.Timetravel(call.Took(ctx))). + From(projectRolesTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*ProjectRoles, error) { projects := make([]*ProjectRole, 0) diff --git a/internal/query/project_role_test.go b/internal/query/project_role_test.go index 516a4df169..468aafaa19 100644 --- a/internal/query/project_role_test.go +++ b/internal/query/project_role_test.go @@ -19,8 +19,7 @@ var ( ` projections.project_roles4.display_name,` + ` projections.project_roles4.group_name,` + ` COUNT(*) OVER ()` + - ` FROM projections.project_roles4` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.project_roles4` prepareProjectRolesCols = []string{ "project_id", "creation_date", @@ -175,7 +174,7 @@ func Test_ProjectRolePrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/project_test.go b/internal/query/project_test.go index a621c27f42..1eafcb69a8 100644 --- a/internal/query/project_test.go +++ b/internal/query/project_test.go @@ -39,8 +39,7 @@ var ( ` projections.projects4.has_project_check,` + ` projections.projects4.private_labeling_setting,` + ` COUNT(*) OVER ()` + - ` FROM projections.projects4` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.projects4` prepareProjectsCols = []string{ "id", "creation_date", @@ -67,8 +66,7 @@ var ( ` projections.projects4.project_role_check,` + ` projections.projects4.has_project_check,` + ` projections.projects4.private_labeling_setting` + - ` FROM projections.projects4` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.projects4` prepareProjectCols = []string{ "id", "creation_date", @@ -314,7 +312,7 @@ func Test_ProjectPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/query.go b/internal/query/query.go index 0a90e9e4f9..c0c051f7b7 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -112,10 +112,6 @@ func (q *Queries) Health(ctx context.Context) error { return q.client.Ping() } -type prepareDatabase interface { - Timetravel(d time.Duration) string -} - // cleanStaticQueries removes whitespaces, // such as ` `, \t, \n, from queries to improve // readability in logs and errors. diff --git a/internal/query/quota.go b/internal/query/quota.go index 50bc28cabc..77d2f3892b 100644 --- a/internal/query/quota.go +++ b/internal/query/quota.go @@ -62,7 +62,7 @@ type Quota struct { func (q *Queries) GetQuota(ctx context.Context, instanceID string, unit quota.Unit) (qu *Quota, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareQuotaQuery(ctx, q.client) + query, scan := prepareQuotaQuery() stmt, args, err := query.Where( sq.Eq{ QuotaColumnInstanceID.identifier(): instanceID, @@ -79,7 +79,7 @@ func (q *Queries) GetQuota(ctx context.Context, instanceID string, unit quota.Un return qu, err } -func prepareQuotaQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Quota, error)) { +func prepareQuotaQuery() (sq.SelectBuilder, func(*sql.Row) (*Quota, error)) { return sq. Select( QuotaColumnID.identifier(), diff --git a/internal/query/quota_notifications.go b/internal/query/quota_notifications.go index 0015278b20..9e3cb1a10c 100644 --- a/internal/query/quota_notifications.go +++ b/internal/query/quota_notifications.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/quota" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -59,7 +58,7 @@ func (q *Queries) GetDueQuotaNotifications(ctx context.Context, instanceID strin ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() usedRel := uint16(math.Floor(float64(usedAbs*100) / float64(qu.Amount))) - query, scan := prepareQuotaNotificationsQuery(ctx, q.client) + query, scan := prepareQuotaNotificationsQuery() stmt, args, err := query.Where( sq.And{ sq.Eq{ @@ -149,7 +148,7 @@ func calculateThreshold(usedRel, notificationPercent uint16) uint16 { return uint16(times+percent-1)*100 + notificationPercent } -func prepareQuotaNotificationsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*QuotaNotifications, error)) { +func prepareQuotaNotificationsQuery() (sq.SelectBuilder, func(*sql.Rows) (*QuotaNotifications, error)) { return sq.Select( QuotaNotificationColumnID.identifier(), QuotaNotificationColumnCallURL.identifier(), @@ -157,7 +156,7 @@ func prepareQuotaNotificationsQuery(ctx context.Context, db prepareDatabase) (sq QuotaNotificationColumnRepeat.identifier(), QuotaNotificationColumnNextDueThreshold.identifier(), ). - From(quotaNotificationsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(quotaNotificationsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*QuotaNotifications, error) { cfgs := &QuotaNotifications{Configs: []*QuotaNotification{}} for rows.Next() { diff --git a/internal/query/quota_notifications_test.go b/internal/query/quota_notifications_test.go index a86b31df57..5515d2b3a0 100644 --- a/internal/query/quota_notifications_test.go +++ b/internal/query/quota_notifications_test.go @@ -92,8 +92,7 @@ var ( ` projections.quotas_notifications.percent,` + ` projections.quotas_notifications.repeat,` + ` projections.quotas_notifications.next_due_threshold` + - ` FROM projections.quotas_notifications` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` FROM projections.quotas_notifications`) quotaNotificationsCols = []string{ "id", @@ -175,7 +174,7 @@ func Test_prepareQuotaNotificationsQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/quota_periods.go b/internal/query/quota_periods.go index 6ec42deba3..c954108c5f 100644 --- a/internal/query/quota_periods.go +++ b/internal/query/quota_periods.go @@ -7,7 +7,6 @@ import ( sq "github.com/Masterminds/squirrel" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/repository/quota" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -40,7 +39,7 @@ var ( func (q *Queries) GetRemainingQuotaUsage(ctx context.Context, instanceID string, unit quota.Unit) (remaining *uint64, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareRemainingQuotaUsageQuery(ctx, q.client) + stmt, scan := prepareRemainingQuotaUsageQuery() query, args, err := stmt.Where( sq.And{ sq.Eq{ @@ -66,13 +65,13 @@ func (q *Queries) GetRemainingQuotaUsage(ctx context.Context, instanceID string, return remaining, err } -func prepareRemainingQuotaUsageQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*uint64, error)) { +func prepareRemainingQuotaUsageQuery() (sq.SelectBuilder, func(*sql.Row) (*uint64, error)) { return sq. Select( "greatest(0, " + QuotaColumnAmount.identifier() + "-" + QuotaPeriodColumnUsage.identifier() + ")", ). From(quotaPeriodsTable.identifier()). - Join(join(QuotaColumnUnit, QuotaPeriodColumnUnit) + db.Timetravel(call.Took(ctx))). + Join(join(QuotaColumnUnit, QuotaPeriodColumnUnit)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*uint64, error) { remaining := new(uint64) err := row.Scan(remaining) diff --git a/internal/query/quota_periods_test.go b/internal/query/quota_periods_test.go index 0f44c5e547..385a49c557 100644 --- a/internal/query/quota_periods_test.go +++ b/internal/query/quota_periods_test.go @@ -14,8 +14,7 @@ import ( var ( expectedRemainingQuotaUsageQuery = regexp.QuoteMeta(`SELECT greatest(0, projections.quotas.amount-projections.quotas_periods.usage)` + ` FROM projections.quotas_periods` + - ` JOIN projections.quotas ON projections.quotas_periods.unit = projections.quotas.unit AND projections.quotas_periods.instance_id = projections.quotas.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` JOIN projections.quotas ON projections.quotas_periods.unit = projections.quotas.unit AND projections.quotas_periods.instance_id = projections.quotas.instance_id`) remainingQuotaUsageCols = []string{ "usage", } @@ -84,7 +83,7 @@ func Test_prepareRemainingQuotaUsageQuery(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/quota_test.go b/internal/query/quota_test.go index a92938e0cb..1e3ff1e9b2 100644 --- a/internal/query/quota_test.go +++ b/internal/query/quota_test.go @@ -110,7 +110,7 @@ func Test_QuotaPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/restrictions.go b/internal/query/restrictions.go index 9e0dd37aa6..8cff5737f7 100644 --- a/internal/query/restrictions.go +++ b/internal/query/restrictions.go @@ -10,7 +10,6 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" @@ -72,7 +71,7 @@ func (q *Queries) GetInstanceRestrictions(ctx context.Context) (restrictions Res ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareRestrictionsQuery(ctx, q.client) + stmt, scan := prepareRestrictionsQuery() instanceID := authz.GetInstance(ctx).InstanceID() query, args, err := stmt.Where(sq.Eq{ RestrictionsColumnInstanceID.identifier(): instanceID, @@ -92,7 +91,7 @@ func (q *Queries) GetInstanceRestrictions(ctx context.Context) (restrictions Res return restrictions, err } -func prepareRestrictionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (Restrictions, error)) { +func prepareRestrictionsQuery() (sq.SelectBuilder, func(*sql.Row) (Restrictions, error)) { return sq.Select( RestrictionsColumnAggregateID.identifier(), RestrictionsColumnCreationDate.identifier(), @@ -102,7 +101,7 @@ func prepareRestrictionsQuery(ctx context.Context, db prepareDatabase) (sq.Selec RestrictionsColumnDisallowPublicOrgRegistration.identifier(), RestrictionsColumnAllowedLanguages.identifier(), ). - From(restrictionsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(restrictionsTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (restrictions Restrictions, err error) { allowedLanguages := database.TextArray[string](make([]string, 0)) diff --git a/internal/query/restrictions_test.go b/internal/query/restrictions_test.go index cc7ee8442a..69ed81ef6d 100644 --- a/internal/query/restrictions_test.go +++ b/internal/query/restrictions_test.go @@ -21,8 +21,7 @@ var ( " projections.restrictions2.sequence," + " projections.restrictions2.disallow_public_org_registration," + " projections.restrictions2.allowed_languages" + - " FROM projections.restrictions2" + - " AS OF SYSTEM TIME '-1 ms'", + " FROM projections.restrictions2", ) restrictionsCols = []string{ @@ -115,7 +114,7 @@ func Test_RestrictionsPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.want.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.want.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/saml_request.go b/internal/query/saml_request.go index a81a1b2c34..784627bc59 100644 --- a/internal/query/saml_request.go +++ b/internal/query/saml_request.go @@ -5,13 +5,11 @@ import ( "database/sql" _ "embed" "errors" - "fmt" "time" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -39,10 +37,6 @@ func (a *SamlRequest) checkLoginClient(ctx context.Context, permissionCheck doma //go:embed saml_request_by_id.sql var samlRequestByIDQuery string -func (q *Queries) samlRequestByIDQuery(ctx context.Context) string { - return fmt.Sprintf(samlRequestByIDQuery, q.client.Timetravel(call.Took(ctx))) -} - func (q *Queries) SamlRequestByID(ctx context.Context, shouldTriggerBulk bool, id string, checkLoginClient bool) (_ *SamlRequest, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -62,7 +56,7 @@ func (q *Queries) SamlRequestByID(ctx context.Context, shouldTriggerBulk bool, i &dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.Issuer, &dst.ACS, &dst.RelayState, &dst.Binding, ) }, - q.samlRequestByIDQuery(ctx), + samlRequestByIDQuery, id, authz.GetInstance(ctx).InstanceID(), ) if errors.Is(err, sql.ErrNoRows) { diff --git a/internal/query/saml_request_by_id.sql b/internal/query/saml_request_by_id.sql index ac1c60058f..73eadb01ad 100644 --- a/internal/query/saml_request_by_id.sql +++ b/internal/query/saml_request_by_id.sql @@ -6,6 +6,6 @@ select acs, relay_state, binding -from projections.saml_requests %s +from projections.saml_requests where id = $1 and instance_id = $2 limit 1; diff --git a/internal/query/saml_request_test.go b/internal/query/saml_request_test.go index 6c6c2b6ebe..3a062ac5fd 100644 --- a/internal/query/saml_request_test.go +++ b/internal/query/saml_request_test.go @@ -22,7 +22,6 @@ import ( func TestQueries_SamlRequestByID(t *testing.T) { expQuery := regexp.QuoteMeta(fmt.Sprintf( samlRequestByIDQuery, - asOfSystemTime, )) cols := []string{ @@ -148,8 +147,7 @@ func TestQueries_SamlRequestByID(t *testing.T) { q := &Queries{ checkPermission: tt.permissionCheck, client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") diff --git a/internal/query/saml_sp_test.go b/internal/query/saml_sp_test.go index 4aafd95de1..35bf93c5fe 100644 --- a/internal/query/saml_sp_test.go +++ b/internal/query/saml_sp_test.go @@ -109,8 +109,7 @@ func TestQueries_ActiveSAMLServiceProviderByID(t *testing.T) { execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") diff --git a/internal/query/secret_generator_test.go b/internal/query/secret_generator_test.go index 683dc3441e..9ce8e71769 100644 --- a/internal/query/secret_generator_test.go +++ b/internal/query/secret_generator_test.go @@ -26,8 +26,7 @@ var ( ` projections.secret_generators2.include_upper_letters,` + ` projections.secret_generators2.include_digits,` + ` projections.secret_generators2.include_symbols` + - ` FROM projections.secret_generators2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.secret_generators2` prepareSecretGeneratorCols = []string{ "aggregate_id", "generator_type", @@ -55,8 +54,7 @@ var ( ` projections.secret_generators2.include_digits,` + ` projections.secret_generators2.include_symbols,` + ` COUNT(*) OVER ()` + - ` FROM projections.secret_generators2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.secret_generators2` prepareSecretGeneratorsCols = []string{ "aggregate_id", "generator_type", @@ -312,7 +310,7 @@ func Test_SecretGeneratorsPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/secret_generators.go b/internal/query/secret_generators.go index 8ee8694d2b..c267d7b290 100644 --- a/internal/query/secret_generators.go +++ b/internal/query/secret_generators.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" @@ -127,7 +126,7 @@ func (q *Queries) SecretGeneratorByType(ctx context.Context, generatorType domai defer func() { span.EndWithError(err) }() instanceID := authz.GetInstance(ctx).InstanceID() - stmt, scan := prepareSecretGeneratorQuery(ctx, q.client) + stmt, scan := prepareSecretGeneratorQuery() query, args, err := stmt.Where(sq.Eq{ SecretGeneratorColumnGeneratorType.identifier(): generatorType, SecretGeneratorColumnInstanceID.identifier(): instanceID, @@ -148,7 +147,7 @@ func (q *Queries) SearchSecretGenerators(ctx context.Context, queries *SecretGen ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareSecretGeneratorsQuery(ctx, q.client) + query, scan := prepareSecretGeneratorsQuery() stmt, args, err := queries.toQuery(query). Where(sq.Eq{ SecretGeneratorColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -180,7 +179,7 @@ func NewSecretGeneratorTypeSearchQuery(value int32) (SearchQuery, error) { return NewNumberQuery(SecretGeneratorColumnGeneratorType, value, NumberEquals) } -func prepareSecretGeneratorQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SecretGenerator, error)) { +func prepareSecretGeneratorQuery() (sq.SelectBuilder, func(*sql.Row) (*SecretGenerator, error)) { return sq.Select( SecretGeneratorColumnAggregateID.identifier(), SecretGeneratorColumnGeneratorType.identifier(), @@ -194,7 +193,7 @@ func prepareSecretGeneratorQuery(ctx context.Context, db prepareDatabase) (sq.Se SecretGeneratorColumnIncludeUpperLetters.identifier(), SecretGeneratorColumnIncludeDigits.identifier(), SecretGeneratorColumnIncludeSymbols.identifier()). - From(secretGeneratorsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(secretGeneratorsTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*SecretGenerator, error) { secretGenerator := new(SecretGenerator) @@ -222,7 +221,7 @@ func prepareSecretGeneratorQuery(ctx context.Context, db prepareDatabase) (sq.Se } } -func prepareSecretGeneratorsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*SecretGenerators, error)) { +func prepareSecretGeneratorsQuery() (sq.SelectBuilder, func(*sql.Rows) (*SecretGenerators, error)) { return sq.Select( SecretGeneratorColumnAggregateID.identifier(), SecretGeneratorColumnGeneratorType.identifier(), @@ -237,7 +236,7 @@ func prepareSecretGeneratorsQuery(ctx context.Context, db prepareDatabase) (sq.S SecretGeneratorColumnIncludeDigits.identifier(), SecretGeneratorColumnIncludeSymbols.identifier(), countColumn.identifier()). - From(secretGeneratorsTable.identifier() + db.Timetravel(call.Took(ctx))). + From(secretGeneratorsTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*SecretGenerators, error) { secretGenerators := make([]*SecretGenerator, 0) diff --git a/internal/query/security_policy.go b/internal/query/security_policy.go index 51938abdae..7a3fb3fa89 100644 --- a/internal/query/security_policy.go +++ b/internal/query/security_policy.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/zerrors" @@ -63,7 +62,7 @@ type SecurityPolicy struct { } func (q *Queries) SecurityPolicy(ctx context.Context) (policy *SecurityPolicy, err error) { - stmt, scan := prepareSecurityPolicyQuery(ctx, q.client) + stmt, scan := prepareSecurityPolicyQuery() query, args, err := stmt.Where(sq.Eq{ SecurityPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), }).ToSql() @@ -78,7 +77,7 @@ func (q *Queries) SecurityPolicy(ctx context.Context) (policy *SecurityPolicy, e return policy, err } -func prepareSecurityPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SecurityPolicy, error)) { +func prepareSecurityPolicyQuery() (sq.SelectBuilder, func(*sql.Row) (*SecurityPolicy, error)) { return sq.Select( SecurityPolicyColumnInstanceID.identifier(), SecurityPolicyColumnCreationDate.identifier(), @@ -88,7 +87,7 @@ func prepareSecurityPolicyQuery(ctx context.Context, db prepareDatabase) (sq.Sel SecurityPolicyColumnEnableIframeEmbedding.identifier(), SecurityPolicyColumnAllowedOrigins.identifier(), SecurityPolicyColumnEnableImpersonation.identifier()). - From(securityPolicyTable.identifier() + db.Timetravel(call.Took(ctx))). + From(securityPolicyTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*SecurityPolicy, error) { securityPolicy := new(SecurityPolicy) diff --git a/internal/query/session.go b/internal/query/session.go index 706465949e..111eb462a0 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -13,7 +13,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" @@ -261,7 +260,7 @@ func (q *Queries) sessionByID(ctx context.Context, shouldTriggerBulk bool, id st traceSpan.EndWithError(err) } - query, scan := prepareSessionQuery(ctx, q.client) + query, scan := prepareSessionQuery() stmt, args, err := query.Where( sq.Eq{ SessionColumnID.identifier(): id, @@ -297,7 +296,7 @@ func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQue ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareSessionsQuery(ctx, q.client) + query, scan := prepareSessionsQuery() stmt, args, err := queries.toQuery(query). Where(sq.Eq{ SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -343,7 +342,7 @@ func NewCreationDateQuery(datetime time.Time, compare TimestampComparison) (Sear return NewTimestampQuery(SessionColumnCreationDate, datetime, compare) } -func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) { +func prepareSessionQuery() (sq.SelectBuilder, func(*sql.Row) (*Session, string, error)) { return sq.Select( SessionColumnID.identifier(), SessionColumnCreationDate.identifier(), @@ -374,7 +373,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil ).From(sessionsTable.identifier()). LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)). LeftJoin(join(HumanUserIDCol, SessionColumnUserID)). - LeftJoin(join(UserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(UserIDCol, SessionColumnUserID)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Session, string, error) { session := new(Session) @@ -456,7 +455,7 @@ func prepareSessionQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil } } -func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Sessions, error)) { +func prepareSessionsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Sessions, error)) { return sq.Select( SessionColumnID.identifier(), SessionColumnCreationDate.identifier(), @@ -487,7 +486,7 @@ func prepareSessionsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBui ).From(sessionsTable.identifier()). LeftJoin(join(LoginNameUserIDCol, SessionColumnUserID)). LeftJoin(join(HumanUserIDCol, SessionColumnUserID)). - LeftJoin(join(UserIDCol, SessionColumnUserID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(UserIDCol, SessionColumnUserID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Sessions, error) { sessions := &Sessions{Sessions: []*Session{}} diff --git a/internal/query/sessions_test.go b/internal/query/sessions_test.go index ba897e6062..e0d9cfda71 100644 --- a/internal/query/sessions_test.go +++ b/internal/query/sessions_test.go @@ -50,8 +50,7 @@ var ( ` FROM projections.sessions8` + ` LEFT JOIN projections.login_names3 ON projections.sessions8.user_id = projections.login_names3.user_id AND projections.sessions8.instance_id = projections.login_names3.instance_id` + ` LEFT JOIN projections.users14_humans ON projections.sessions8.user_id = projections.users14_humans.user_id AND projections.sessions8.instance_id = projections.users14_humans.instance_id` + - ` LEFT JOIN projections.users14 ON projections.sessions8.user_id = projections.users14.id AND projections.sessions8.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.users14 ON projections.sessions8.user_id = projections.users14.id AND projections.sessions8.instance_id = projections.users14.instance_id`) expectedSessionsQuery = regexp.QuoteMeta(`SELECT projections.sessions8.id,` + ` projections.sessions8.creation_date,` + ` projections.sessions8.change_date,` + @@ -81,8 +80,7 @@ var ( ` FROM projections.sessions8` + ` LEFT JOIN projections.login_names3 ON projections.sessions8.user_id = projections.login_names3.user_id AND projections.sessions8.instance_id = projections.login_names3.instance_id` + ` LEFT JOIN projections.users14_humans ON projections.sessions8.user_id = projections.users14_humans.user_id AND projections.sessions8.instance_id = projections.users14_humans.instance_id` + - ` LEFT JOIN projections.users14 ON projections.sessions8.user_id = projections.users14.id AND projections.sessions8.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.users14 ON projections.sessions8.user_id = projections.users14.id AND projections.sessions8.instance_id = projections.users14.instance_id`) sessionCols = []string{ "id", @@ -440,7 +438,7 @@ func Test_SessionsPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -577,14 +575,14 @@ func Test_SessionPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } -func prepareSessionQueryTesting(t *testing.T, token string) func(context.Context, prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, error)) { - return func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Session, error)) { - builder, scan := prepareSessionQuery(ctx, db) +func prepareSessionQueryTesting(t *testing.T, token string) func() (sq.SelectBuilder, func(*sql.Row) (*Session, error)) { + return func() (sq.SelectBuilder, func(*sql.Row) (*Session, error)) { + builder, scan := prepareSessionQuery() return builder, func(row *sql.Row) (*Session, error) { session, tokenID, err := scan(row) require.Equal(t, tokenID, token) diff --git a/internal/query/sms.go b/internal/query/sms.go index 310d3d0f14..3659f05daf 100644 --- a/internal/query/sms.go +++ b/internal/query/sms.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" @@ -149,7 +148,7 @@ func (q *Queries) SMSProviderConfigByID(ctx context.Context, id string) (config ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareSMSConfigQuery(ctx, q.client) + query, scan := prepareSMSConfigQuery() stmt, args, err := query.Where( sq.Eq{ SMSColumnID.identifier(): id, @@ -171,7 +170,7 @@ func (q *Queries) SMSProviderConfigActive(ctx context.Context, instanceID string ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareSMSConfigQuery(ctx, q.client) + query, scan := prepareSMSConfigQuery() stmt, args, err := query.Where( sq.Eq{ SMSColumnInstanceID.identifier(): instanceID, @@ -193,7 +192,7 @@ func (q *Queries) SearchSMSConfigs(ctx context.Context, queries *SMSConfigsSearc ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareSMSConfigsQuery(ctx, q.client) + query, scan := prepareSMSConfigsQuery() stmt, args, err := queries.toQuery(query). Where(sq.Eq{ SMSColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -217,7 +216,7 @@ func NewSMSProviderStateQuery(state domain.SMSConfigState) (SearchQuery, error) return NewNumberQuery(SMSColumnState, state, NumberEquals) } -func prepareSMSConfigQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SMSConfig, error)) { +func prepareSMSConfigQuery() (sq.SelectBuilder, func(*sql.Row) (*SMSConfig, error)) { return sq.Select( SMSColumnID.identifier(), SMSColumnAggregateID.identifier(), @@ -238,7 +237,7 @@ func prepareSMSConfigQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu SMSHTTPColumnEndpoint.identifier(), ).From(smsConfigsTable.identifier()). LeftJoin(join(SMSTwilioColumnSMSID, SMSColumnID)). - LeftJoin(join(SMSHTTPColumnSMSID, SMSColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(SMSHTTPColumnSMSID, SMSColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*SMSConfig, error) { config := new(SMSConfig) @@ -281,7 +280,7 @@ func prepareSMSConfigQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu } } -func prepareSMSConfigsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*SMSConfigs, error)) { +func prepareSMSConfigsQuery() (sq.SelectBuilder, func(*sql.Rows) (*SMSConfigs, error)) { return sq.Select( SMSColumnID.identifier(), SMSColumnAggregateID.identifier(), @@ -304,7 +303,7 @@ func prepareSMSConfigsQuery(ctx context.Context, db prepareDatabase) (sq.SelectB countColumn.identifier(), ).From(smsConfigsTable.identifier()). LeftJoin(join(SMSTwilioColumnSMSID, SMSColumnID)). - LeftJoin(join(SMSHTTPColumnSMSID, SMSColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(SMSHTTPColumnSMSID, SMSColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Rows) (*SMSConfigs, error) { configs := &SMSConfigs{Configs: []*SMSConfig{}} diff --git a/internal/query/sms_test.go b/internal/query/sms_test.go index 82c3659f2c..e6e79d72bc 100644 --- a/internal/query/sms_test.go +++ b/internal/query/sms_test.go @@ -35,8 +35,7 @@ var ( ` projections.sms_configs3_http.endpoint` + ` FROM projections.sms_configs3` + ` LEFT JOIN projections.sms_configs3_twilio ON projections.sms_configs3.id = projections.sms_configs3_twilio.sms_id AND projections.sms_configs3.instance_id = projections.sms_configs3_twilio.instance_id` + - ` LEFT JOIN projections.sms_configs3_http ON projections.sms_configs3.id = projections.sms_configs3_http.sms_id AND projections.sms_configs3.instance_id = projections.sms_configs3_http.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.sms_configs3_http ON projections.sms_configs3.id = projections.sms_configs3_http.sms_id AND projections.sms_configs3.instance_id = projections.sms_configs3_http.instance_id`) expectedSMSConfigsQuery = regexp.QuoteMeta(`SELECT projections.sms_configs3.id,` + ` projections.sms_configs3.aggregate_id,` + ` projections.sms_configs3.creation_date,` + @@ -59,8 +58,7 @@ var ( ` COUNT(*) OVER ()` + ` FROM projections.sms_configs3` + ` LEFT JOIN projections.sms_configs3_twilio ON projections.sms_configs3.id = projections.sms_configs3_twilio.sms_id AND projections.sms_configs3.instance_id = projections.sms_configs3_twilio.instance_id` + - ` LEFT JOIN projections.sms_configs3_http ON projections.sms_configs3.id = projections.sms_configs3_http.sms_id AND projections.sms_configs3.instance_id = projections.sms_configs3_http.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'`) + ` LEFT JOIN projections.sms_configs3_http ON projections.sms_configs3.id = projections.sms_configs3_http.sms_id AND projections.sms_configs3.instance_id = projections.sms_configs3_http.instance_id`) smsConfigCols = []string{ "id", @@ -353,7 +351,7 @@ func Test_SMSConfigsPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } @@ -494,7 +492,7 @@ func Test_SMSConfigPrepare(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/smtp.go b/internal/query/smtp.go index 7c45fe33fe..4238ec121e 100644 --- a/internal/query/smtp.go +++ b/internal/query/smtp.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" @@ -153,7 +152,7 @@ func (q *Queries) SMTPConfigActive(ctx context.Context, resourceOwner string) (c ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareSMTPConfigQuery(ctx, q.client) + stmt, scan := prepareSMTPConfigQuery() query, args, err := stmt.Where(sq.Eq{ SMTPConfigColumnResourceOwner.identifier(): resourceOwner, SMTPConfigColumnInstanceID.identifier(): resourceOwner, @@ -174,7 +173,7 @@ func (q *Queries) SMTPConfigByID(ctx context.Context, instanceID, id string) (co ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - stmt, scan := prepareSMTPConfigQuery(ctx, q.client) + stmt, scan := prepareSMTPConfigQuery() query, args, err := stmt.Where(sq.Eq{ SMTPConfigColumnInstanceID.identifier(): instanceID, SMTPConfigColumnID.identifier(): id, @@ -190,7 +189,7 @@ func (q *Queries) SMTPConfigByID(ctx context.Context, instanceID, id string) (co return config, err } -func prepareSMTPConfigQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SMTPConfig, error)) { +func prepareSMTPConfigQuery() (sq.SelectBuilder, func(*sql.Row) (*SMTPConfig, error)) { password := new(crypto.CryptoValue) return sq.Select( @@ -215,7 +214,7 @@ func prepareSMTPConfigQuery(ctx context.Context, db prepareDatabase) (sq.SelectB SMTPConfigHTTPColumnEndpoint.identifier()). From(smtpConfigsTable.identifier()). LeftJoin(join(SMTPConfigSMTPColumnID, SMTPConfigColumnID)). - LeftJoin(join(SMTPConfigHTTPColumnID, SMTPConfigColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(SMTPConfigHTTPColumnID, SMTPConfigColumnID)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*SMTPConfig, error) { config := new(SMTPConfig) @@ -255,7 +254,7 @@ func prepareSMTPConfigQuery(ctx context.Context, db prepareDatabase) (sq.SelectB } } -func prepareSMTPConfigsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*SMTPConfigs, error)) { +func prepareSMTPConfigsQuery() (sq.SelectBuilder, func(*sql.Rows) (*SMTPConfigs, error)) { return sq.Select( SMTPConfigColumnCreationDate.identifier(), SMTPConfigColumnChangeDate.identifier(), @@ -279,7 +278,7 @@ func prepareSMTPConfigsQuery(ctx context.Context, db prepareDatabase) (sq.Select countColumn.identifier(), ).From(smtpConfigsTable.identifier()). LeftJoin(join(SMTPConfigSMTPColumnID, SMTPConfigColumnID)). - LeftJoin(join(SMTPConfigHTTPColumnID, SMTPConfigColumnID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(SMTPConfigHTTPColumnID, SMTPConfigColumnID)). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*SMTPConfigs, error) { configs := &SMTPConfigs{Configs: []*SMTPConfig{}} @@ -329,7 +328,7 @@ func (q *Queries) SearchSMTPConfigs(ctx context.Context, queries *SMTPConfigsSea ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareSMTPConfigsQuery(ctx, q.client) + query, scan := prepareSMTPConfigsQuery() stmt, args, err := queries.toQuery(query). Where(sq.Eq{ SMTPConfigColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), diff --git a/internal/query/smtp_test.go b/internal/query/smtp_test.go index 4d12edcbd3..68ace249aa 100644 --- a/internal/query/smtp_test.go +++ b/internal/query/smtp_test.go @@ -33,8 +33,7 @@ var ( ` projections.smtp_configs5_http.endpoint` + ` FROM projections.smtp_configs5` + ` LEFT JOIN projections.smtp_configs5_smtp ON projections.smtp_configs5.id = projections.smtp_configs5_smtp.id AND projections.smtp_configs5.instance_id = projections.smtp_configs5_smtp.instance_id` + - ` LEFT JOIN projections.smtp_configs5_http ON projections.smtp_configs5.id = projections.smtp_configs5_http.id AND projections.smtp_configs5.instance_id = projections.smtp_configs5_http.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.smtp_configs5_http ON projections.smtp_configs5.id = projections.smtp_configs5_http.id AND projections.smtp_configs5.instance_id = projections.smtp_configs5_http.instance_id` prepareSMTPConfigCols = []string{ "creation_date", "change_date", @@ -287,7 +286,7 @@ func Test_SMTPConfigsPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/system_features_test.go b/internal/query/system_features_test.go index e460d38cec..fcd0f812f5 100644 --- a/internal/query/system_features_test.go +++ b/internal/query/system_features_test.go @@ -45,23 +45,23 @@ func TestQueries_GetSystemFeatures(t *testing.T) { name: "all features set", eventstore: expectEventstore( expectFilter( - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLoginDefaultOrgEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemTriggerIntrospectionProjectionsEventType, true, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLegacyIntrospectionEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemUserSchemaEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemActionsEventType, true, )), @@ -97,23 +97,23 @@ func TestQueries_GetSystemFeatures(t *testing.T) { name: "all features set, reset, set some feature", eventstore: expectEventstore( expectFilter( - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLoginDefaultOrgEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemTriggerIntrospectionProjectionsEventType, true, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLegacyIntrospectionEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemUserSchemaEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemActionsEventType, false, )), @@ -121,7 +121,7 @@ func TestQueries_GetSystemFeatures(t *testing.T) { context.Background(), aggregate, feature_v2.SystemResetEventType, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemTriggerIntrospectionProjectionsEventType, true, )), @@ -157,23 +157,23 @@ func TestQueries_GetSystemFeatures(t *testing.T) { name: "all features set, reset, set some feature, not cascaded", eventstore: expectEventstore( expectFilter( - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLoginDefaultOrgEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemTriggerIntrospectionProjectionsEventType, true, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemLegacyIntrospectionEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemUserSchemaEventType, false, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemActionsEventType, false, )), @@ -181,7 +181,7 @@ func TestQueries_GetSystemFeatures(t *testing.T) { context.Background(), aggregate, feature_v2.SystemResetEventType, )), - eventFromEventPusher(feature_v2.NewSetEvent[bool]( + eventFromEventPusher(feature_v2.NewSetEvent( context.Background(), aggregate, feature_v2.SystemTriggerIntrospectionProjectionsEventType, true, )), diff --git a/internal/query/target.go b/internal/query/target.go index 03db85236c..d9b50f4a14 100644 --- a/internal/query/target.go +++ b/internal/query/target.go @@ -116,8 +116,8 @@ func (q *Queries) SearchTargets(ctx context.Context, queries *TargetSearchQuerie eq := sq.Eq{ TargetColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } - query, scan := prepareTargetsQuery(ctx, q.client) - targets, err := genericRowsQueryWithState[*Targets](ctx, q.client, targetTable, combineToWhereStmt(query, queries.toQuery, eq), scan) + query, scan := prepareTargetsQuery() + targets, err := genericRowsQueryWithState(ctx, q.client, targetTable, combineToWhereStmt(query, queries.toQuery, eq), scan) if err != nil { return nil, err } @@ -134,8 +134,8 @@ func (q *Queries) GetTargetByID(ctx context.Context, id string) (*Target, error) TargetColumnID.identifier(): id, TargetColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } - query, scan := prepareTargetQuery(ctx, q.client) - target, err := genericRowQuery[*Target](ctx, q.client, query.Where(eq), scan) + query, scan := prepareTargetQuery() + target, err := genericRowQuery(ctx, q.client, query.Where(eq), scan) if err != nil { return nil, err } @@ -153,7 +153,7 @@ func NewTargetInIDsSearchQuery(values []string) (SearchQuery, error) { return NewInTextQuery(TargetColumnID, values) } -func prepareTargetsQuery(context.Context, prepareDatabase) (sq.SelectBuilder, func(rows *sql.Rows) (*Targets, error)) { +func prepareTargetsQuery() (sq.SelectBuilder, func(rows *sql.Rows) (*Targets, error)) { return sq.Select( TargetColumnID.identifier(), TargetColumnCreationDate.identifier(), @@ -205,7 +205,7 @@ func prepareTargetsQuery(context.Context, prepareDatabase) (sq.SelectBuilder, fu } } -func prepareTargetQuery(context.Context, prepareDatabase) (sq.SelectBuilder, func(row *sql.Row) (*Target, error)) { +func prepareTargetQuery() (sq.SelectBuilder, func(row *sql.Row) (*Target, error)) { return sq.Select( TargetColumnID.identifier(), TargetColumnCreationDate.identifier(), diff --git a/internal/query/target_test.go b/internal/query/target_test.go index aa1ad517b7..ef564bf236 100644 --- a/internal/query/target_test.go +++ b/internal/query/target_test.go @@ -372,7 +372,7 @@ func Test_TargetPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/user.go b/internal/query/user.go index 3ee9a48463..233faf4aa9 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -13,7 +13,6 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" @@ -437,7 +436,7 @@ func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, queries . triggerUserProjections(ctx) } - query, scan := prepareUserQuery(ctx, q.client) + query, scan := prepareUserQuery() for _, q := range queries { query = q.toQuery(query) } @@ -460,7 +459,7 @@ func (q *Queries) GetHumanProfile(ctx context.Context, userID string, queries .. ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareProfileQuery(ctx, q.client) + query, scan := prepareProfileQuery() for _, q := range queries { query = q.toQuery(query) } @@ -484,7 +483,7 @@ func (q *Queries) GetHumanEmail(ctx context.Context, userID string, queries ...S ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareEmailQuery(ctx, q.client) + query, scan := prepareEmailQuery() for _, q := range queries { query = q.toQuery(query) } @@ -508,7 +507,7 @@ func (q *Queries) GetHumanPhone(ctx context.Context, userID string, queries ...S ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := preparePhoneQuery(ctx, q.client) + query, scan := preparePhoneQuery() for _, q := range queries { query = q.toQuery(query) } @@ -595,7 +594,7 @@ func (q *Queries) GetNotifyUser(ctx context.Context, shouldTriggered bool, queri triggerUserProjections(ctx) } - query, scan := prepareNotifyUserQuery(ctx, q.client) + query, scan := prepareNotifyUserQuery() for _, q := range queries { query = q.toQuery(query) } @@ -650,7 +649,7 @@ func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, f ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareUsersQuery(ctx, q.client) + query, scan := prepareUsersQuery() query = queries.toQuery(query).Where(sq.Eq{ UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), }) @@ -678,7 +677,7 @@ func (q *Queries) IsUserUnique(ctx context.Context, username, email, resourceOwn ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareUserUniqueQuery(ctx, q.client) + query, scan := prepareUserUniqueQuery() queries := make([]SearchQuery, 0, 3) if username != "" { usernameQuery, err := NewUserUsernameSearchQuery(username, TextEquals) @@ -961,7 +960,7 @@ func scanUser(row *sql.Row) (*User, error) { return u, nil } -func prepareUserQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*User, error)) { +func prepareUserQuery() (sq.SelectBuilder, func(*sql.Row) (*User, error)) { loginNamesQuery, loginNamesArgs, err := prepareLoginNamesQuery() if err != nil { return sq.SelectBuilder{}, nil @@ -1012,14 +1011,14 @@ func prepareUserQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder loginNamesArgs...). LeftJoin("("+preferredLoginNameQuery+") AS "+userPreferredLoginNameTable.alias+" ON "+ userPreferredLoginNameUserIDCol.identifier()+" = "+UserIDCol.identifier()+" AND "+ - userPreferredLoginNameInstanceIDCol.identifier()+" = "+UserInstanceIDCol.identifier()+db.Timetravel(call.Took(ctx)), + userPreferredLoginNameInstanceIDCol.identifier()+" = "+UserInstanceIDCol.identifier(), preferredLoginNameArgs...). PlaceholderFormat(sq.Dollar), scanUser } -func prepareProfileQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Profile, error)) { +func prepareProfileQuery() (sq.SelectBuilder, func(*sql.Row) (*Profile, error)) { return sq.Select( UserIDCol.identifier(), UserCreationDateCol.identifier(), @@ -1035,7 +1034,7 @@ func prepareProfileQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil HumanGenderCol.identifier(), HumanAvatarURLCol.identifier()). From(userTable.identifier()). - LeftJoin(join(HumanUserIDCol, UserIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(HumanUserIDCol, UserIDCol)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Profile, error) { p := new(Profile) @@ -1085,7 +1084,7 @@ func prepareProfileQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuil } } -func prepareEmailQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Email, error)) { +func prepareEmailQuery() (sq.SelectBuilder, func(*sql.Row) (*Email, error)) { return sq.Select( UserIDCol.identifier(), UserCreationDateCol.identifier(), @@ -1096,7 +1095,7 @@ func prepareEmailQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilde HumanEmailCol.identifier(), HumanIsEmailVerifiedCol.identifier()). From(userTable.identifier()). - LeftJoin(join(HumanUserIDCol, UserIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(HumanUserIDCol, UserIDCol)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Email, error) { e := new(Email) @@ -1132,7 +1131,7 @@ func prepareEmailQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilde } } -func preparePhoneQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*Phone, error)) { +func preparePhoneQuery() (sq.SelectBuilder, func(*sql.Row) (*Phone, error)) { return sq.Select( UserIDCol.identifier(), UserCreationDateCol.identifier(), @@ -1143,7 +1142,7 @@ func preparePhoneQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilde HumanPhoneCol.identifier(), HumanIsPhoneVerifiedCol.identifier()). From(userTable.identifier()). - LeftJoin(join(HumanUserIDCol, UserIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(HumanUserIDCol, UserIDCol)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Phone, error) { e := new(Phone) @@ -1179,7 +1178,7 @@ func preparePhoneQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilde } } -func prepareNotifyUserQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*NotifyUser, error)) { +func prepareNotifyUserQuery() (sq.SelectBuilder, func(*sql.Row) (*NotifyUser, error)) { loginNamesQuery, loginNamesArgs, err := prepareLoginNamesQuery() if err != nil { return sq.SelectBuilder{}, nil @@ -1224,7 +1223,7 @@ func prepareNotifyUserQuery(ctx context.Context, db prepareDatabase) (sq.SelectB loginNamesArgs...). LeftJoin("("+preferredLoginNameQuery+") AS "+userPreferredLoginNameTable.alias+" ON "+ userPreferredLoginNameUserIDCol.identifier()+" = "+UserIDCol.identifier()+" AND "+ - userPreferredLoginNameInstanceIDCol.identifier()+" = "+UserInstanceIDCol.identifier()+db.Timetravel(call.Took(ctx)), + userPreferredLoginNameInstanceIDCol.identifier()+" = "+UserInstanceIDCol.identifier(), preferredLoginNameArgs...). PlaceholderFormat(sq.Dollar), scanNotifyUser @@ -1331,7 +1330,7 @@ func prepareCountUsersQuery() (sq.SelectBuilder, func(*sql.Rows) (uint64, error) } } -func prepareUserUniqueQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (bool, error)) { +func prepareUserUniqueQuery() (sq.SelectBuilder, func(*sql.Row) (bool, error)) { return sq.Select( UserIDCol.identifier(), UserStateCol.identifier(), @@ -1340,7 +1339,7 @@ func prepareUserUniqueQuery(ctx context.Context, db prepareDatabase) (sq.SelectB HumanEmailCol.identifier(), HumanIsEmailVerifiedCol.identifier()). From(userTable.identifier()). - LeftJoin(join(HumanUserIDCol, UserIDCol) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(HumanUserIDCol, UserIDCol)). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (bool, error) { userID := sql.NullString{} @@ -1368,7 +1367,7 @@ func prepareUserUniqueQuery(ctx context.Context, db prepareDatabase) (sq.SelectB } } -func prepareUsersQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Users, error)) { +func prepareUsersQuery() (sq.SelectBuilder, func(*sql.Rows) (*Users, error)) { loginNamesQuery, loginNamesArgs, err := prepareLoginNamesQuery() if err != nil { return sq.SelectBuilder{}, nil @@ -1417,7 +1416,7 @@ func prepareUsersQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilde loginNamesArgs...). LeftJoin("("+preferredLoginNameQuery+") AS "+userPreferredLoginNameTable.alias+" ON "+ userPreferredLoginNameUserIDCol.identifier()+" = "+UserIDCol.identifier()+" AND "+ - userPreferredLoginNameInstanceIDCol.identifier()+" = "+UserInstanceIDCol.identifier()+db.Timetravel(call.Took(ctx)), + userPreferredLoginNameInstanceIDCol.identifier()+" = "+UserInstanceIDCol.identifier(), preferredLoginNameArgs...). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*Users, error) { diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index 0687545aef..6949b6bec3 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -158,7 +157,7 @@ func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMe ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareUserAuthMethodsQuery(ctx, q.client) + query, scan := prepareUserAuthMethodsQuery() stmt, args, err := queries.toQuery(query).Where(sq.Eq{UserAuthMethodColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest") @@ -185,7 +184,7 @@ func (q *Queries) ListUserAuthMethodTypes(ctx context.Context, userID string, ac ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareUserAuthMethodTypesQuery(ctx, q.client, activeOnly, includeWithoutDomain, queryDomain) + query, scan := prepareUserAuthMethodTypesQuery(activeOnly, includeWithoutDomain, queryDomain) eq := sq.Eq{ UserIDCol.identifier(): userID, UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -222,7 +221,7 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, q.client) + query, scan := prepareUserAuthMethodTypesRequiredQuery() eq := sq.Eq{ UserIDCol.identifier(): userID, UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -349,7 +348,7 @@ func (q *UserAuthMethodSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectB return query } -func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethods, error)) { +func prepareUserAuthMethodsQuery() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethods, error)) { return sq.Select( UserAuthMethodColumnTokenID.identifier(), UserAuthMethodColumnCreationDate.identifier(), @@ -361,7 +360,7 @@ func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.Se UserAuthMethodColumnState.identifier(), UserAuthMethodColumnMethodType.identifier(), countColumn.identifier()). - From(userAuthMethodTable.identifier() + db.Timetravel(call.Took(ctx))). + From(userAuthMethodTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*AuthMethods, error) { userAuthMethods := make([]*AuthMethod, 0) @@ -399,7 +398,7 @@ func prepareUserAuthMethodsQuery(ctx context.Context, db prepareDatabase) (sq.Se } } -func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, activeOnly bool, includeWithoutDomain bool, queryDomain string) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { +func prepareUserAuthMethodTypesQuery(activeOnly bool, includeWithoutDomain bool, queryDomain string) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery(activeOnly, includeWithoutDomain, queryDomain) if err != nil { return sq.SelectBuilder{}, nil @@ -420,7 +419,7 @@ func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, ac authMethodsArgs...). LeftJoin("(" + idpsQuery + ") AS " + userIDPsCountTable.alias + " ON " + userIDPsCountUserID.identifier() + " = " + UserIDCol.identifier() + " AND " + - userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))). + userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*AuthMethodTypes, error) { userAuthMethodTypes := make([]domain.UserAuthMethodType, 0) @@ -461,7 +460,7 @@ func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, ac } } -func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { +func prepareUserAuthMethodTypesRequiredQuery() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery() if err != nil { return sq.SelectBuilder{}, nil diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go index 041e4f8e9e..18d2b41bb9 100644 --- a/internal/query/user_auth_method_test.go +++ b/internal/query/user_auth_method_test.go @@ -191,8 +191,7 @@ var ( ` projections.user_auth_methods5.state,` + ` projections.user_auth_methods5.method_type,` + ` COUNT(*) OVER ()` + - ` FROM projections.user_auth_methods5` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.user_auth_methods5` prepareUserAuthMethodsCols = []string{ "token_id", "creation_date", @@ -215,8 +214,7 @@ var ( ` ON auth_method_types.user_id = projections.users14.id AND auth_method_types.instance_id = projections.users14.instance_id` + ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + - ` ON user_idps_count.user_id = projections.users14.id AND user_idps_count.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms` + ` ON user_idps_count.user_id = projections.users14.id AND user_idps_count.instance_id = projections.users14.instance_id` prepareActiveAuthMethodTypesCols = []string{ "password_set", "method_type", @@ -232,8 +230,7 @@ var ( ` ON auth_method_types.user_id = projections.users14.id AND auth_method_types.instance_id = projections.users14.instance_id` + ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + - ` ON user_idps_count.user_id = projections.users14.id AND user_idps_count.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms` + ` ON user_idps_count.user_id = projections.users14.id AND user_idps_count.instance_id = projections.users14.instance_id` prepareActiveAuthMethodTypesDomainCols = []string{ "password_set", "method_type", @@ -249,8 +246,7 @@ var ( ` ON auth_method_types.user_id = projections.users14.id AND auth_method_types.instance_id = projections.users14.instance_id` + ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + - ` ON user_idps_count.user_id = projections.users14.id AND user_idps_count.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms` + ` ON user_idps_count.user_id = projections.users14.id AND user_idps_count.instance_id = projections.users14.instance_id` prepareActiveAuthMethodTypesDomainExternalCols = []string{ "password_set", "method_type", @@ -417,8 +413,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, true, "") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, true, "") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -434,8 +430,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery one second factor", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, true, "") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, true, "") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -466,8 +462,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery one second factor with domain", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, true, "example.com") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, true, "example.com") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -498,8 +494,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery one second factor with domain external", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, false, "example.com") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, false, "example.com") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -530,8 +526,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery multiple second factors", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, true, "") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, true, "") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -568,8 +564,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery multiple second factors domain", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, true, "example.com") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, true, "example.com") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -606,8 +602,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery multiple second factors domain external", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, false, "example.com") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, false, "example.com") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -644,8 +640,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { - builder, scan := prepareUserAuthMethodTypesQuery(ctx, db, true, true, "") + prepare: func() (sq.SelectBuilder, func(*sql.Rows) (*AuthMethodTypes, error)) { + builder, scan := prepareUserAuthMethodTypesQuery(true, true, "") return builder, func(rows *sql.Rows) (*AuthMethodTypes, error) { return scan(rows) } @@ -666,8 +662,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { + builder, scan := prepareUserAuthMethodTypesRequiredQuery() return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { return scan(row) } @@ -689,8 +685,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery one second factor", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { + builder, scan := prepareUserAuthMethodTypesRequiredQuery() return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { return scan(row) } @@ -716,8 +712,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { + builder, scan := prepareUserAuthMethodTypesRequiredQuery() return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { return scan(row) } @@ -744,8 +740,8 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { - builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) + prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { + builder, scan := prepareUserAuthMethodTypesRequiredQuery() return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { return scan(row) } @@ -767,7 +763,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/user_grant.go b/internal/query/user_grant.go index 265d8eaae1..c3f24c066e 100644 --- a/internal/query/user_grant.go +++ b/internal/query/user_grant.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" @@ -246,7 +245,7 @@ func (q *Queries) UserGrant(ctx context.Context, shouldTriggerBulk bool, queries traceSpan.EndWithError(err) } - query, scan := prepareUserGrantQuery(ctx, q.client) + query, scan := prepareUserGrantQuery() for _, q := range queries { query = q.toQuery(query) } @@ -274,7 +273,7 @@ func (q *Queries) UserGrants(ctx context.Context, queries *UserGrantsQueries, sh traceSpan.EndWithError(err) } - query, scan := prepareUserGrantsQuery(ctx, q.client) + query, scan := prepareUserGrantsQuery() eq := sq.Eq{UserGrantInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -298,7 +297,7 @@ func (q *Queries) UserGrants(ctx context.Context, queries *UserGrantsQueries, sh return grants, nil } -func prepareUserGrantQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserGrant, error)) { +func prepareUserGrantQuery() (sq.SelectBuilder, func(*sql.Row) (*UserGrant, error)) { return sq.Select( UserGrantID.identifier(), UserGrantCreationDate.identifier(), @@ -336,7 +335,7 @@ func prepareUserGrantQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu LeftJoin(join(OrgColumnID, UserGrantResourceOwner)). LeftJoin(join(ProjectColumnID, UserGrantProjectID)). LeftJoin(join(GrantedOrgColumnId, UserResourceOwnerCol)). - LeftJoin(join(LoginNameUserIDCol, UserGrantUserID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(LoginNameUserIDCol, UserGrantUserID)). Where( sq.Eq{LoginNameIsPrimaryCol.identifier(): true}, ).PlaceholderFormat(sq.Dollar), @@ -421,7 +420,7 @@ func prepareUserGrantQuery(ctx context.Context, db prepareDatabase) (sq.SelectBu } } -func prepareUserGrantsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserGrants, error)) { +func prepareUserGrantsQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserGrants, error)) { return sq.Select( UserGrantID.identifier(), UserGrantCreationDate.identifier(), @@ -461,7 +460,7 @@ func prepareUserGrantsQuery(ctx context.Context, db prepareDatabase) (sq.SelectB LeftJoin(join(OrgColumnID, UserGrantResourceOwner)). LeftJoin(join(ProjectColumnID, UserGrantProjectID)). LeftJoin(join(GrantedOrgColumnId, UserResourceOwnerCol)). - LeftJoin(join(LoginNameUserIDCol, UserGrantUserID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(LoginNameUserIDCol, UserGrantUserID)). Where( sq.Eq{LoginNameIsPrimaryCol.identifier(): true}, ).PlaceholderFormat(sq.Dollar), diff --git a/internal/query/user_grant_test.go b/internal/query/user_grant_test.go index 6cfa0b563b..6a640c2ef2 100644 --- a/internal/query/user_grant_test.go +++ b/internal/query/user_grant_test.go @@ -47,7 +47,6 @@ var ( " LEFT JOIN projections.projects4 ON projections.user_grants5.project_id = projections.projects4.id AND projections.user_grants5.instance_id = projections.projects4.instance_id" + " LEFT JOIN projections.orgs1 AS granted_orgs ON projections.users14.resource_owner = granted_orgs.id AND projections.users14.instance_id = granted_orgs.instance_id" + " LEFT JOIN projections.login_names3 ON projections.user_grants5.user_id = projections.login_names3.user_id AND projections.user_grants5.instance_id = projections.login_names3.instance_id" + - ` AS OF SYSTEM TIME '-1 ms' ` + " WHERE projections.login_names3.is_primary = $1") userGrantCols = []string{ "id", @@ -110,7 +109,6 @@ var ( " LEFT JOIN projections.projects4 ON projections.user_grants5.project_id = projections.projects4.id AND projections.user_grants5.instance_id = projections.projects4.instance_id" + " LEFT JOIN projections.orgs1 AS granted_orgs ON projections.users14.resource_owner = granted_orgs.id AND projections.users14.instance_id = granted_orgs.instance_id" + " LEFT JOIN projections.login_names3 ON projections.user_grants5.user_id = projections.login_names3.user_id AND projections.user_grants5.instance_id = projections.login_names3.instance_id" + - ` AS OF SYSTEM TIME '-1 ms' ` + " WHERE projections.login_names3.is_primary = $1") userGrantsCols = append( userGrantCols, @@ -1008,7 +1006,7 @@ func Test_UserGrantPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/user_membership.go b/internal/query/user_membership.go index 7ba2629cfa..cae2b4dae3 100644 --- a/internal/query/user_membership.go +++ b/internal/query/user_membership.go @@ -9,7 +9,6 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -138,7 +137,7 @@ func (q *Queries) Memberships(ctx context.Context, queries *MembershipSearchQuer wg.Wait() } - query, queryArgs, scan := prepareMembershipsQuery(ctx, q.client, queries) + query, queryArgs, scan := prepareMembershipsQuery(queries) eq := sq.Eq{membershipInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { @@ -237,7 +236,7 @@ func getMembershipFromQuery(queries *MembershipSearchQuery) (string, []interface args } -func prepareMembershipsQuery(ctx context.Context, db prepareDatabase, queries *MembershipSearchQuery) (sq.SelectBuilder, []interface{}, func(*sql.Rows) (*Memberships, error)) { +func prepareMembershipsQuery(queries *MembershipSearchQuery) (sq.SelectBuilder, []interface{}, func(*sql.Rows) (*Memberships, error)) { query, args := getMembershipFromQuery(queries) return sq.Select( membershipUserID.identifier(), @@ -259,7 +258,7 @@ func prepareMembershipsQuery(ctx context.Context, db prepareDatabase, queries *M LeftJoin(join(ProjectColumnID, membershipProjectID)). LeftJoin(join(OrgColumnID, membershipOrgID)). LeftJoin(join(ProjectGrantColumnGrantID, membershipGrantID)). - LeftJoin(join(InstanceColumnID, membershipInstanceID) + db.Timetravel(call.Took(ctx))). + LeftJoin(join(InstanceColumnID, membershipInstanceID)). PlaceholderFormat(sq.Dollar), args, func(rows *sql.Rows) (*Memberships, error) { diff --git a/internal/query/user_membership_test.go b/internal/query/user_membership_test.go index a0ea3cda31..b0170182d1 100644 --- a/internal/query/user_membership_test.go +++ b/internal/query/user_membership_test.go @@ -1,7 +1,6 @@ package query import ( - "context" "database/sql" "database/sql/driver" "errors" @@ -87,8 +86,7 @@ var ( " LEFT JOIN projections.projects4 ON members.project_id = projections.projects4.id AND members.instance_id = projections.projects4.instance_id" + " LEFT JOIN projections.orgs1 ON members.org_id = projections.orgs1.id AND members.instance_id = projections.orgs1.instance_id" + " LEFT JOIN projections.project_grants4 ON members.grant_id = projections.project_grants4.grant_id AND members.instance_id = projections.project_grants4.instance_id" + - " LEFT JOIN projections.instances ON members.instance_id = projections.instances.id" + - ` AS OF SYSTEM TIME '-1 ms'`) + " LEFT JOIN projections.instances ON members.instance_id = projections.instances.id") membershipCols = []string{ "user_id", "roles", @@ -456,14 +454,14 @@ func Test_MembershipPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } -func prepareMembershipWrapper() func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Memberships, error)) { - return func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Memberships, error)) { - builder, _, fun := prepareMembershipsQuery(ctx, db, &MembershipSearchQuery{}) +func prepareMembershipWrapper() func() (sq.SelectBuilder, func(*sql.Rows) (*Memberships, error)) { + return func() (sq.SelectBuilder, func(*sql.Rows) (*Memberships, error)) { + builder, _, fun := prepareMembershipsQuery(&MembershipSearchQuery{}) return builder, fun } } diff --git a/internal/query/user_metadata.go b/internal/query/user_metadata.go index a3b7c1fd34..ff612f82c8 100644 --- a/internal/query/user_metadata.go +++ b/internal/query/user_metadata.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -87,7 +86,7 @@ func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bo traceSpan.EndWithError(err) } - query, scan := prepareUserMetadataQuery(ctx, q.client) + query, scan := prepareUserMetadataQuery() for _, q := range queries { query = q.toQuery(query) } @@ -119,7 +118,7 @@ func (q *Queries) SearchUserMetadataForUsers(ctx context.Context, shouldTriggerB traceSpan.EndWithError(err) } - query, scan := prepareUserMetadataListQuery(ctx, q.client) + query, scan := prepareUserMetadataListQuery() eq := sq.Eq{ UserMetadataUserIDCol.identifier(): userIDs, UserMetadataInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -151,7 +150,7 @@ func (q *Queries) SearchUserMetadata(ctx context.Context, shouldTriggerBulk bool traceSpan.EndWithError(err) } - query, scan := prepareUserMetadataListQuery(ctx, q.client) + query, scan := prepareUserMetadataListQuery() eq := sq.Eq{ UserMetadataUserIDCol.identifier(): userID, UserMetadataInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -235,7 +234,7 @@ func NewUserMetadataExistsQuery(key string, value []byte, keyComparison TextComp ) } -func prepareUserMetadataQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserMetadata, error)) { +func prepareUserMetadataQuery() (sq.SelectBuilder, func(*sql.Row) (*UserMetadata, error)) { return sq.Select( UserMetadataCreationDateCol.identifier(), UserMetadataChangeDateCol.identifier(), @@ -244,7 +243,7 @@ func prepareUserMetadataQuery(ctx context.Context, db prepareDatabase) (sq.Selec UserMetadataKeyCol.identifier(), UserMetadataValueCol.identifier(), ). - From(userMetadataTable.identifier() + db.Timetravel(call.Took(ctx))). + From(userMetadataTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*UserMetadata, error) { m := new(UserMetadata) @@ -267,7 +266,7 @@ func prepareUserMetadataQuery(ctx context.Context, db prepareDatabase) (sq.Selec } } -func prepareUserMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserMetadataList, error)) { +func prepareUserMetadataListQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserMetadataList, error)) { return sq.Select( UserMetadataCreationDateCol.identifier(), UserMetadataChangeDateCol.identifier(), @@ -277,7 +276,7 @@ func prepareUserMetadataListQuery(ctx context.Context, db prepareDatabase) (sq.S UserMetadataKeyCol.identifier(), UserMetadataValueCol.identifier(), countColumn.identifier()). - From(userMetadataTable.identifier() + db.Timetravel(call.Took(ctx))). + From(userMetadataTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*UserMetadataList, error) { metadata := make([]*UserMetadata, 0) diff --git a/internal/query/user_metadata_test.go b/internal/query/user_metadata_test.go index 7f9d1b8ed3..6236272da4 100644 --- a/internal/query/user_metadata_test.go +++ b/internal/query/user_metadata_test.go @@ -18,8 +18,7 @@ var ( ` projections.user_metadata5.sequence,` + ` projections.user_metadata5.key,` + ` projections.user_metadata5.value` + - ` FROM projections.user_metadata5` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.user_metadata5` userMetadataCols = []string{ "creation_date", "change_date", @@ -251,7 +250,7 @@ func Test_UserMetadataPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/user_password.go b/internal/query/user_password.go index ed77d0d3ae..1d0037f721 100644 --- a/internal/query/user_password.go +++ b/internal/query/user_password.go @@ -119,7 +119,6 @@ func (wm *HumanPasswordReadModel) Reduce() error { func (wm *HumanPasswordReadModel) Query() *eventstore.SearchQueryBuilder { query := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent). AwaitOpenTransactions(). - AllowTimeTravel(). AddQuery(). AggregateTypes(user.AggregateType). AggregateIDs(wm.AggregateID). diff --git a/internal/query/user_personal_access_token.go b/internal/query/user_personal_access_token.go index dadd635b6a..8ea33f51a4 100644 --- a/internal/query/user_personal_access_token.go +++ b/internal/query/user_personal_access_token.go @@ -10,7 +10,6 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" @@ -98,7 +97,7 @@ func (q *Queries) PersonalAccessTokenByID(ctx context.Context, shouldTriggerBulk traceSpan.EndWithError(err) } - query, scan := preparePersonalAccessTokenQuery(ctx, q.client) + query, scan := preparePersonalAccessTokenQuery() for _, q := range queries { query = q.toQuery(query) } @@ -128,7 +127,7 @@ func (q *Queries) SearchPersonalAccessTokens(ctx context.Context, queries *Perso ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() - query, scan := preparePersonalAccessTokensQuery(ctx, q.client) + query, scan := preparePersonalAccessTokensQuery() eq := sq.Eq{ PersonalAccessTokenColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), } @@ -178,7 +177,7 @@ func (q *PersonalAccessTokenSearchQueries) toQuery(query sq.SelectBuilder) sq.Se return query } -func preparePersonalAccessTokenQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PersonalAccessToken, error)) { +func preparePersonalAccessTokenQuery() (sq.SelectBuilder, func(*sql.Row) (*PersonalAccessToken, error)) { return sq.Select( PersonalAccessTokenColumnID.identifier(), PersonalAccessTokenColumnCreationDate.identifier(), @@ -188,7 +187,7 @@ func preparePersonalAccessTokenQuery(ctx context.Context, db prepareDatabase) (s PersonalAccessTokenColumnUserID.identifier(), PersonalAccessTokenColumnExpiration.identifier(), PersonalAccessTokenColumnScopes.identifier()). - From(personalAccessTokensTable.identifier() + db.Timetravel(call.Took(ctx))). + From(personalAccessTokensTable.identifier()). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*PersonalAccessToken, error) { p := new(PersonalAccessToken) @@ -212,7 +211,7 @@ func preparePersonalAccessTokenQuery(ctx context.Context, db prepareDatabase) (s } } -func preparePersonalAccessTokensQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*PersonalAccessTokens, error)) { +func preparePersonalAccessTokensQuery() (sq.SelectBuilder, func(*sql.Rows) (*PersonalAccessTokens, error)) { return sq.Select( PersonalAccessTokenColumnID.identifier(), PersonalAccessTokenColumnCreationDate.identifier(), @@ -223,7 +222,7 @@ func preparePersonalAccessTokensQuery(ctx context.Context, db prepareDatabase) ( PersonalAccessTokenColumnExpiration.identifier(), PersonalAccessTokenColumnScopes.identifier(), countColumn.identifier()). - From(personalAccessTokensTable.identifier() + db.Timetravel(call.Took(ctx))). + From(personalAccessTokensTable.identifier()). PlaceholderFormat(sq.Dollar), func(rows *sql.Rows) (*PersonalAccessTokens, error) { personalAccessTokens := make([]*PersonalAccessToken, 0) diff --git a/internal/query/user_personal_access_token_test.go b/internal/query/user_personal_access_token_test.go index 79ba700ed5..dd3ed37e62 100644 --- a/internal/query/user_personal_access_token_test.go +++ b/internal/query/user_personal_access_token_test.go @@ -23,8 +23,7 @@ var ( " projections.personal_access_tokens3.user_id," + " projections.personal_access_tokens3.expiration," + " projections.personal_access_tokens3.scopes" + - " FROM projections.personal_access_tokens3" + - ` AS OF SYSTEM TIME '-1 ms'`) + " FROM projections.personal_access_tokens3") personalAccessTokenCols = []string{ "id", "creation_date", @@ -45,8 +44,7 @@ var ( " projections.personal_access_tokens3.expiration," + " projections.personal_access_tokens3.scopes," + " COUNT(*) OVER ()" + - " FROM projections.personal_access_tokens3" + - " AS OF SYSTEM TIME '-1 ms'") + " FROM projections.personal_access_tokens3") personalAccessTokensCols = []string{ "id", "creation_date", @@ -266,7 +264,7 @@ func Test_PersonalAccessTokenPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/user_schema.go b/internal/query/user_schema.go index ff5117d264..d3ea4dee79 100644 --- a/internal/query/user_schema.go +++ b/internal/query/user_schema.go @@ -96,7 +96,7 @@ func (q *Queries) GetUserSchemaByID(ctx context.Context, id string) (userSchema } query, scan := prepareUserSchemaQuery() - return genericRowQuery[*UserSchema](ctx, q.client, query.Where(eq), scan) + return genericRowQuery(ctx, q.client, query.Where(eq), scan) } func (q *Queries) SearchUserSchema(ctx context.Context, queries *UserSchemaSearchQueries) (userSchemas *UserSchemas, err error) { @@ -108,7 +108,7 @@ func (q *Queries) SearchUserSchema(ctx context.Context, queries *UserSchemaSearc } query, scan := prepareUserSchemasQuery() - return genericRowsQueryWithState[*UserSchemas](ctx, q.client, userSchemaTable, combineToWhereStmt(query, queries.toQuery, eq), scan) + return genericRowsQueryWithState(ctx, q.client, userSchemaTable, combineToWhereStmt(query, queries.toQuery, eq), scan) } func (q *UserSchemaSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { diff --git a/internal/query/user_test.go b/internal/query/user_test.go index 16b08611f0..50d65cc1ec 100644 --- a/internal/query/user_test.go +++ b/internal/query/user_test.go @@ -6,7 +6,6 @@ import ( "database/sql/driver" "errors" "fmt" - "reflect" "regexp" "testing" @@ -268,8 +267,7 @@ var ( ` ON login_names.user_id = projections.users14.id AND login_names.instance_id = projections.users14.instance_id` + ` LEFT JOIN` + ` (` + preferredLoginNameQuery + `) AS preferred_login_name` + - ` ON preferred_login_name.user_id = projections.users14.id AND preferred_login_name.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` ON preferred_login_name.user_id = projections.users14.id AND preferred_login_name.instance_id = projections.users14.instance_id` userCols = []string{ "id", "creation_date", @@ -319,8 +317,7 @@ var ( ` projections.users14_humans.gender,` + ` projections.users14_humans.avatar_key` + ` FROM projections.users14` + - ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` profileCols = []string{ "id", "creation_date", @@ -345,8 +342,7 @@ var ( ` projections.users14_humans.email,` + ` projections.users14_humans.is_email_verified` + ` FROM projections.users14` + - ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` emailCols = []string{ "id", "creation_date", @@ -366,8 +362,7 @@ var ( ` projections.users14_humans.phone,` + ` projections.users14_humans.is_phone_verified` + ` FROM projections.users14` + - ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` phoneCols = []string{ "id", "creation_date", @@ -385,8 +380,7 @@ var ( ` projections.users14_humans.email,` + ` projections.users14_humans.is_email_verified` + ` FROM projections.users14` + - ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.users14_humans ON projections.users14.id = projections.users14_humans.user_id AND projections.users14.instance_id = projections.users14_humans.instance_id` userUniqueCols = []string{ "id", "state", @@ -428,8 +422,7 @@ var ( ` ON login_names.user_id = projections.users14.id AND login_names.instance_id = projections.users14.instance_id` + ` LEFT JOIN` + ` (` + preferredLoginNameQuery + `) AS preferred_login_name` + - ` ON preferred_login_name.user_id = projections.users14.id AND preferred_login_name.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` ON preferred_login_name.user_id = projections.users14.id AND preferred_login_name.instance_id = projections.users14.instance_id` notifyUserCols = []string{ "id", "creation_date", @@ -497,8 +490,7 @@ var ( ` ON login_names.user_id = projections.users14.id AND login_names.instance_id = projections.users14.instance_id` + ` LEFT JOIN` + ` (` + preferredLoginNameQuery + `) AS preferred_login_name` + - ` ON preferred_login_name.user_id = projections.users14.id AND preferred_login_name.instance_id = projections.users14.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` ON preferred_login_name.user_id = projections.users14.id AND preferred_login_name.instance_id = projections.users14.instance_id` usersCols = []string{ "id", "creation_date", @@ -1572,12 +1564,7 @@ func Test_UserPrepares(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - params := defaultPrepareArgs - if reflect.TypeOf(tt.prepare).NumIn() == 0 { - params = []reflect.Value{} - } - - assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, params...) + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) }) } } diff --git a/internal/query/userinfo_test.go b/internal/query/userinfo_test.go index 6ded7b4eed..5314283635 100644 --- a/internal/query/userinfo_test.go +++ b/internal/query/userinfo_test.go @@ -429,8 +429,7 @@ func TestQueries_GetOIDCUserInfo(t *testing.T) { execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") @@ -476,8 +475,7 @@ func TestQueries_GetOIDCUserinfoClientByID(t *testing.T) { execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } ctx := authz.NewMockContext("instanceID", "orgID", "loginClient") diff --git a/internal/query/web_key_test.go b/internal/query/web_key_test.go index 6008ec6528..80d07bfa13 100644 --- a/internal/query/web_key_test.go +++ b/internal/query/web_key_test.go @@ -208,8 +208,7 @@ func TestQueries_GetActiveSigningWebKey(t *testing.T) { execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, keyEncryptionAlgorithm: alg, } @@ -307,8 +306,7 @@ func TestQueries_ListWebKeys(t *testing.T) { execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } got, err := q.ListWebKeys(ctx) @@ -369,8 +367,7 @@ func TestQueries_GetWebKeySet(t *testing.T) { execMock(t, tt.mock, func(db *sql.DB) { q := &Queries{ client: &database.DB{ - DB: db, - Database: &prepareDB{}, + DB: db, }, } got, err := q.GetWebKeySet(ctx) diff --git a/internal/queue/queue.go b/internal/queue/queue.go index b45a7eb8cb..d680221753 100644 --- a/internal/queue/queue.go +++ b/internal/queue/queue.go @@ -27,9 +27,6 @@ type Config struct { } func NewQueue(config *Config) (_ *Queue, err error) { - if config.Client.Type() == "cockroach" { - return nil, nil - } return &Queue{ driver: riverpgxv5.New(config.Client.Pool), config: &river.Config{ diff --git a/internal/static/database/crdb.go b/internal/static/database/crdb.go index a031f7d17a..549e0ae505 100644 --- a/internal/static/database/crdb.go +++ b/internal/static/database/crdb.go @@ -14,7 +14,7 @@ import ( "github.com/zitadel/zitadel/internal/zerrors" ) -var _ static.Storage = (*crdbStorage)(nil) +var _ static.Storage = (*storage)(nil) const ( assetsTable = "system.assets" @@ -29,15 +29,15 @@ const ( AssetColUpdatedAt = "updated_at" ) -type crdbStorage struct { +type storage struct { client *sql.DB } func NewStorage(client *sql.DB, _ map[string]interface{}) (static.Storage, error) { - return &crdbStorage{client: client}, nil + return &storage{client: client}, nil } -func (c *crdbStorage) PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) { +func (c *storage) PutObject(ctx context.Context, instanceID, location, resourceOwner, name, contentType string, objectType static.ObjectType, object io.Reader, objectSize int64) (*static.Asset, error) { data, err := io.ReadAll(object) if err != nil { return nil, zerrors.ThrowInternal(err, "DATAB-Dfwvq", "Errors.Internal") @@ -71,7 +71,7 @@ func (c *crdbStorage) PutObject(ctx context.Context, instanceID, location, resou }, nil } -func (c *crdbStorage) GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*static.Asset, error), error) { +func (c *storage) GetObject(ctx context.Context, instanceID, resourceOwner, name string) ([]byte, func() (*static.Asset, error), error) { query, args, err := squirrel.Select(AssetColData, AssetColContentType, AssetColHash, AssetColUpdatedAt). From(assetsTable). Where(squirrel.Eq{ @@ -111,7 +111,7 @@ func (c *crdbStorage) GetObject(ctx context.Context, instanceID, resourceOwner, nil } -func (c *crdbStorage) GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*static.Asset, error) { +func (c *storage) GetObjectInfo(ctx context.Context, instanceID, resourceOwner, name string) (*static.Asset, error) { query, args, err := squirrel.Select(AssetColContentType, AssetColLocation, "length("+AssetColData+")", AssetColHash, AssetColUpdatedAt). From(assetsTable). Where(squirrel.Eq{ @@ -143,7 +143,7 @@ func (c *crdbStorage) GetObjectInfo(ctx context.Context, instanceID, resourceOwn return asset, nil } -func (c *crdbStorage) RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error { +func (c *storage) RemoveObject(ctx context.Context, instanceID, resourceOwner, name string) error { stmt, args, err := squirrel.Delete(assetsTable). Where(squirrel.Eq{ AssetColInstanceID: instanceID, @@ -162,7 +162,7 @@ func (c *crdbStorage) RemoveObject(ctx context.Context, instanceID, resourceOwne return nil } -func (c *crdbStorage) RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType static.ObjectType) error { +func (c *storage) RemoveObjects(ctx context.Context, instanceID, resourceOwner string, objectType static.ObjectType) error { stmt, args, err := squirrel.Delete(assetsTable). Where(squirrel.Eq{ AssetColInstanceID: instanceID, @@ -181,7 +181,7 @@ func (c *crdbStorage) RemoveObjects(ctx context.Context, instanceID, resourceOwn return nil } -func (c *crdbStorage) RemoveInstanceObjects(ctx context.Context, instanceID string) error { +func (c *storage) RemoveInstanceObjects(ctx context.Context, instanceID string) error { stmt, args, err := squirrel.Delete(assetsTable). Where(squirrel.Eq{ AssetColInstanceID: instanceID, diff --git a/internal/static/database/crdb_test.go b/internal/static/database/crdb_test.go index 14a128dbe2..2be76e69fa 100644 --- a/internal/static/database/crdb_test.go +++ b/internal/static/database/crdb_test.go @@ -40,7 +40,7 @@ const ( " WHERE instance_id = $1" ) -func Test_crdbStorage_CreateObject(t *testing.T) { +func Test_dbStorage_CreateObject(t *testing.T) { type fields struct { client db } @@ -112,7 +112,7 @@ func Test_crdbStorage_CreateObject(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &crdbStorage{ + c := &storage{ client: tt.fields.client.db, } got, err := c.PutObject(tt.args.ctx, tt.args.instanceID, tt.args.location, tt.args.resourceOwner, tt.args.name, tt.args.contentType, tt.args.objectType, tt.args.data, tt.args.objectSize) @@ -127,7 +127,7 @@ func Test_crdbStorage_CreateObject(t *testing.T) { } } -func Test_crdbStorage_RemoveObject(t *testing.T) { +func Test_dbStorage_RemoveObject(t *testing.T) { type fields struct { client db } @@ -166,7 +166,7 @@ func Test_crdbStorage_RemoveObject(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &crdbStorage{ + c := &storage{ client: tt.fields.client.db, } err := c.RemoveObject(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.name) @@ -178,7 +178,7 @@ func Test_crdbStorage_RemoveObject(t *testing.T) { } } -func Test_crdbStorage_RemoveObjects(t *testing.T) { +func Test_dbStorage_RemoveObjects(t *testing.T) { type fields struct { client db } @@ -216,7 +216,7 @@ func Test_crdbStorage_RemoveObjects(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &crdbStorage{ + c := &storage{ client: tt.fields.client.db, } err := c.RemoveObjects(tt.args.ctx, tt.args.instanceID, tt.args.resourceOwner, tt.args.objectType) @@ -227,7 +227,7 @@ func Test_crdbStorage_RemoveObjects(t *testing.T) { }) } } -func Test_crdbStorage_RemoveInstanceObjects(t *testing.T) { +func Test_dbStorage_RemoveInstanceObjects(t *testing.T) { type fields struct { client db } @@ -260,7 +260,7 @@ func Test_crdbStorage_RemoveInstanceObjects(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - c := &crdbStorage{ + c := &storage{ client: tt.fields.client.db, } err := c.RemoveInstanceObjects(tt.args.ctx, tt.args.instanceID) diff --git a/internal/v2/eventstore/postgres/push.go b/internal/v2/eventstore/postgres/push.go index 09f663a086..bde74687c7 100644 --- a/internal/v2/eventstore/postgres/push.go +++ b/internal/v2/eventstore/postgres/push.go @@ -171,8 +171,7 @@ func (s *Storage) push(ctx context.Context, tx *sql.Tx, reducer eventstore.Reduc cmd.position.InPositionOrder, ) - stmt.WriteString(s.pushPositionStmt) - stmt.WriteString(`)`) + stmt.WriteString(", statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp()))") } stmt.WriteString(` RETURNING created_at, "position"`) diff --git a/internal/v2/eventstore/postgres/push_test.go b/internal/v2/eventstore/postgres/push_test.go index 91fdc1fcd7..bb3254427c 100644 --- a/internal/v2/eventstore/postgres/push_test.go +++ b/internal/v2/eventstore/postgres/push_test.go @@ -1288,7 +1288,6 @@ func Test_push(t *testing.T) { }, }, } - initPushStmt("postgres") for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { dbMock := mock.NewSQLMock(t, append([]mock.Expectation{mock.ExpectBegin(nil)}, tt.args.expectations...)...) @@ -1297,9 +1296,7 @@ func Test_push(t *testing.T) { t.Errorf("unexpected error in begin: %v", err) t.FailNow() } - s := Storage{ - pushPositionStmt: initPushStmt("postgres"), - } + s := Storage{} err = s.push(context.Background(), tx, tt.args.reducer, tt.args.commands) tt.want.assertErr(t, err) dbMock.Assert(t) diff --git a/internal/v2/eventstore/postgres/storage.go b/internal/v2/eventstore/postgres/storage.go index 3a703a7d17..d4148f4f1a 100644 --- a/internal/v2/eventstore/postgres/storage.go +++ b/internal/v2/eventstore/postgres/storage.go @@ -3,8 +3,6 @@ package postgres import ( "context" - "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/v2/eventstore" ) @@ -15,9 +13,8 @@ var ( ) type Storage struct { - client *database.DB - config *Config - pushPositionStmt string + client *database.DB + config *Config } type Config struct { @@ -25,23 +22,9 @@ type Config struct { } func New(client *database.DB, config *Config) *Storage { - initPushStmt(client.Type()) return &Storage{ - client: client, - config: config, - pushPositionStmt: initPushStmt(client.Type()), - } -} - -func initPushStmt(typ string) string { - switch typ { - case "cockroach": - return ", hlc_to_timestamp(cluster_logical_timestamp()), cluster_logical_timestamp()" - case "postgres": - return ", statement_timestamp(), EXTRACT(EPOCH FROM clock_timestamp())" - default: - logging.WithFields("database_type", typ).Panic("position statement for type not implemented") - return "" + client: client, + config: config, } }