diff --git a/cmd/key/key.go b/cmd/key/key.go index 1ccde775ed..02fa272a8a 100644 --- a/cmd/key/key.go +++ b/cmd/key/key.go @@ -128,5 +128,5 @@ func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, er if err != nil { return nil, err } - return cryptoDB.NewKeyStorage(db.DB, masterKey) + return cryptoDB.NewKeyStorage(db, masterKey) } diff --git a/cmd/setup/03.go b/cmd/setup/03.go index 24b0942260..73d32fcab6 100644 --- a/cmd/setup/03.go +++ b/cmd/setup/03.go @@ -2,7 +2,6 @@ package setup import ( "context" - "database/sql" "fmt" "os" "strings" @@ -14,6 +13,7 @@ import ( "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/crypto" crypto_db "github.com/zitadel/zitadel/internal/crypto/database" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" ) @@ -30,7 +30,7 @@ type FirstInstance struct { smtpEncryptionKey *crypto.KeyConfig oidcEncryptionKey *crypto.KeyConfig masterKey string - db *sql.DB + db *database.DB es *eventstore.Eventstore defaults systemdefaults.SystemDefaults zitadelRoles []authz.RoleMapping diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index beba81611a..8097b61643 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -77,7 +77,7 @@ func Setup(config *Config, steps *Steps, masterKey string) { steps.FirstInstance.smtpEncryptionKey = config.EncryptionKeys.SMTP steps.FirstInstance.oidcEncryptionKey = config.EncryptionKeys.OIDC steps.FirstInstance.masterKey = masterKey - steps.FirstInstance.db = dbClient.DB + steps.FirstInstance.db = dbClient steps.FirstInstance.es = eventstoreClient steps.FirstInstance.defaults = config.SystemDefaults steps.FirstInstance.zitadelRoles = config.InternalAuthZ.RolePermissionMappings diff --git a/cmd/start/start.go b/cmd/start/start.go index d0df178c94..c8a5c20e32 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -124,7 +124,7 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error return fmt.Errorf("cannot start client for projection: %w", err) } - keyStorage, err := cryptoDB.NewKeyStorage(dbClient.DB, masterKey) + keyStorage, err := cryptoDB.NewKeyStorage(dbClient, masterKey) if err != nil { return fmt.Errorf("cannot start key storage: %w", err) } diff --git a/internal/admin/repository/eventsourcing/view/view.go b/internal/admin/repository/eventsourcing/view/view.go index 095e7c1dfa..9ede972813 100644 --- a/internal/admin/repository/eventsourcing/view/view.go +++ b/internal/admin/repository/eventsourcing/view/view.go @@ -15,7 +15,7 @@ type View struct { } func StartView(sqlClient *database.DB) (*View, error) { - gorm, err := gorm.Open("postgres", sqlClient) + gorm, err := gorm.Open("postgres", sqlClient.DB) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/view/view.go b/internal/auth/repository/eventsourcing/view/view.go index b65badf1e5..08a9014d1c 100644 --- a/internal/auth/repository/eventsourcing/view/view.go +++ b/internal/auth/repository/eventsourcing/view/view.go @@ -23,7 +23,7 @@ type View struct { } func StartView(sqlClient *database.DB, keyAlgorithm crypto.EncryptionAlgorithm, queries *query.Queries, idGenerator id.Generator, es eventstore.Eventstore) (*View, error) { - gorm, err := gorm.Open("postgres", sqlClient) + gorm, err := gorm.Open("postgres", sqlClient.DB) if err != nil { return nil, err } diff --git a/internal/auth_request/repository/cache/cache.go b/internal/auth_request/repository/cache/cache.go index 37bdb2276f..10090ef286 100644 --- a/internal/auth_request/repository/cache/cache.go +++ b/internal/auth_request/repository/cache/cache.go @@ -59,7 +59,12 @@ func (c *AuthRequestCache) getAuthRequest(key, value, instanceID string) (*domai var b []byte var requestType domain.AuthRequestType query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE instance_id = $1 and %s = $2", key) - err := c.client.QueryRow(query, instanceID, value).Scan(&b, &requestType) + err := c.client.QueryRow( + func(row *sql.Row) error { + return row.Scan(&b, &requestType) + }, + query, instanceID, value) + if err != nil { if errors.Is(err, sql.ErrNoRows) { return nil, caos_errs.ThrowNotFound(err, "CACHE-d24aD", "Errors.AuthRequest.NotFound") diff --git a/internal/authz/repository/eventsourcing/view/view.go b/internal/authz/repository/eventsourcing/view/view.go index c3ee5b79ac..0b07cd2e0d 100644 --- a/internal/authz/repository/eventsourcing/view/view.go +++ b/internal/authz/repository/eventsourcing/view/view.go @@ -19,7 +19,7 @@ type View struct { } func StartView(sqlClient *database.DB, idGenerator id.Generator, queries *query.Queries) (*View, error) { - gorm, err := gorm.Open("postgres", sqlClient) + gorm, err := gorm.Open("postgres", sqlClient.DB) if err != nil { return nil, err } diff --git a/internal/crypto/database/database.go b/internal/crypto/database/database.go index 228a9261b5..7cbf46dc5b 100644 --- a/internal/crypto/database/database.go +++ b/internal/crypto/database/database.go @@ -6,11 +6,12 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/zitadel/zitadel/internal/crypto" + z_db "github.com/zitadel/zitadel/internal/database" caos_errs "github.com/zitadel/zitadel/internal/errors" ) type database struct { - client *sql.DB + client *z_db.DB masterKey string encrypt func(key, masterKey string) (encryptedKey string, err error) decrypt func(encryptedKey, masterKey string) (key string, err error) @@ -22,7 +23,7 @@ const ( encryptionKeysKeyCol = "key" ) -func NewKeyStorage(client *sql.DB, masterKey string) (*database, error) { +func NewKeyStorage(client *z_db.DB, masterKey string) (*database, error) { if err := checkMasterKeyLength(masterKey); err != nil { return nil, err } @@ -42,32 +43,30 @@ func (d *database) ReadKeys() (crypto.Keys, error) { if err != nil { return nil, caos_errs.ThrowInternal(err, "", "unable to read keys") } - rows, err := d.client.Query(stmt, args...) + err = d.client.Query(func(rows *sql.Rows) error { + for rows.Next() { + var id, encryptionKey string + err = rows.Scan(&id, &encryptionKey) + if err != nil { + return caos_errs.ThrowInternal(err, "", "unable to read keys") + } + key, err := d.decrypt(encryptionKey, d.masterKey) + if err != nil { + return caos_errs.ThrowInternal(err, "", "unable to decrypt key") + } + keys[id] = key + } + return nil + }, stmt, args...) + if err != nil { return nil, caos_errs.ThrowInternal(err, "", "unable to read keys") } - for rows.Next() { - var id, encryptionKey string - err = rows.Scan(&id, &encryptionKey) - if err != nil { - return nil, caos_errs.ThrowInternal(err, "", "unable to read keys") - } - key, err := d.decrypt(encryptionKey, d.masterKey) - if err != nil { - if err := rows.Close(); err != nil { - return nil, caos_errs.ThrowInternal(err, "", "unable to close rows") - } - return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key") - } - keys[id] = key - } - if err := rows.Close(); err != nil { - return nil, caos_errs.ThrowInternal(err, "", "unable to close rows") - } - return keys, err + + return keys, nil } -func (d *database) ReadKey(id string) (*crypto.Key, error) { +func (d *database) ReadKey(id string) (_ *crypto.Key, err error) { stmt, args, err := sq.Select(encryptionKeysKeyCol). From(EncryptionKeysTable). Where(sq.Eq{encryptionKeysIDCol: id}). @@ -76,19 +75,23 @@ func (d *database) ReadKey(id string) (*crypto.Key, error) { if err != nil { return nil, caos_errs.ThrowInternal(err, "", "unable to read key") } - row := d.client.QueryRow(stmt, args...) + var key string + err = d.client.QueryRow(func(row *sql.Row) error { + var encryptionKey string + err = row.Scan(&encryptionKey) + if err != nil { + return caos_errs.ThrowInternal(err, "", "unable to read key") + } + key, err = d.decrypt(encryptionKey, d.masterKey) + if err != nil { + return caos_errs.ThrowInternal(err, "", "unable to decrypt key") + } + return nil + }, stmt, args...) if err != nil { return nil, caos_errs.ThrowInternal(err, "", "unable to read key") } - var encryptionKey string - err = row.Scan(&encryptionKey) - if err != nil { - return nil, caos_errs.ThrowInternal(err, "", "unable to read key") - } - key, err := d.decrypt(encryptionKey, d.masterKey) - if err != nil { - return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key") - } + return &crypto.Key{ ID: id, Value: key, diff --git a/internal/crypto/database/database_test.go b/internal/crypto/database/database_test.go index b85c5f882a..62a089ab31 100644 --- a/internal/crypto/database/database_test.go +++ b/internal/crypto/database/database_test.go @@ -12,6 +12,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/zitadel/zitadel/internal/crypto" + z_db "github.com/zitadel/zitadel/internal/database" caos_errs "github.com/zitadel/zitadel/internal/errors" ) @@ -46,7 +47,7 @@ func Test_database_ReadKeys(t *testing.T) { { "decryption error", fields{ - client: dbMock(t, expectQuery( + client: dbMock(t, expectQueryScanErr( "SELECT id, key FROM system.encryption_keys", []string{"id", "key"}, [][]driver.Value{ @@ -172,7 +173,7 @@ func Test_database_ReadKey(t *testing.T) { { "key not found err", fields{ - client: dbMock(t, expectQuery( + client: dbMock(t, expectQueryScanErr( "SELECT key FROM system.encryption_keys WHERE id = $1", nil, nil, @@ -192,7 +193,7 @@ func Test_database_ReadKey(t *testing.T) { { "decryption error", fields{ - client: dbMock(t, expectQuery( + client: dbMock(t, expectQueryScanErr( "SELECT key FROM system.encryption_keys WHERE id = $1", []string{"key"}, [][]driver.Value{ @@ -445,7 +446,7 @@ func Test_checkMasterKeyLength(t *testing.T) { type db struct { mock sqlmock.Sqlmock - db *sql.DB + db *z_db.DB } func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db { @@ -459,19 +460,41 @@ func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db { } return db{ mock: mock, - db: client, + db: &z_db.DB{DB: client}, } } func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { + m.ExpectBegin() m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err) + m.ExpectRollback() + } +} + +func expectQueryScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) { + return func(m sqlmock.Sqlmock) { + m.ExpectBegin() + q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...) + m.ExpectRollback() + result := sqlmock.NewRows(cols) + count := uint64(len(rows)) + for _, row := range rows { + if cols[len(cols)-1] == "count" { + row = append(row, count) + } + result.AddRow(row...) + } + q.WillReturnRows(result) + q.RowsWillBeClosed() } } func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { + m.ExpectBegin() q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...) + m.ExpectCommit() result := sqlmock.NewRows(cols) count := uint64(len(rows)) for _, row := range rows { diff --git a/internal/database/cockroach/config.go b/internal/database/cockroach/config.go index f48f026e24..ca51af16d1 100644 --- a/internal/database/cockroach/config.go +++ b/internal/database/cockroach/config.go @@ -2,7 +2,6 @@ package cockroach import ( "database/sql" - "fmt" "strconv" "strings" "time" @@ -94,12 +93,7 @@ func (c *Config) Type() string { } func (c *Config) Timetravel(d time.Duration) string { - // verify that it is at least 1 micro second - if d < time.Microsecond { - d = time.Microsecond - } - - return fmt.Sprintf(" AS OF SYSTEM TIME '-%d µs' ", d.Microseconds()) + return "" } type User struct { diff --git a/internal/database/cockroach/config_test.go b/internal/database/cockroach/config_test.go deleted file mode 100644 index ca16f98f7c..0000000000 --- a/internal/database/cockroach/config_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package cockroach - -import ( - "testing" - "time" -) - -func TestConfig_Timetravel(t *testing.T) { - type args struct { - d time.Duration - } - tests := []struct { - name string - args args - want string - }{ - { - name: "no duration", - args: args{ - d: 0, - }, - want: " AS OF SYSTEM TIME '-1 µs' ", - }, - { - name: "less than microsecond", - args: args{ - d: 100 * time.Nanosecond, - }, - want: " AS OF SYSTEM TIME '-1 µs' ", - }, - { - name: "10 microseconds", - args: args{ - d: 10 * time.Microsecond, - }, - want: " AS OF SYSTEM TIME '-10 µs' ", - }, - { - name: "10 milliseconds", - args: args{ - d: 10 * time.Millisecond, - }, - want: " AS OF SYSTEM TIME '-10000 µs' ", - }, - { - name: "1 second", - args: args{ - d: 1 * time.Second, - }, - want: " AS OF SYSTEM TIME '-1000000 µs' ", - }, - } - for _, tt := range tests { - c := &Config{} - t.Run(tt.name, func(t *testing.T) { - if got := c.Timetravel(tt.args.d); got != tt.want { - t.Errorf("Config.Timetravel() = %q, want %q", got, tt.want) - } - }) - } -} diff --git a/internal/database/database.go b/internal/database/database.go index c7507f1e68..5f4de01220 100644 --- a/internal/database/database.go +++ b/internal/database/database.go @@ -1,9 +1,12 @@ package database import ( + "context" "database/sql" "reflect" + "github.com/zitadel/logging" + _ "github.com/zitadel/zitadel/internal/database/cockroach" "github.com/zitadel/zitadel/internal/database/dialect" _ "github.com/zitadel/zitadel/internal/database/postgres" @@ -24,6 +27,66 @@ type DB struct { dialect.Database } +func (db *DB) Query(scan func(*sql.Rows) error, query string, args ...any) error { + return db.QueryContext(context.Background(), scan, query, args...) +} + +func (db *DB) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) (err error) { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return err + } + defer func() { + if err != nil { + rollbackErr := tx.Rollback() + logging.OnError(rollbackErr).Info("commit of read only transaction failed") + return + } + err = tx.Commit() + }() + + rows, err := tx.QueryContext(ctx, query, args...) + if err != nil { + return err + } + defer func() { + closeErr := rows.Close() + logging.OnError(closeErr).Info("rows.Close failed") + }() + + if err = scan(rows); err != nil { + return err + } + return rows.Err() +} + +func (db *DB) QueryRow(scan func(*sql.Row) error, query string, args ...any) (err error) { + return db.QueryRowContext(context.Background(), scan, query, args...) +} + +func (db *DB) QueryRowContext(ctx context.Context, scan func(row *sql.Row) error, query string, args ...any) (err error) { + tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true}) + if err != nil { + return err + } + defer func() { + if err != nil { + rollbackErr := tx.Rollback() + logging.OnError(rollbackErr).Info("commit of read only transaction failed") + return + } + err = tx.Commit() + }() + + row := tx.QueryRowContext(ctx, query, args...) + + err = scan(row) + if err != nil { + return err + } + return row.Err() +} + func Connect(config Config, useAdmin bool) (*DB, error) { client, err := config.connector.Connect(useAdmin) if err != nil { diff --git a/internal/eventstore/handler/crdb/current_sequence.go b/internal/eventstore/handler/crdb/current_sequence.go index b11c17dd40..2b92343be8 100644 --- a/internal/eventstore/handler/crdb/current_sequence.go +++ b/internal/eventstore/handler/crdb/current_sequence.go @@ -12,9 +12,10 @@ import ( ) const ( - currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE` - updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES ` - updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp` + currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE` + currentSequenceStmtWithoutLockFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2)` + updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES ` + updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp` ) type currentSequences map[eventstore.AggregateType][]*instanceSequence @@ -24,41 +25,38 @@ type instanceSequence struct { sequence uint64 } -func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs database.StringArray) (currentSequences, error) { - rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, instanceIDs) +func (h *StatementHandler) currentSequences(ctx context.Context, isTx bool, query func(context.Context, func(*sql.Rows) error, string, ...interface{}) error, instanceIDs database.StringArray) (currentSequences, error) { + stmt := h.currentSequenceStmt + if !isTx { + stmt = h.currentSequenceWithoutLockStmt + } + + sequences := make(currentSequences, len(h.aggregates)) + err := query(ctx, + func(rows *sql.Rows) error { + for rows.Next() { + var ( + aggregateType eventstore.AggregateType + sequence uint64 + instanceID string + ) + + err := rows.Scan(&sequence, &aggregateType, &instanceID) + if err != nil { + return errors.ThrowInternal(err, "CRDB-dbatK", "scan failed") + } + + sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{ + sequence: sequence, + instanceID: instanceID, + }) + } + return nil + }, + stmt, h.ProjectionName, instanceIDs) if err != nil { return nil, err } - - defer rows.Close() - - sequences := make(currentSequences, len(h.aggregates)) - for rows.Next() { - var ( - aggregateType eventstore.AggregateType - sequence uint64 - instanceID string - ) - - err = rows.Scan(&sequence, &aggregateType, &instanceID) - if err != nil { - return nil, errors.ThrowInternal(err, "CRDB-dbatK", "scan failed") - } - - sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{ - sequence: sequence, - instanceID: instanceID, - }) - } - - if err = rows.Close(); err != nil { - return nil, errors.ThrowInternal(err, "CRDB-h5i5m", "close rows failed") - } - - if err = rows.Err(); err != nil { - return nil, errors.ThrowInternal(err, "CRDB-O8zig", "errors in scanning rows") - } - return sequences, nil } diff --git a/internal/eventstore/handler/crdb/db_mock_test.go b/internal/eventstore/handler/crdb/db_mock_test.go index 33c7984e12..20b70cf5e8 100644 --- a/internal/eventstore/handler/crdb/db_mock_test.go +++ b/internal/eventstore/handler/crdb/db_mock_test.go @@ -124,13 +124,17 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) { } } -func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) { +func expectCurrentSequence(isTx bool, tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) { rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}) for _, instanceID := range instanceIDs { rows.AddRow(seq, aggregateType, instanceID) } return func(m sqlmock.Sqlmock) { - m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`). + stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)` + if isTx { + stmt += " FOR UPDATE" + } + m.ExpectQuery(stmt). WithArgs( projection, database.StringArray(instanceIDs), @@ -141,9 +145,13 @@ func expectCurrentSequence(tableName, projection string, seq uint64, aggregateTy } } -func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) { +func expectCurrentSequenceErr(isTx bool, tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) { - m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`). + stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)` + if isTx { + stmt += " FOR UPDATE" + } + m.ExpectQuery(stmt). WithArgs( projection, database.StringArray(instanceIDs), diff --git a/internal/eventstore/handler/crdb/handler_stmt.go b/internal/eventstore/handler/crdb/handler_stmt.go index e49fe007be..27418bf332 100644 --- a/internal/eventstore/handler/crdb/handler_stmt.go +++ b/internal/eventstore/handler/crdb/handler_stmt.go @@ -36,13 +36,14 @@ type StatementHandler struct { *handler.ProjectionHandler Locker - client *database.DB - sequenceTable string - currentSequenceStmt string - updateSequencesBaseStmt string - maxFailureCount uint - failureCountStmt string - setFailureCountStmt string + client *database.DB + sequenceTable string + currentSequenceStmt string + currentSequenceWithoutLockStmt string + updateSequencesBaseStmt string + maxFailureCount uint + failureCountStmt string + setFailureCountStmt string aggregates []eventstore.AggregateType reduces map[eventstore.EventType]handler.Reduce @@ -77,20 +78,21 @@ func NewStatementHandler( } h := StatementHandler{ - client: config.Client, - sequenceTable: config.SequenceTable, - maxFailureCount: config.MaxFailureCount, - currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable), - updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable), - failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable), - setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable), - aggregates: aggregateTypes, - reduces: reduces, - bulkLimit: config.BulkLimit, - Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName), - initCheck: config.InitCheck, - initialized: make(chan bool), - reduceScheduledPseudoEvent: reduceScheduledPseudoEvent, + client: config.Client, + sequenceTable: config.SequenceTable, + maxFailureCount: config.MaxFailureCount, + currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable), + currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable), + updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable), + failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable), + setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable), + aggregates: aggregateTypes, + reduces: reduces, + bulkLimit: config.BulkLimit, + Locker: NewLocker(config.Client.DB, config.LockTable, config.ProjectionName), + initCheck: config.InitCheck, + initialized: make(chan bool), + reduceScheduledPseudoEvent: reduceScheduledPseudoEvent, } h.ProjectionHandler = handler.NewProjectionHandler(ctx, config.ProjectionHandlerConfig, h.reduce, h.Update, h.searchQuery, h.Lock, h.Unlock, h.initialized, reduceScheduledPseudoEvent) @@ -114,7 +116,7 @@ func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string } func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) { - sequences, err := h.currentSequences(ctx, h.client.QueryContext, instanceIDs) + sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs) if err != nil { return nil, 0, err } @@ -140,6 +142,26 @@ func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []stri return queryBuilder, h.bulkLimit, nil } +type transaction struct { + *sql.Tx +} + +func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error { + rows, err := t.Tx.QueryContext(ctx, query, args...) + if err != nil { + return err + } + defer func() { + closeErr := rows.Close() + logging.OnError(closeErr).Info("rows.Close failed") + }() + + if err = scan(rows); err != nil { + return err + } + return rows.Err() +} + // Update implements handler.Update func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) { if len(stmts) == 0 { @@ -154,7 +176,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed") } - sequences, err := h.currentSequences(ctx, tx.QueryContext, instanceIDs) + sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs) if err != nil { tx.Rollback() return -1, err diff --git a/internal/eventstore/handler/crdb/handler_stmt_test.go b/internal/eventstore/handler/crdb/handler_stmt_test.go index b72db40eb1..027469e71b 100644 --- a/internal/eventstore/handler/crdb/handler_stmt_test.go +++ b/internal/eventstore/handler/crdb/handler_stmt_test.go @@ -90,7 +90,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) { return errors.Is(err, sql.ErrTxDone) }, expectations: []mockExpectation{ - expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone), + expectBegin(), + expectCurrentSequenceErr(false, "my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone), + expectRollback(), }, SearchQueryBuilder: nil, }, @@ -112,7 +114,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) { return err == nil }, expectations: []mockExpectation{ - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}), + expectBegin(), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}), + expectCommit(), }, SearchQueryBuilder: eventstore. NewSearchQueryBuilder(eventstore.ColumnsEvent). @@ -142,7 +146,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) { return err == nil }, expectations: []mockExpectation{ - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}), + expectBegin(), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}), + expectCommit(), }, SearchQueryBuilder: eventstore. NewSearchQueryBuilder(eventstore.ColumnsEvent). @@ -216,6 +222,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) { t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err) return } + if !reflect.DeepEqual(query, tt.want.SearchQueryBuilder) { t.Errorf("unexpected query: expected %v, got %v", tt.want.SearchQueryBuilder, query) } @@ -289,7 +296,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone), + expectCurrentSequenceErr(false, "my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone), expectRollback(), }, isErr: func(err error) bool { @@ -321,7 +328,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectRollback(), }, isErr: func(err error) bool { @@ -360,7 +367,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectCommit(), }, isErr: func(err error) bool { @@ -399,7 +406,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "agg", []string{"instanceID"}), expectSavePoint(), expectCreate("my_projection", []string{"col"}, []string{"$1"}), expectSavePointRelease(), @@ -442,7 +449,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "agg", []string{"instanceID"}), expectSavePoint(), expectCreate("my_projection", []string{"col"}, []string{"$1"}), expectSavePointRelease(), @@ -478,7 +485,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectCommit(), }, @@ -511,7 +518,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectCommit(), }, @@ -551,7 +558,7 @@ func TestStatementHandler_Update(t *testing.T) { want: want{ expectations: []mockExpectation{ expectBegin(), - expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), + expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectCommit(), }, @@ -1425,7 +1432,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { return errors.Is(err, sql.ErrConnDone) }, expectations: []mockExpectation{ - expectCurrentSequenceErr("my_table", "my_projection", nil, sql.ErrConnDone), + expectCurrentSequenceErr(true, "my_table", "my_projection", nil, sql.ErrConnDone), }, }, }, @@ -1487,7 +1494,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { return errors.Is(err, nil) }, expectations: []mockExpectation{ - expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID"}), + expectCurrentSequence(true, "my_table", "my_projection", 5, "agg", []string{"instanceID"}), }, sequences: currentSequences{ "agg": []*instanceSequence{ @@ -1515,7 +1522,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { return errors.Is(err, nil) }, expectations: []mockExpectation{ - expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}), + expectCurrentSequence(true, "my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}), }, sequences: currentSequences{ "agg": []*instanceSequence{ @@ -1563,7 +1570,7 @@ func TestStatementHandler_currentSequence(t *testing.T) { t.Fatalf("unexpected err in begin: %v", err) } - seq, err := h.currentSequences(context.Background(), tx.QueryContext, tt.args.instanceIDs) + seq, err := h.currentSequences(context.Background(), true, (&transaction{Tx: tx}).QueryContext, tt.args.instanceIDs) if !tt.want.isErr(err) { t.Errorf("unexpected error: %v", err) } diff --git a/internal/eventstore/repository/sql/crdb.go b/internal/eventstore/repository/sql/crdb.go index 62cf1f92a5..14fab23b61 100644 --- a/internal/eventstore/repository/sql/crdb.go +++ b/internal/eventstore/repository/sql/crdb.go @@ -161,13 +161,18 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`) func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error { - row := db.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID) - if row.Err() != nil { - return caos_errs.ThrowInvalidArgument(row.Err(), "SQL-7gtFA", "Errors.InvalidArgument") - } var sequenceName string - if err := row.Scan(&sequenceName); err != nil || !instanceRegexp.MatchString(sequenceName) { - return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument") + err := db.QueryRowContext(ctx, + func(row *sql.Row) error { + if err := row.Scan(&sequenceName); err != nil || !instanceRegexp.MatchString(sequenceName) { + return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument") + } + return nil + }, + "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID, + ) + if err != nil { + return err } if _, err := db.ExecContext(ctx, "CREATE SEQUENCE "+sequenceName); err != nil { @@ -220,9 +225,9 @@ func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueC } // Filter returns all events matching the given search query -func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) { +func (crdb *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) { events = []*repository.Event{} - err = query(ctx, db, searchQuery, &events) + err = query(ctx, crdb, searchQuery, &events) if err != nil { return nil, err } @@ -250,8 +255,8 @@ func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQ return ids, nil } -func (db *CRDB) db() *sql.DB { - return db.DB.DB +func (db *CRDB) db() *database.DB { + return db.DB } func (db *CRDB) orderByEventSequence(desc bool) string { diff --git a/internal/eventstore/repository/sql/query.go b/internal/eventstore/repository/sql/query.go index 2457a0638b..8bc3b3ac97 100644 --- a/internal/eventstore/repository/sql/query.go +++ b/internal/eventstore/repository/sql/query.go @@ -11,6 +11,7 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/call" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database/dialect" z_errors "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/eventstore/repository" @@ -24,13 +25,33 @@ type querier interface { eventQuery() string maxSequenceQuery() string instanceIDsQuery() string - db() *sql.DB + db() *database.DB orderByEventSequence(desc bool) string dialect.Database } type scan func(dest ...interface{}) error +type tx struct { + *sql.Tx +} + +func (t *tx) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error { + rows, err := t.Tx.QueryContext(ctx, query, args...) + if err != nil { + return err + } + defer func() { + closeErr := rows.Close() + logging.OnError(closeErr).Info("rows.Close failed") + }() + + if err = scan(rows); err != nil { + return err + } + return rows.Err() +} + func query(ctx context.Context, criteria querier, searchQuery *repository.SearchQuery, dest interface{}) error { query, rowScanner := prepareColumns(criteria, searchQuery.Columns) where, values := prepareCondition(criteria, searchQuery.Filters) @@ -56,26 +77,27 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search query = criteria.placeholder(query) var contextQuerier interface { - QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error } contextQuerier = criteria.db() if searchQuery.Tx != nil { - contextQuerier = searchQuery.Tx + contextQuerier = &tx{Tx: searchQuery.Tx} } - rows, err := contextQuerier.QueryContext(ctx, query, values...) + err := contextQuerier.QueryContext(ctx, + func(rows *sql.Rows) error { + for rows.Next() { + err := rowScanner(rows.Scan, dest) + if err != nil { + return err + } + } + return nil + }, query, values...) if err != nil { logging.New().WithError(err).Info("query failed") return z_errors.ThrowInternal(err, "SQL-KyeAx", "unable to filter events") } - defer rows.Close() - - for rows.Next() { - err = rowScanner(rows.Scan, dest) - if err != nil { - return err - } - } return nil } diff --git a/internal/eventstore/repository/sql/query_test.go b/internal/eventstore/repository/sql/query_test.go index af9aa9860c..d7438ed273 100644 --- a/internal/eventstore/repository/sql/query_test.go +++ b/internal/eventstore/repository/sql/query_test.go @@ -741,7 +741,7 @@ func Test_query_events_mocked(t *testing.T) { }, }, fields: fields{ - mock: newMockClient(t).expectQuery(t, + mock: newMockClient(t).expectQueryScanErr(t, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC`, []driver.Value{repository.AggregateType("user")}, &repository.Event{Sequence: 100}), @@ -853,7 +853,21 @@ type dbMock struct { } func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock { + m.mock.ExpectBegin() query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...) + m.mock.ExpectCommit() + rows := sqlmock.NewRows([]string{"event_sequence"}) + for _, event := range events { + rows = rows.AddRow(event.Sequence) + } + query.WillReturnRows(rows).RowsWillBeClosed() + return m +} + +func (m *dbMock) expectQueryScanErr(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock { + m.mock.ExpectBegin() + query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...) + m.mock.ExpectRollback() rows := sqlmock.NewRows([]string{"event_sequence"}) for _, event := range events { rows = rows.AddRow(event.Sequence) @@ -863,6 +877,7 @@ func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.V } func (m *dbMock) expectQueryErr(t *testing.T, expectedQuery string, args []driver.Value, err error) *dbMock { + m.mock.ExpectBegin() m.mock.ExpectQuery(expectedQuery).WithArgs(args...).WillReturnError(err) return m } diff --git a/internal/eventstore/v1/internal/repository/sql/db_mock_test.go b/internal/eventstore/v1/internal/repository/sql/db_mock_test.go index e1271fc0d5..da64dfd5d2 100644 --- a/internal/eventstore/v1/internal/repository/sql/db_mock_test.go +++ b/internal/eventstore/v1/internal/repository/sql/db_mock_test.go @@ -127,9 +127,11 @@ func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, ev for i := 0; i < eventCount; i++ { rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") } + db.mock.ExpectBegin() db.mock.ExpectQuery(expectedFilterEventsLimitFormat). WithArgs(aggregateType, limit). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -138,8 +140,10 @@ func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) * for i := eventCount; i > 0; i-- { rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") } + db.mock.ExpectBegin() db.mock.ExpectQuery(expectedFilterEventsDescFormat). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -148,9 +152,11 @@ func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID for i := limit; i > 0; i-- { rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") } + db.mock.ExpectBegin() db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit). WithArgs(aggregateType, aggregateID, limit). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -159,28 +165,36 @@ func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregat for i := limit; i > 0; i-- { rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") } + db.mock.ExpectBegin() db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit). WithArgs(aggregateType, aggregateID, limit). WillReturnRows(rows) + db.mock.ExpectCommit() return db } func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock { + db.mock.ExpectBegin() db.mock.ExpectQuery(expectedGetAllEvents). WillReturnError(returnedErr) + db.mock.ExpectRollback() return db } func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock { + db.mock.ExpectBegin() db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`). WithArgs(aggregateType). WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence)) + db.mock.ExpectCommit() return db } func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock { + db.mock.ExpectBegin() db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`). WithArgs(aggregateType).WillReturnError(err) + // db.mock.ExpectRollback() return db } diff --git a/internal/eventstore/v1/internal/repository/sql/filter.go b/internal/eventstore/v1/internal/repository/sql/filter.go index ab67730b59..850a1ee30a 100644 --- a/internal/eventstore/v1/internal/repository/sql/filter.go +++ b/internal/eventstore/v1/internal/repository/sql/filter.go @@ -3,12 +3,13 @@ package sql import ( "context" "database/sql" + "errors" "runtime/debug" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/database" - "github.com/zitadel/zitadel/internal/errors" + errs "github.com/zitadel/zitadel/internal/errors" es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) @@ -24,72 +25,74 @@ func (db *SQL) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFac return db.filter(ctx, db.client, searchQuery) } -func (sql *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) { - query, limit, values, rowScanner := sql.buildQuery(ctx, db, searchQuery) +func (server *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) { + query, limit, values, rowScanner := server.buildQuery(ctx, db, searchQuery) if query == "" { - return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") + return nil, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") } - rows, err := db.Query(query, values...) - if err != nil { - logging.New().WithError(err).Info("query failed") - return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events") - } - defer rows.Close() - events = make([]*es_models.Event, 0, limit) + err = db.QueryContext(ctx, + func(rows *sql.Rows) error { + for rows.Next() { + event := new(es_models.Event) + err := rowScanner(rows.Scan, event) + if err != nil { + return err + } - for rows.Next() { - event := new(es_models.Event) - err := rowScanner(rows.Scan, event) - if err != nil { - return nil, err - } - - events = append(events, event) + events = append(events, event) + } + return nil + }, + query, values..., + ) + if err != nil { + logging.New().WithError(err).Info("query failed") + return nil, errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events") } - return events, nil } func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) { query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory) if query == "" { - return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") + return 0, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") } - row := db.client.QueryRow(query, values...) sequence := new(Sequence) - err := rowScanner(row.Scan, sequence) - if err != nil { + err := db.client.QueryRowContext(ctx, func(row *sql.Row) error { + return rowScanner(row.Scan, sequence) + }, query, values...) + if err != nil && !errors.Is(err, sql.ErrNoRows) { logging.New().WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Info("query failed") - return 0, errors.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence") + return 0, errs.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence") } return uint64(*sequence), nil } -func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) ([]string, error) { +func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (ids []string, err error) { query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory) if query == "" { - return nil, errors.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory") + return nil, errs.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory") } - rows, err := db.client.Query(query, values...) + err = db.client.QueryContext(ctx, + func(rows *sql.Rows) error { + for rows.Next() { + var id string + err := rowScanner(rows.Scan, &id) + if err != nil { + return err + } + + ids = append(ids, id) + } + return nil + }, + query, values...) if err != nil { logging.New().WithError(err).Info("query failed") - return nil, errors.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids") - } - defer rows.Close() - - ids := make([]string, 0) - - for rows.Next() { - var id string - err := rowScanner(rows.Scan, &id) - if err != nil { - return nil, err - } - - ids = append(ids, id) + return nil, errs.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids") } return ids, nil diff --git a/internal/eventstore/v1/internal/repository/sql/filter_test.go b/internal/eventstore/v1/internal/repository/sql/filter_test.go index 75863fed8a..39ebbd0add 100644 --- a/internal/eventstore/v1/internal/repository/sql/filter_test.go +++ b/internal/eventstore/v1/internal/repository/sql/filter_test.go @@ -130,6 +130,7 @@ func TestSQL_Filter(t *testing.T) { if (err != nil) != tt.res.wantErr { t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr) } + if tt.res.eventsLen != 0 && len(events) != tt.res.eventsLen { t.Errorf("events has wrong length got: %d want %d", len(events), tt.res.eventsLen) } @@ -221,10 +222,12 @@ func TestSQL_LatestSequence(t *testing.T) { sql := &SQL{ client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)}, } + sequence, err := sql.LatestSequence(context.Background(), tt.args.searchQuery) if (err != nil) != tt.res.wantErr { t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr) } + if tt.res.sequence != sequence { t.Errorf("events has wrong length got: %d want %d", sequence, tt.res.sequence) } diff --git a/internal/logstore/emitters/access/database.go b/internal/logstore/emitters/access/database.go index bc121612e3..de3afe5aee 100644 --- a/internal/logstore/emitters/access/database.go +++ b/internal/logstore/emitters/access/database.go @@ -2,6 +2,7 @@ package access import ( "context" + "database/sql" "fmt" "net/http" "strings" @@ -136,9 +137,15 @@ func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, } var count uint64 - if err = l.dbClient. - QueryRowContext(ctx, stmt, args...). - Scan(&count); err != nil { + err = l.dbClient. + QueryRowContext(ctx, + func(row *sql.Row) error { + return row.Scan(&count) + }, + stmt, args..., + ) + + if err != nil { return 0, caos_errors.ThrowInternal(err, "ACCESS-pBPrM", "Errors.Logstore.Access.ScanFailed") } diff --git a/internal/logstore/emitters/execution/database.go b/internal/logstore/emitters/execution/database.go index 106b433200..396f3f2ea1 100644 --- a/internal/logstore/emitters/execution/database.go +++ b/internal/logstore/emitters/execution/database.go @@ -2,6 +2,7 @@ package execution import ( "context" + "database/sql" "fmt" "time" @@ -113,9 +114,14 @@ func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string, } var durationSeconds uint64 - if err = l.dbClient. - QueryRowContext(ctx, stmt, args...). - Scan(&durationSeconds); err != nil { + err = l.dbClient. + QueryRowContext(ctx, + func(row *sql.Row) error { + return row.Scan(&durationSeconds) + }, + stmt, args..., + ) + if err != nil { return 0, caos_errors.ThrowInternal(err, "EXEC-Ad8nP", "Errors.Logstore.Execution.ScanFailed") } return durationSeconds, nil diff --git a/internal/query/action.go b/internal/query/action.go index 5c5a786bbb..26aee576ec 100644 --- a/internal/query/action.go +++ b/internal/query/action.go @@ -130,19 +130,19 @@ func (q *Queries) SearchActions(ctx context.Context, queries *ActionSearchQuerie return nil, errors.ThrowInvalidArgument(err, "QUERY-SDgwg", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + actions, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-SDfr52", "Errors.Internal") } - actions, err = scan(rows) - if err != nil { - return nil, err - } + actions.LatestSequence, err = q.latestSequence(ctx, actionTable) return actions, err } -func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, withOwnerRemoved bool) (_ *Action, err error) { +func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, withOwnerRemoved bool) (action *Action, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -160,8 +160,11 @@ func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, wi return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + action, err = scan(row) + return err + }, query, args...) + return action, err } func NewActionResourceOwnerQuery(id string) (SearchQuery, error) { diff --git a/internal/query/action_flow.go b/internal/query/action_flow.go index ca6c999233..0b263041a3 100644 --- a/internal/query/action_flow.go +++ b/internal/query/action_flow.go @@ -67,7 +67,7 @@ type Flow struct { TriggerActions map[domain.TriggerType][]*Action } -func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID string, withOwnerRemoved bool) (_ *Flow, err error) { +func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID string, withOwnerRemoved bool) (flow *Flow, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -85,14 +85,14 @@ func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID s return nil, errors.ThrowInvalidArgument(err, "QUERY-HBRh3", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil { - return nil, errors.ThrowInternal(err, "QUERY-Gg42f", "Errors.Internal") - } - return scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + flow, err = scan(rows) + return err + }, stmt, args...) + return flow, err } -func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flowType domain.FlowType, triggerType domain.TriggerType, orgID string, withOwnerRemoved bool) (_ []*Action, err error) { +func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flowType domain.FlowType, triggerType domain.TriggerType, orgID string, withOwnerRemoved bool) (actions []*Action, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -112,14 +112,14 @@ func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flow return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, query, args...) - if err != nil { - return nil, errors.ThrowInternal(err, "QUERY-SDf52", "Errors.Internal") - } - return scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + actions, err = scan(rows) + return err + }, query, args...) + return actions, err } -func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, withOwnerRemoved bool) (_ []domain.FlowType, err error) { +func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, withOwnerRemoved bool) (types []domain.FlowType, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -136,12 +136,11 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, w return nil, errors.ThrowInvalidArgument(err, "QUERY-Dh311", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, query, args...) - if err != nil { - return nil, errors.ThrowInternal(err, "QUERY-Bhj4w", "Errors.Internal") - } - - return scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + types, err = scan(rows) + return err + }, query, args...) + return types, err } func prepareFlowTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowType, error)) { diff --git a/internal/query/action_flow_test.go b/internal/query/action_flow_test.go index a4a9131d89..897bbd04f4 100644 --- a/internal/query/action_flow_test.go +++ b/internal/query/action_flow_test.go @@ -33,8 +33,8 @@ var ( ` projections.flow_triggers2.sequence,` + ` projections.flow_triggers2.resource_owner` + ` FROM projections.flow_triggers2` + - ` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` + // ` AS OF SYSTEM TIME '-1 ms'` prepareFlowCols = []string{ "id", "creation_date", @@ -66,8 +66,8 @@ var ( ` projections.actions3.allowed_to_fail,` + ` projections.actions3.timeout` + ` FROM projections.flow_triggers2` + - ` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` + - ` AS OF SYSTEM TIME '-1 ms'` + ` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` + // ` AS OF SYSTEM TIME '-1 ms'` prepareTriggerActionCols = []string{ "id", @@ -83,8 +83,8 @@ var ( } prepareFlowTypeStmt = `SELECT projections.flow_triggers2.flow_type` + - ` FROM projections.flow_triggers2` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.flow_triggers2` + // ` AS OF SYSTEM TIME '-1 ms'` prepareFlowTypeCols = []string{ "flow_type", diff --git a/internal/query/action_test.go b/internal/query/action_test.go index 2dfe6e8e32..4c0f00d075 100644 --- a/internal/query/action_test.go +++ b/internal/query/action_test.go @@ -25,8 +25,8 @@ var ( ` projections.actions3.timeout,` + ` projections.actions3.allowed_to_fail,` + ` COUNT(*) OVER ()` + - ` FROM projections.actions3` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.actions3` + // ` AS OF SYSTEM TIME '-1 ms'` prepareActionsCols = []string{ "id", "creation_date", @@ -51,8 +51,8 @@ var ( ` projections.actions3.script,` + ` projections.actions3.timeout,` + ` projections.actions3.allowed_to_fail` + - ` FROM projections.actions3` + - ` AS OF SYSTEM TIME '-1 ms'` + ` FROM projections.actions3` + // ` AS OF SYSTEM TIME '-1 ms'` prepareActionCols = []string{ "id", "creation_date", @@ -215,13 +215,13 @@ func Test_ActionPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Action)(nil), }, { name: "prepareActionQuery no result", prepare: prepareActionQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareActionStmt), nil, nil, @@ -284,7 +284,7 @@ func Test_ActionPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Action)(nil), }, } for _, tt := range tests { diff --git a/internal/query/app.go b/internal/query/app.go index 577f6c1814..639b2e141c 100644 --- a/internal/query/app.go +++ b/internal/query/app.go @@ -248,7 +248,7 @@ var ( } ) -func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bool, projectID, appID string, withOwnerRemoved bool) (_ *App, err error) { +func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bool, projectID, appID string, withOwnerRemoved bool) (app *App, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -270,11 +270,14 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo return nil, errors.ThrowInternal(err, "QUERY-AFDgg", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + app, err = scan(row) + return err + }, query, args...) + return app, err } -func (q *Queries) AppByID(ctx context.Context, appID string, withOwnerRemoved bool) (_ *App, err error) { +func (q *Queries) AppByID(ctx context.Context, appID string, withOwnerRemoved bool) (app *App, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -291,11 +294,14 @@ func (q *Queries) AppByID(ctx context.Context, appID string, withOwnerRemoved bo return nil, errors.ThrowInternal(err, "QUERY-immt9", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + app, err = scan(row) + return err + }, query, args...) + return app, err } -func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string, withOwnerRemoved bool) (_ *App, err error) { +func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string, withOwnerRemoved bool) (app *App, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -312,11 +318,14 @@ func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string, withOw return nil, errors.ThrowInternal(err, "QUERY-JgUop", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + app, err = scan(row) + return err + }, query, args...) + return app, err } -func (q *Queries) ProjectByClientID(ctx context.Context, appID string, withOwnerRemoved bool) (_ *Project, err error) { +func (q *Queries) ProjectByClientID(ctx context.Context, appID string, withOwnerRemoved bool) (project *Project, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -337,11 +346,14 @@ func (q *Queries) ProjectByClientID(ctx context.Context, appID string, withOwner return nil, errors.ThrowInternal(err, "QUERY-XhJi3", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + project, err = scan(row) + return err + }, query, args...) + return project, err } -func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string, withOwnerRemoved bool) (_ string, err error) { +func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string, withOwnerRemoved bool) (id string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -358,11 +370,14 @@ func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string, w return "", errors.ThrowInternal(err, "QUERY-7d92U", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + id, err = scan(row) + return err + }, query, args...) + return id, err } -func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string, withOwnerRemoved bool) (_ string, err error) { +func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string, withOwnerRemoved bool) (id string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -384,11 +399,14 @@ func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string, withO return "", errors.ThrowInternal(err, "QUERY-SDfg3", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + id, err = scan(row) + return err + }, query, args...) + return id, err } -func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string, withOwnerRemoved bool) (_ *Project, err error) { +func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string, withOwnerRemoved bool) (project *Project, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -405,11 +423,14 @@ func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string, withOwne return nil, errors.ThrowInternal(err, "QUERY-XhJi4", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + project, err = scan(row) + return err + }, query, args...) + return project, err } -func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (_ *App, err error) { +func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (app *App, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -426,11 +447,14 @@ func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string, withOw return nil, errors.ThrowInternal(err, "QUERY-JgVop", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + app, err = scan(row) + return err + }, query, args...) + return app, err } -func (q *Queries) AppByClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (_ *App, err error) { +func (q *Queries) AppByClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (app *App, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -450,11 +474,14 @@ func (q *Queries) AppByClientID(ctx context.Context, clientID string, withOwnerR return nil, errors.ThrowInternal(err, "QUERY-Dfge2", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + app, err = scan(row) + return err + }, query, args...) + return app, err } -func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (_ *Apps, err error) { +func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (apps *Apps, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -468,19 +495,18 @@ func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, wit return nil, errors.ThrowInvalidArgument(err, "QUERY-fajp8", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + apps, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-aJnZL", "Errors.Internal") } - apps, err := scan(rows) - if err != nil { - return nil, err - } apps.LatestSequence, err = q.latestSequence(ctx, appsTable) return apps, err } -func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (_ []string, err error) { +func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (ids []string, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -494,11 +520,14 @@ func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries return nil, errors.ThrowInvalidArgument(err, "QUERY-fajp8", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + ids, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-aJnZL", "Errors.Internal") } - return scan(rows) + return ids, nil } func NewAppNameSearchQuery(method TextComparison, value string) (SearchQuery, error) { diff --git a/internal/query/app_test.go b/internal/query/app_test.go index e75e4ce06f..ca93baddf5 100644 --- a/internal/query/app_test.go +++ b/internal/query/app_test.go @@ -1115,7 +1115,7 @@ func Test_AppsPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*App)(nil), }, } for _, tt := range tests { @@ -1140,7 +1140,7 @@ func Test_AppPrepare(t *testing.T) { name: "prepareAppQuery no result", prepare: prepareAppQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( expectedAppQuery, nil, nil, @@ -1747,7 +1747,7 @@ func Test_AppPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*App)(nil), }, } for _, tt := range tests { @@ -1833,7 +1833,7 @@ func Test_AppIDsPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*App)(nil), }, } for _, tt := range tests { @@ -1858,7 +1858,7 @@ func Test_ProjectIDByAppPrepare(t *testing.T) { name: "prepareProjectIDByAppQuery no result", prepare: prepareProjectIDByAppQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( expectedProjectIDByAppQuery, nil, nil, @@ -1899,7 +1899,7 @@ func Test_ProjectIDByAppPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: "", }, } for _, tt := range tests { @@ -1924,7 +1924,7 @@ func Test_ProjectByAppPrepare(t *testing.T) { name: "prepareProjectByAppQuery no result", prepare: prepareProjectByAppQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( expectedProjectByAppQuery, nil, nil, @@ -2097,7 +2097,7 @@ func Test_ProjectByAppPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Project)(nil), }, } for _, tt := range tests { diff --git a/internal/query/auth_request.go b/internal/query/auth_request.go index ee91bdebea..334c101c99 100644 --- a/internal/query/auth_request.go +++ b/internal/query/auth_request.go @@ -60,12 +60,16 @@ func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, i ) dst := new(AuthRequest) - err = q.client.DB.QueryRowContext( - ctx, q.authRequestByIDQuery(ctx), + err = q.client.QueryRowContext( + ctx, + func(row *sql.Row) error { + return row.Scan( + &dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.ClientID, &scope, &dst.RedirectURI, + &prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID, + ) + }, + q.authRequestByIDQuery(ctx), id, authz.GetInstance(ctx).InstanceID(), - ).Scan( - &dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.ClientID, &scope, &dst.RedirectURI, - &prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID, ) if errs.Is(err, sql.ErrNoRows) { return nil, errors.ThrowNotFound(err, "QUERY-Thee9", "Errors.AuthRequest.NotExisting") diff --git a/internal/query/auth_request_test.go b/internal/query/auth_request_test.go index 7348855e75..ba23b7315d 100644 --- a/internal/query/auth_request_test.go +++ b/internal/query/auth_request_test.go @@ -125,7 +125,7 @@ func TestQueries_AuthRequestByID(t *testing.T) { shouldTriggerBulk: false, id: "123", }, - expect: mockQuery(expQuery, cols, nil, "123", "instanceID"), + expect: mockQueryScanErr(expQuery, cols, nil, "123", "instanceID"), wantErr: errors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.AuthRequest.NotExisting"), }, { diff --git a/internal/query/authn_key.go b/internal/query/authn_key.go index 87140b2165..2aa3576ad6 100644 --- a/internal/query/authn_key.go +++ b/internal/query/authn_key.go @@ -144,14 +144,14 @@ func (q *Queries) SearchAuthNKeys(ctx context.Context, queries *AuthNKeySearchQu return nil, errors.ThrowInvalidArgument(err, "QUERY-SAf3f", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + authNKeys, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Dbg53", "Errors.Internal") } - authNKeys, err = scan(rows) - if err != nil { - return nil, err - } + authNKeys.LatestSequence, err = q.latestSequence(ctx, authNKeyTable) return authNKeys, err } @@ -174,19 +174,18 @@ func (q *Queries) SearchAuthNKeysData(ctx context.Context, queries *AuthNKeySear return nil, errors.ThrowInvalidArgument(err, "QUERY-SAg3f", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + authNKeys, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Dbi53", "Errors.Internal") } - authNKeys, err = scan(rows) - if err != nil { - return nil, err - } authNKeys.LatestSequence, err = q.latestSequence(ctx, authNKeyTable) return authNKeys, err } -func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (_ *AuthNKey, err error) { +func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (key *AuthNKey, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -211,11 +210,14 @@ func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, i return nil, errors.ThrowInternal(err, "QUERY-AGhg4", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + key, err = scan(row) + return err + }, stmt, args...) + return key, err } -func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string, withOwnerRemoved bool) (_ []byte, err error) { +func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string, withOwnerRemoved bool) (key []byte, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -250,8 +252,11 @@ func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id return nil, errors.ThrowInternal(err, "QUERY-DAb32", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + key, err = scan(row) + return err + }, query, args...) + return key, err } func NewAuthNKeyResourceOwnerQuery(id string) (SearchQuery, error) { diff --git a/internal/query/authn_key_test.go b/internal/query/authn_key_test.go index 9deb9bfa4c..620e6a5079 100644 --- a/internal/query/authn_key_test.go +++ b/internal/query/authn_key_test.go @@ -349,13 +349,13 @@ func Test_AuthNKeyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*AuthNKey)(nil), }, { name: "prepareAuthNKeyQuery no result", prepare: prepareAuthNKeyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareAuthNKeyStmt), nil, nil, @@ -412,13 +412,13 @@ func Test_AuthNKeyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*AuthNKey)(nil), }, { name: "prepareAuthNKeyPublicKeyQuery no result", prepare: prepareAuthNKeyPublicKeyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareAuthNKeyPublicKeyStmt), nil, nil, @@ -461,7 +461,7 @@ func Test_AuthNKeyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: ([]byte)(nil), }, } for _, tt := range tests { diff --git a/internal/query/certificate.go b/internal/query/certificate.go index 2856138804..6b5f06ebfa 100644 --- a/internal/query/certificate.go +++ b/internal/query/certificate.go @@ -66,7 +66,7 @@ var ( } ) -func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage domain.KeyUsage) (_ *Certificates, err error) { +func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage domain.KeyUsage) (certs *Certificates, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -88,19 +88,19 @@ func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage dom return nil, errors.ThrowInternal(err, "QUERY-SDfkg", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + certs, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Sgan4", "Errors.Internal") } - keys, err := scan(rows) - if err != nil { - return nil, err - } - keys.LatestSequence, err = q.latestSequence(ctx, keyTable) + + certs.LatestSequence, err = q.latestSequence(ctx, keyTable) if !errors.IsNotFound(err) { - return keys, err + return certs, err } - return keys, nil + return certs, nil } func prepareCertificateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) { diff --git a/internal/query/certificate_test.go b/internal/query/certificate_test.go index 4bd1de2be7..6e8b52abfa 100644 --- a/internal/query/certificate_test.go +++ b/internal/query/certificate_test.go @@ -138,7 +138,7 @@ func Test_CertificatePrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Certificate)(nil), }, } for _, tt := range tests { diff --git a/internal/query/current_sequence.go b/internal/query/current_sequence.go index ba7f278713..50c18a9006 100644 --- a/internal/query/current_sequence.go +++ b/internal/query/current_sequence.go @@ -66,14 +66,17 @@ func (q *Queries) SearchCurrentSequences(ctx context.Context, queries *CurrentSe return nil, errors.ThrowInvalidArgument(err, "QUERY-MmFef", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + failedEvents, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-22H8f", "Errors.Internal") } - return scan(rows) + return failedEvents, nil } -func (q *Queries) latestSequence(ctx context.Context, projections ...table) (_ *LatestSequence, err error) { +func (q *Queries) latestSequence(ctx context.Context, projections ...table) (seq *LatestSequence, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -91,8 +94,12 @@ func (q *Queries) latestSequence(ctx context.Context, projections ...table) (_ * return nil, errors.ThrowInternal(err, "QUERY-5CfX9", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + seq, err = scan(row) + return err + }, stmt, args...) + + return seq, err } func (q *Queries) ClearCurrentSequence(ctx context.Context, projectionName string) (err error) { @@ -129,11 +136,17 @@ func (q *Queries) checkAndLock(ctx context.Context, projectionName string) error if err != nil { return errors.ThrowInternal(err, "QUERY-Dfwf2", "Errors.ProjectionName.Invalid") } - row := q.client.QueryRowContext(ctx, projectionQuery, args...) var count int - if err := row.Scan(&count); err != nil || count == 0 { - return errors.ThrowInternal(err, "QUERY-ej8fn", "Errors.ProjectionName.Invalid") + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + if err := row.Scan(&count); err != nil || count == 0 { + return errors.ThrowInternal(err, "QUERY-ej8fn", "Errors.ProjectionName.Invalid") + } + return err + }, projectionQuery, args...) + if err != nil { + return err } + lock := fmt.Sprintf(lockStmtFormat, locksTable.identifier()) if err != nil { return errors.ThrowInternal(err, "QUERY-DVfg3", "Errors.RemoveFailed") diff --git a/internal/query/current_sequence_test.go b/internal/query/current_sequence_test.go index cdd78a43c2..20df6aea0a 100644 --- a/internal/query/current_sequence_test.go +++ b/internal/query/current_sequence_test.go @@ -132,7 +132,7 @@ func Test_CurrentSequencesPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*CurrentSequences)(nil), }, } for _, tt := range tests { diff --git a/internal/query/custom_text.go b/internal/query/custom_text.go index 415cbe3023..5f83866805 100644 --- a/internal/query/custom_text.go +++ b/internal/query/custom_text.go @@ -104,14 +104,14 @@ func (q *Queries) CustomTextList(ctx context.Context, aggregateID, template, lan return nil, errors.ThrowInternal(err, "QUERY-M9gse", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, query, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + texts, err = scan(rows) + return err + }, query, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-2j00f", "Errors.Internal") } - texts, err = scan(rows) - if err != nil { - return nil, err - } + texts.LatestSequence, err = q.latestSequence(ctx, projectsTable) return texts, err } @@ -134,14 +134,14 @@ func (q *Queries) CustomTextListByTemplate(ctx context.Context, aggregateID, tem return nil, errors.ThrowInternal(err, "QUERY-M49fs", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, query, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + texts, err = scan(rows) + return err + }, query, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-3n9ge", "Errors.Internal") } - texts, err = scan(rows) - if err != nil { - return nil, err - } + texts.LatestSequence, err = q.latestSequence(ctx, projectsTable) return texts, err } diff --git a/internal/query/custom_text_test.go b/internal/query/custom_text_test.go index 93a9508ba4..3df40638cc 100644 --- a/internal/query/custom_text_test.go +++ b/internal/query/custom_text_test.go @@ -180,7 +180,7 @@ func Test_CustomTextPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*CustomText)(nil), }, } for _, tt := range tests { diff --git a/internal/query/device_auth.go b/internal/query/device_auth.go index 98faff200b..91d26ba6c9 100644 --- a/internal/query/device_auth.go +++ b/internal/query/device_auth.go @@ -70,7 +70,7 @@ var ( } ) -func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (_ *domain.DeviceAuth, err error) { +func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (deviceAuth *domain.DeviceAuth, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -85,10 +85,14 @@ func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCo return nil, errors.ThrowInternal(err, "QUERY-uk1Oh", "Errors.Query.SQLStatement") } - return scan(q.client.QueryRowContext(ctx, query, args...)) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + deviceAuth, err = scan(row) + return err + }, query, args...) + return deviceAuth, err } -func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (_ *domain.DeviceAuth, err error) { +func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (deviceAuth *domain.DeviceAuth, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -102,7 +106,11 @@ func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (_ return nil, errors.ThrowInternal(err, "QUERY-Axu7l", "Errors.Query.SQLStatement") } - return scan(q.client.QueryRowContext(ctx, query, args...)) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + deviceAuth, err = scan(row) + return err + }, query, args...) + return deviceAuth, err } var deviceAuthSelectColumns = []string{ diff --git a/internal/query/device_auth_test.go b/internal/query/device_auth_test.go index 938cb9f844..112032d328 100644 --- a/internal/query/device_auth_test.go +++ b/internal/query/device_auth_test.go @@ -67,9 +67,11 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) { } defer client.Close() + mock.ExpectBegin() mock.ExpectQuery(expectedDeviceAuthWhereDeviceCodeQuery).WillReturnRows( sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), ) + mock.ExpectCommit() q := Queries{ client: &database.DB{DB: client}, } @@ -86,9 +88,11 @@ func TestQueries_DeviceAuthByUserCode(t *testing.T) { } defer client.Close() + mock.ExpectBegin() mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows( sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), ) + mock.ExpectCommit() q := Queries{ client: &database.DB{DB: client}, } @@ -133,6 +137,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) { return nil, true }, }, + object: (*domain.DeviceAuth)(nil), }, { name: "other error", @@ -148,6 +153,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) { return nil, true }, }, + object: (*domain.DeviceAuth)(nil), }, } for _, tt := range tests { diff --git a/internal/query/domain_policy.go b/internal/query/domain_policy.go index 2f7a596cd1..77e821d39b 100644 --- a/internal/query/domain_policy.go +++ b/internal/query/domain_policy.go @@ -86,7 +86,7 @@ var ( } ) -func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *DomainPolicy, err error) { +func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *DomainPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -120,11 +120,14 @@ func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, return nil, errors.ThrowInternal(err, "QUERY-D3CqT", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultDomainPolicy(ctx context.Context) (_ *DomainPolicy, err error) { +func (q *Queries) DefaultDomainPolicy(ctx context.Context) (policy *DomainPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -139,8 +142,11 @@ func (q *Queries) DefaultDomainPolicy(ctx context.Context) (_ *DomainPolicy, err return nil, errors.ThrowInternal(err, "QUERY-pM7lP", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } func prepareDomainPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy, error)) { diff --git a/internal/query/domain_policy_test.go b/internal/query/domain_policy_test.go index 5940555b28..ef861c737b 100644 --- a/internal/query/domain_policy_test.go +++ b/internal/query/domain_policy_test.go @@ -54,7 +54,7 @@ func Test_DomainPolicyPrepares(t *testing.T) { name: "prepareDomainPolicyQuery no result", prepare: prepareDomainPolicyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareDomainPolicyStmt), nil, nil, @@ -117,7 +117,7 @@ func Test_DomainPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*DomainPolicy)(nil), }, } for _, tt := range tests { diff --git a/internal/query/failed_events.go b/internal/query/failed_events.go index 2afd28d5e4..ed335bf621 100644 --- a/internal/query/failed_events.go +++ b/internal/query/failed_events.go @@ -77,11 +77,14 @@ func (q *Queries) SearchFailedEvents(ctx context.Context, queries *FailedEventSe return nil, errors.ThrowInvalidArgument(err, "QUERY-n8rjJ", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + failedEvents, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-3j99J", "Errors.Internal") } - return scan(rows) + return failedEvents, nil } func (q *Queries) RemoveFailedEvent(ctx context.Context, projectionName, instanceID string, sequence uint64) (err error) { diff --git a/internal/query/failed_events_test.go b/internal/query/failed_events_test.go index ede972c27c..01ba2bee2f 100644 --- a/internal/query/failed_events_test.go +++ b/internal/query/failed_events_test.go @@ -146,7 +146,7 @@ func Test_FailedEventsPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*FailedEvents)(nil), }, } for _, tt := range tests { diff --git a/internal/query/iam_member.go b/internal/query/iam_member.go index ae8e601ab8..bba45fbf54 100644 --- a/internal/query/iam_member.go +++ b/internal/query/iam_member.go @@ -76,7 +76,7 @@ func addIamMemberWithoutOwnerRemoved(eq map[string]interface{}) { eq[InstanceMemberOwnerRemovedUser.identifier()] = false } -func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, withOwnerRemoved bool) (_ *Members, err error) { +func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, withOwnerRemoved bool) (members *Members, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -96,14 +96,13 @@ func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, with return nil, err } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + members, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Pdg1I", "Errors.Internal") } - members, err := scan(rows) - if err != nil { - return nil, err - } members.LatestSequence = currentSequence return members, err } diff --git a/internal/query/iam_member_test.go b/internal/query/iam_member_test.go index 849f420ceb..ed2bd6cd32 100644 --- a/internal/query/iam_member_test.go +++ b/internal/query/iam_member_test.go @@ -280,7 +280,7 @@ func Test_IAMMemberPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*IAMMembership)(nil), }, } for _, tt := range tests { diff --git a/internal/query/idp.go b/internal/query/idp.go index 35c0b7b5d7..e29077b0bb 100644 --- a/internal/query/idp.go +++ b/internal/query/idp.go @@ -188,7 +188,7 @@ var ( ) // IDPByIDAndResourceOwner searches for the requested id in the context of the resource owner and IAM -func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk bool, id, resourceOwner string, withOwnerRemoved bool) (_ *IDP, err error) { +func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk bool, id, resourceOwner string, withOwnerRemoved bool) (idp *IDP, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -216,8 +216,11 @@ func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk return nil, errors.ThrowInternal(err, "QUERY-0gocI", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + idp, err = scan(row) + return err + }, query, args...) + return idp, err } // IDPs searches idps matching the query @@ -237,14 +240,13 @@ func (q *Queries) IDPs(ctx context.Context, queries *IDPSearchQueries, withOwner return nil, errors.ThrowInvalidArgument(err, "QUERY-X6X7y", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + idps, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-xPlVH", "Errors.Internal") } - idps, err = scan(rows) - if err != nil { - return nil, err - } idps.LatestSequence, err = q.latestSequence(ctx, idpTable) return idps, err } diff --git a/internal/query/idp_login_policy_link.go b/internal/query/idp_login_policy_link.go index 750f2f1141..89ee4a08b6 100644 --- a/internal/query/idp_login_policy_link.go +++ b/internal/query/idp_login_policy_link.go @@ -105,13 +105,13 @@ func (q *Queries) IDPLoginPolicyLinks(ctx context.Context, resourceOwner string, if err != nil { return nil, errors.ThrowInvalidArgument(err, "QUERY-FDbKW", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil || rows.Err() != nil { - return nil, errors.ThrowInternal(err, "QUERY-ZkKUc", "Errors.Internal") - } - idps, err = scan(rows) + + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + idps, err = scan(rows) + return err + }, stmt, args...) if err != nil { - return nil, err + return nil, errors.ThrowInternal(err, "QUERY-ZkKUc", "Errors.Internal") } idps.LatestSequence, err = q.latestSequence(ctx, idpLoginPolicyLinkTable) return idps, err diff --git a/internal/query/idp_login_policy_link_test.go b/internal/query/idp_login_policy_link_test.go index 22f07e7fa1..4f720b86af 100644 --- a/internal/query/idp_login_policy_link_test.go +++ b/internal/query/idp_login_policy_link_test.go @@ -128,7 +128,7 @@ func Test_IDPLoginPolicyLinkPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*IDPs)(nil), }, } for _, tt := range tests { diff --git a/internal/query/idp_template.go b/internal/query/idp_template.go index b83efe156e..a19d6b512a 100644 --- a/internal/query/idp_template.go +++ b/internal/query/idp_template.go @@ -606,7 +606,7 @@ var ( ) // IDPTemplateByID searches for the requested id -func (q *Queries) IDPTemplateByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (_ *IDPTemplate, err error) { +func (q *Queries) IDPTemplateByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (template *IDPTemplate, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -630,8 +630,11 @@ func (q *Queries) IDPTemplateByID(ctx context.Context, shouldTriggerBulk bool, i return nil, errors.ThrowInternal(err, "QUERY-SFefg", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + template, err = scan(row) + return err + }, stmt, args...) + return template, err } // IDPTemplates searches idp templates matching the query @@ -651,14 +654,13 @@ func (q *Queries) IDPTemplates(ctx context.Context, queries *IDPTemplateSearchQu return nil, errors.ThrowInvalidArgument(err, "QUERY-SAF34", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + idps, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-BDFrq", "Errors.Internal") } - idps, err = scan(rows) - if err != nil { - return nil, err - } idps.LatestSequence, err = q.latestSequence(ctx, idpTemplateTable) return idps, err } diff --git a/internal/query/idp_template_test.go b/internal/query/idp_template_test.go index a646861b7d..34c3ad3720 100644 --- a/internal/query/idp_template_test.go +++ b/internal/query/idp_template_test.go @@ -443,7 +443,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { name: "prepareIDPTemplateByIDQuery no result", prepare: prepareIDPTemplateByIDQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(idpTemplateQuery), nil, nil, @@ -1646,7 +1646,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*IDPTemplate)(nil), }, { name: "prepareIDPTemplatesQuery no result", @@ -2606,7 +2606,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*IDPTemplates)(nil), }, } for _, tt := range tests { diff --git a/internal/query/idp_test.go b/internal/query/idp_test.go index 6a692ed164..a442523696 100644 --- a/internal/query/idp_test.go +++ b/internal/query/idp_test.go @@ -144,7 +144,7 @@ func Test_IDPPrepares(t *testing.T) { name: "prepareIDPByIDQuery no result", prepare: prepareIDPByIDQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(idpQuery), nil, nil, @@ -341,7 +341,7 @@ func Test_IDPPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*IDP)(nil), }, { name: "prepareIDPsQuery no result", @@ -728,7 +728,7 @@ func Test_IDPPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*IDPs)(nil), }, } for _, tt := range tests { diff --git a/internal/query/idp_user_link.go b/internal/query/idp_user_link.go index b2b95e85ed..346ce129e6 100644 --- a/internal/query/idp_user_link.go +++ b/internal/query/idp_user_link.go @@ -103,14 +103,13 @@ func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQ return nil, errors.ThrowInvalidArgument(err, "QUERY-4zzFK", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + idps, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-C1E4D", "Errors.Internal") } - idps, err = scan(rows) - if err != nil { - return nil, err - } idps.LatestSequence, err = q.latestSequence(ctx, idpUserLinkTable) return idps, err } diff --git a/internal/query/idp_user_link_test.go b/internal/query/idp_user_link_test.go index 652d905a5b..af4e3a54d7 100644 --- a/internal/query/idp_user_link_test.go +++ b/internal/query/idp_user_link_test.go @@ -135,7 +135,7 @@ func Test_IDPUserLinkPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*IDPUserLinks)(nil), }, } for _, tt := range tests { diff --git a/internal/query/instance.go b/internal/query/instance.go index 9040478af4..77327e736f 100644 --- a/internal/query/instance.go +++ b/internal/query/instance.go @@ -166,18 +166,17 @@ func (q *Queries) SearchInstances(ctx context.Context, queries *InstanceSearchQu return nil, errors.ThrowInvalidArgument(err, "QUERY-M9fow", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + instances, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-3j98f", "Errors.Internal") } - instances, err = scan(rows) - if err != nil { - return nil, err - } return instances, err } -func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (_ *Instance, err error) { +func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (instance *Instance, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -193,14 +192,14 @@ func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (_ *Inst return nil, errors.ThrowInternal(err, "QUERY-d9ngs", "Errors.Query.SQLStatement") } - row, err := q.client.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - return scan(row) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + instance, err = scan(rows) + return err + }, query, args...) + return instance, err } -func (q *Queries) InstanceByHost(ctx context.Context, host string) (_ authz.Instance, err error) { +func (q *Queries) InstanceByHost(ctx context.Context, host string) (instance authz.Instance, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -213,11 +212,11 @@ func (q *Queries) InstanceByHost(ctx context.Context, host string) (_ authz.Inst return nil, errors.ThrowInternal(err, "QUERY-SAfg2", "Errors.Query.SQLStatement") } - row, err := q.client.QueryContext(ctx, query, args...) - if err != nil { - return nil, err - } - return scan(row) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + instance, err = scan(rows) + return err + }, query, args...) + return instance, err } func (q *Queries) InstanceByID(ctx context.Context) (_ authz.Instance, err error) { diff --git a/internal/query/instance_domain.go b/internal/query/instance_domain.go index 6eedd3f06f..bb11e1aa82 100644 --- a/internal/query/instance_domain.go +++ b/internal/query/instance_domain.go @@ -88,11 +88,10 @@ func (q *Queries) SearchInstanceDomainsGlobal(ctx context.Context, queries *Inst } func (q *Queries) queryInstanceDomains(ctx context.Context, stmt string, scan func(*sql.Rows) (*InstanceDomains, error), args ...interface{}) (domains *InstanceDomains, err error) { - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil { - return nil, errors.ThrowInternal(err, "QUERY-Dh9Ap", "Errors.Internal") - } - domains, err = scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + domains, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, err } diff --git a/internal/query/instance_domain_test.go b/internal/query/instance_domain_test.go index 06d8e8dc04..4f72c0def4 100644 --- a/internal/query/instance_domain_test.go +++ b/internal/query/instance_domain_test.go @@ -162,7 +162,7 @@ func Test_InstanceDomainPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Domains)(nil), }, } for _, tt := range tests { diff --git a/internal/query/instance_test.go b/internal/query/instance_test.go index 0320631283..a14b44df0d 100644 --- a/internal/query/instance_test.go +++ b/internal/query/instance_test.go @@ -97,7 +97,7 @@ func Test_InstancePrepares(t *testing.T) { additionalArgs: []reflect.Value{reflect.ValueOf("")}, prepare: prepareInstanceQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(instanceQuery), nil, nil, @@ -160,7 +160,7 @@ func Test_InstancePrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Instance)(nil), }, { name: "prepareInstancesQuery no result", diff --git a/internal/query/key.go b/internal/query/key.go index 1c9891383c..29a39fa061 100644 --- a/internal/query/key.go +++ b/internal/query/key.go @@ -177,7 +177,7 @@ var ( } ) -func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (_ *PublicKeys, err error) { +func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (keys *PublicKeys, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -194,14 +194,14 @@ func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (_ *PublicK return nil, errors.ThrowInternal(err, "QUERY-SDFfg", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + keys, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Sghn4", "Errors.Internal") } - keys, err := scan(rows) - if err != nil { - return nil, err - } + keys.LatestSequence, err = q.latestSequence(ctx, keyTable) if !errors.IsNotFound(err) { return keys, err @@ -209,7 +209,7 @@ func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (_ *PublicK return keys, nil } -func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (_ *PrivateKeys, err error) { +func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (keys *PrivateKeys, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -229,14 +229,13 @@ func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (_ * return nil, errors.ThrowInternal(err, "QUERY-SDff2", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, query, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + keys, err = scan(rows) + return err + }, query, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-WRFG4", "Errors.Internal") } - keys, err := scan(rows) - if err != nil { - return nil, err - } keys.LatestSequence, err = q.latestSequence(ctx, keyTable) if !errors.IsNotFound(err) { return keys, err diff --git a/internal/query/key_test.go b/internal/query/key_test.go index 26ab889719..a600b23eee 100644 --- a/internal/query/key_test.go +++ b/internal/query/key_test.go @@ -147,7 +147,7 @@ func Test_KeyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*PublicKeys)(nil), }, { name: "preparePrivateKeysQuery no result", @@ -230,7 +230,7 @@ func Test_KeyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*PrivateKeys)(nil), }, } for _, tt := range tests { diff --git a/internal/query/label_policy.go b/internal/query/label_policy.go index 731dbba2e0..476c8de595 100644 --- a/internal/query/label_policy.go +++ b/internal/query/label_policy.go @@ -42,7 +42,7 @@ type Theme struct { IconURL string } -func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (_ *LabelPolicy, err error) { +func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (policy *LabelPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -68,11 +68,14 @@ func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, with return nil, errors.ThrowInternal(err, "QUERY-V22un", "unable to create sql stmt") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (_ *LabelPolicy, err error) { +func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (policy *LabelPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -98,11 +101,14 @@ func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (_ return nil, errors.ThrowInternal(err, "QUERY-AG5eq", "unable to create sql stmt") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (_ *LabelPolicy, err error) { +func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (policy *LabelPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -118,11 +124,14 @@ func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (_ *LabelPolicy, return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "unable to create sql stmt") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (_ *LabelPolicy, err error) { +func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (policy *LabelPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -138,8 +147,11 @@ func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (_ *LabelPolicy return nil, errors.ThrowInternal(err, "QUERY-B3JQR", "unable to create sql stmt") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } var ( diff --git a/internal/query/lockout_policy.go b/internal/query/lockout_policy.go index ff9dfc663a..7e7624dc28 100644 --- a/internal/query/lockout_policy.go +++ b/internal/query/lockout_policy.go @@ -81,7 +81,7 @@ var ( } ) -func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *LockoutPolicy, err error) { +func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *LockoutPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -110,11 +110,14 @@ func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool return nil, errors.ThrowInternal(err, "QUERY-SKR6X", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (_ *LockoutPolicy, err error) { +func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (policy *LockoutPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -129,8 +132,11 @@ func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (_ *LockoutPolicy, e return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } func prepareLockoutPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LockoutPolicy, error)) { diff --git a/internal/query/lockout_policy_test.go b/internal/query/lockout_policy_test.go index 5c1833182b..0fefcf386e 100644 --- a/internal/query/lockout_policy_test.go +++ b/internal/query/lockout_policy_test.go @@ -53,7 +53,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) { name: "prepareLockoutPolicyQuery no result", prepare: prepareLockoutPolicyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareLockoutPolicyStmt), nil, nil, @@ -114,7 +114,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*LockoutPolicy)(nil), }, } for _, tt := range tests { diff --git a/internal/query/login_policy.go b/internal/query/login_policy.go index dcf168d944..5d0a7e668f 100644 --- a/internal/query/login_policy.go +++ b/internal/query/login_policy.go @@ -166,7 +166,7 @@ var ( } ) -func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *LoginPolicy, err error) { +func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *LoginPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -191,11 +191,14 @@ func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, o return nil, errors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + policy, err = q.scanAndAddLinksToLoginPolicy(ctx, rows, scan) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-SWgr3", "Errors.Internal") } - return q.scanAndAddLinksToLoginPolicy(ctx, rows, scan) + return policy, nil } func (q *Queries) scanAndAddLinksToLoginPolicy(ctx context.Context, rows *sql.Rows, scan func(*sql.Rows) (*LoginPolicy, error)) (*LoginPolicy, error) { @@ -214,7 +217,7 @@ func (q *Queries) scanAndAddLinksToLoginPolicy(ctx context.Context, rows *sql.Ro return policy, nil } -func (q *Queries) DefaultLoginPolicy(ctx context.Context) (_ *LoginPolicy, err error) { +func (q *Queries) DefaultLoginPolicy(ctx context.Context) (policy *LoginPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -227,14 +230,17 @@ func (q *Queries) DefaultLoginPolicy(ctx context.Context) (_ *LoginPolicy, err e return nil, errors.ThrowInternal(err, "QUERY-t4TBK", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + policy, err = q.scanAndAddLinksToLoginPolicy(ctx, rows, scan) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-SArt2", "Errors.Internal") } - return q.scanAndAddLinksToLoginPolicy(ctx, rows, scan) + return policy, nil } -func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (_ *SecondFactors, err error) { +func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (factors *SecondFactors, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -259,8 +265,10 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (_ *Seco return nil, errors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - factors, err := scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + factors, err = scan(row) + return err + }, stmt, args...) if err != nil { return nil, err } @@ -268,7 +276,7 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (_ *Seco return factors, err } -func (q *Queries) DefaultSecondFactors(ctx context.Context) (_ *SecondFactors, err error) { +func (q *Queries) DefaultSecondFactors(ctx context.Context) (factors *SecondFactors, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -281,8 +289,10 @@ func (q *Queries) DefaultSecondFactors(ctx context.Context) (_ *SecondFactors, e return nil, errors.ThrowInternal(err, "QUERY-CZ2Nv", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - factors, err := scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + factors, err = scan(row) + return err + }, stmt, args...) if err != nil { return nil, err } @@ -290,7 +300,7 @@ func (q *Queries) DefaultSecondFactors(ctx context.Context) (_ *SecondFactors, e return factors, err } -func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (_ *MultiFactors, err error) { +func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (factors *MultiFactors, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -315,8 +325,10 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (_ *Multi return nil, errors.ThrowInternal(err, "QUERY-B4o7h", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - factors, err := scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + factors, err = scan(row) + return err + }, stmt, args...) if err != nil { return nil, err } @@ -324,7 +336,7 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (_ *Multi return factors, err } -func (q *Queries) DefaultMultiFactors(ctx context.Context) (_ *MultiFactors, err error) { +func (q *Queries) DefaultMultiFactors(ctx context.Context) (factors *MultiFactors, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -337,8 +349,10 @@ func (q *Queries) DefaultMultiFactors(ctx context.Context) (_ *MultiFactors, err return nil, errors.ThrowInternal(err, "QUERY-WxYjr", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - factors, err := scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + factors, err = scan(row) + return err + }, stmt, args...) if err != nil { return nil, err } diff --git a/internal/query/login_policy_test.go b/internal/query/login_policy_test.go index 9aa398d81d..6da7e34874 100644 --- a/internal/query/login_policy_test.go +++ b/internal/query/login_policy_test.go @@ -98,7 +98,7 @@ func Test_LoginPolicyPrepares(t *testing.T) { name: "prepareLoginPolicyQuery no result", prepare: prepareLoginPolicyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(loginPolicyQuery), nil, nil, @@ -189,13 +189,13 @@ func Test_LoginPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*LoginPolicy)(nil), }, { name: "prepareLoginPolicy2FAsQuery no result", prepare: prepareLoginPolicy2FAsQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(prepareLoginPolicy2FAsStmt), prepareLoginPolicy2FAsCols, nil, @@ -257,13 +257,13 @@ func Test_LoginPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*SecondFactors)(nil), }, { name: "prepareLoginPolicyMFAsQuery no result", prepare: prepareLoginPolicyMFAsQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(prepareLoginPolicyMFAsStmt), prepareLoginPolicyMFAsCols, nil, @@ -325,7 +325,7 @@ func Test_LoginPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*MultiFactors)(nil), }, } for _, tt := range tests { diff --git a/internal/query/mail_template.go b/internal/query/mail_template.go index 13459a4a9a..f0872cf290 100644 --- a/internal/query/mail_template.go +++ b/internal/query/mail_template.go @@ -70,7 +70,7 @@ var ( } ) -func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (_ *MailTemplate, err error) { +func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (template *MailTemplate, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -93,11 +93,14 @@ func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwner return nil, errors.ThrowInternal(err, "QUERY-m0sJg", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + template, err = scan(row) + return err + }, query, args...) + return template, err } -func (q *Queries) DefaultMailTemplate(ctx context.Context) (_ *MailTemplate, err error) { +func (q *Queries) DefaultMailTemplate(ctx context.Context) (template *MailTemplate, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -112,8 +115,11 @@ func (q *Queries) DefaultMailTemplate(ctx context.Context) (_ *MailTemplate, err return nil, errors.ThrowInternal(err, "QUERY-2m0fH", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + template, err = scan(row) + return err + }, query, args...) + return template, err } func prepareMailTemplateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MailTemplate, error)) { diff --git a/internal/query/message_text.go b/internal/query/message_text.go index fc6b445694..4bc656c0cc 100644 --- a/internal/query/message_text.go +++ b/internal/query/message_text.go @@ -126,7 +126,7 @@ var ( } ) -func (q *Queries) DefaultMessageText(ctx context.Context) (_ *MessageText, err error) { +func (q *Queries) DefaultMessageText(ctx context.Context) (text *MessageText, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -140,8 +140,11 @@ func (q *Queries) DefaultMessageText(ctx context.Context) (_ *MessageText, err e return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + text, err = scan(row) + return err + }, query, args...) + return text, err } func (q *Queries) DefaultMessageTextByTypeAndLanguageFromFileSystem(ctx context.Context, messageType, language string) (_ *MessageText, err error) { @@ -159,7 +162,7 @@ func (q *Queries) DefaultMessageTextByTypeAndLanguageFromFileSystem(ctx context. return messageTexts.GetMessageTextByType(messageType), nil } -func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggregateID, messageType, language string, withOwnerRemoved bool) (_ *MessageText, err error) { +func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggregateID, messageType, language string, withOwnerRemoved bool) (msg *MessageText, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -179,8 +182,10 @@ func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggreg return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - msg, err := scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + msg, err = scan(row) + return err + }, query, args...) if errors.IsNotFound(err) { return q.IAMMessageTextByTypeAndLanguage(ctx, messageType, language) } diff --git a/internal/query/message_text_test.go b/internal/query/message_text_test.go index d371a2f171..713066512f 100644 --- a/internal/query/message_text_test.go +++ b/internal/query/message_text_test.go @@ -64,7 +64,7 @@ func Test_MessageTextPrepares(t *testing.T) { name: "prepareMessageTextQuery no result", prepare: prepareMessageTextQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareMessageTextStmt), nil, nil, @@ -135,7 +135,7 @@ func Test_MessageTextPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*MessageText)(nil), }, } for _, tt := range tests { diff --git a/internal/query/milestone.go b/internal/query/milestone.go index b4767495b8..e781151d7e 100644 --- a/internal/query/milestone.go +++ b/internal/query/milestone.go @@ -69,7 +69,7 @@ var ( ) // SearchMilestones tries to defer the instanceID from the passed context if no instanceIDs are passed -func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *MilestonesSearchQueries) (_ *Milestones, err error) { +func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *MilestonesSearchQueries) (milestones *Milestones, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() query, scan := prepareMilestonesQuery(ctx, q.client) @@ -80,22 +80,14 @@ func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, qu if err != nil { return nil, errors.ThrowInternal(err, "QUERY-A9i5k", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + milestones, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, err } - defer func() { - if closeErr := rows.Close(); closeErr != nil && err == nil { - err = errors.ThrowInternal(closeErr, "QUERY-CK9mI", "Errors.Query.CloseRows") - } - }() - milestones, err := scan(rows) - if err != nil { - return nil, err - } - if err = rows.Err(); err != nil { - return nil, errors.ThrowInternal(err, "QUERY-asLsI", "Errors.Internal") - } + milestones.LatestSequence, err = q.latestSequence(ctx, milestonesTable) return milestones, err diff --git a/internal/query/milestone_test.go b/internal/query/milestone_test.go index 1c1852a59d..b0b7ec8b5a 100644 --- a/internal/query/milestone_test.go +++ b/internal/query/milestone_test.go @@ -178,7 +178,7 @@ func Test_MilestonesPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Milestones)(nil), }, } for _, tt := range tests { diff --git a/internal/query/notification_policy.go b/internal/query/notification_policy.go index 64873e1ca9..c015c1c963 100644 --- a/internal/query/notification_policy.go +++ b/internal/query/notification_policy.go @@ -76,7 +76,7 @@ var ( } ) -func (q *Queries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *NotificationPolicy, err error) { +func (q *Queries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *NotificationPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -104,11 +104,14 @@ func (q *Queries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk return nil, errors.ThrowInternal(err, "QUERY-Xuoapqm", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBulk bool) (_ *NotificationPolicy, err error) { +func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBulk bool) (policy *NotificationPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -130,8 +133,11 @@ func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBu return nil, errors.ThrowInternal(err, "QUERY-xlqp209", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } func prepareNotificationPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*NotificationPolicy, error)) { diff --git a/internal/query/notification_policy_test.go b/internal/query/notification_policy_test.go index b318531fb5..3c5de860fa 100644 --- a/internal/query/notification_policy_test.go +++ b/internal/query/notification_policy_test.go @@ -50,7 +50,7 @@ func Test_NotificationPolicyPrepares(t *testing.T) { name: "prepareNotificationPolicyQuery no result", prepare: prepareNotificationPolicyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( notificationPolicyStmt, nil, nil, @@ -109,7 +109,7 @@ func Test_NotificationPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*NotificationPolicy)(nil), }, } for _, tt := range tests { diff --git a/internal/query/notification_provider.go b/internal/query/notification_provider.go index bbc0c50a08..7ba7320e2a 100644 --- a/internal/query/notification_provider.go +++ b/internal/query/notification_provider.go @@ -70,7 +70,7 @@ var ( } ) -func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (_ *DebugNotificationProvider, err error) { +func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (provider *DebugNotificationProvider, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -90,8 +90,11 @@ func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID str return nil, errors.ThrowInternal(err, "QUERY-f9jSf", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + provider, err = scan(row) + return err + }, stmt, args...) + return provider, err } func prepareDebugNotificationProviderQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DebugNotificationProvider, error)) { diff --git a/internal/query/notification_provider_test.go b/internal/query/notification_provider_test.go index 2fafe45ea1..b0bfbe5115 100644 --- a/internal/query/notification_provider_test.go +++ b/internal/query/notification_provider_test.go @@ -50,7 +50,7 @@ func Test_NotificationProviderPrepares(t *testing.T) { name: "prepareNotificationProviderQuery no result", prepare: prepareDebugNotificationProviderQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareNotificationProviderStmt), nil, nil, @@ -109,7 +109,7 @@ func Test_NotificationProviderPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*DebugNotificationProvider)(nil), }, } for _, tt := range tests { diff --git a/internal/query/oidc_settings.go b/internal/query/oidc_settings.go index af26406cc4..7e48ff43a7 100644 --- a/internal/query/oidc_settings.go +++ b/internal/query/oidc_settings.go @@ -75,7 +75,7 @@ type OIDCSettings struct { RefreshTokenExpiration time.Duration } -func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) (_ *OIDCSettings, err error) { +func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) (settings *OIDCSettings, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -88,8 +88,11 @@ func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) ( return nil, errors.ThrowInternal(err, "QUERY-s9nle", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + settings, err = scan(row) + return err + }, query, args...) + return settings, err } func prepareOIDCSettingsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*OIDCSettings, error)) { diff --git a/internal/query/oidc_settings_test.go b/internal/query/oidc_settings_test.go index 40f544795b..d5aa653160 100644 --- a/internal/query/oidc_settings_test.go +++ b/internal/query/oidc_settings_test.go @@ -52,7 +52,7 @@ func Test_OIDCConfigsPrepares(t *testing.T) { name: "prepareOIDCSettingsQuery no result", prepare: prepareOIDCSettingsQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( prepareOIDCSettingsStmt, nil, nil, @@ -113,7 +113,7 @@ func Test_OIDCConfigsPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*OIDCSettings)(nil), }, } for _, tt := range tests { diff --git a/internal/query/org.go b/internal/query/org.go index 40bcbe4005..6ba4b54dc6 100644 --- a/internal/query/org.go +++ b/internal/query/org.go @@ -89,7 +89,7 @@ func (q *OrgSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { return query } -func (q *Queries) OrgByID(ctx context.Context, shouldTriggerBulk bool, id string) (_ *Org, err error) { +func (q *Queries) OrgByID(ctx context.Context, shouldTriggerBulk bool, id string) (org *Org, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -106,11 +106,14 @@ func (q *Queries) OrgByID(ctx context.Context, shouldTriggerBulk bool, id string return nil, errors.ThrowInternal(err, "QUERY-AWx52", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + org, err = scan(row) + return err + }, query, args...) + return org, err } -func (q *Queries) OrgByPrimaryDomain(ctx context.Context, domain string) (_ *Org, err error) { +func (q *Queries) OrgByPrimaryDomain(ctx context.Context, domain string) (org *Org, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -124,11 +127,14 @@ func (q *Queries) OrgByPrimaryDomain(ctx context.Context, domain string) (_ *Org return nil, errors.ThrowInternal(err, "QUERY-TYUCE", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + org, err = scan(row) + return err + }, query, args...) + return org, err } -func (q *Queries) OrgByVerifiedDomain(ctx context.Context, domain string) (_ *Org, err error) { +func (q *Queries) OrgByVerifiedDomain(ctx context.Context, domain string) (org *Org, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -142,8 +148,11 @@ func (q *Queries) OrgByVerifiedDomain(ctx context.Context, domain string) (_ *Or return nil, errors.ThrowInternal(err, "QUERY-TYUCE", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + org, err = scan(row) + return err + }, query, args...) + return org, err } func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUnique bool, err error) { @@ -176,8 +185,11 @@ func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUniqu return false, errors.ThrowInternal(err, "QUERY-Dgbe2", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + isUnique, err = scan(row) + return err + }, stmt, args...) + return isUnique, err } func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID string, err error) { @@ -209,14 +221,14 @@ func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries) (or return nil, errors.ThrowInvalidArgument(err, "QUERY-wQ3by", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + orgs, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-M6mYN", "Errors.Internal") } - orgs, err = scan(rows) - if err != nil { - return nil, err - } + orgs.LatestSequence, err = q.latestSequence(ctx, orgsTable) return orgs, err } diff --git a/internal/query/org_domain.go b/internal/query/org_domain.go index a901e2d4c6..ee89efba96 100644 --- a/internal/query/org_domain.go +++ b/internal/query/org_domain.go @@ -70,14 +70,14 @@ func (q *Queries) SearchOrgDomains(ctx context.Context, queries *OrgDomainSearch return nil, errors.ThrowInvalidArgument(err, "QUERY-ZRfj1", "Errors.Query.SQLStatement") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + domains, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-M6mYN", "Errors.Internal") } - domains, err = scan(rows) - if err != nil { - return nil, err - } + domains.LatestSequence, err = q.latestSequence(ctx, orgDomainsTable) return domains, err } diff --git a/internal/query/org_domain_test.go b/internal/query/org_domain_test.go index 18bd817046..5757eda657 100644 --- a/internal/query/org_domain_test.go +++ b/internal/query/org_domain_test.go @@ -172,7 +172,7 @@ func Test_OrgDomainPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Domains)(nil), }, } for _, tt := range tests { diff --git a/internal/query/org_member.go b/internal/query/org_member.go index 91bd2d30f3..3b5eb4a41a 100644 --- a/internal/query/org_member.go +++ b/internal/query/org_member.go @@ -78,7 +78,7 @@ func addOrgMemberWithoutOwnerRemoved(eq map[string]interface{}) { eq[OrgMemberOwnerRemovedUser.identifier()] = false } -func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery, withOwnerRemoved bool) (_ *Members, err error) { +func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery, withOwnerRemoved bool) (members *Members, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -98,14 +98,14 @@ func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery, with return nil, err } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + members, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-5g4yV", "Errors.Internal") } - members, err := scan(rows) - if err != nil { - return nil, err - } + members.LatestSequence = currentSequence return members, err } diff --git a/internal/query/org_member_test.go b/internal/query/org_member_test.go index a8bb7739da..57e9f84b6a 100644 --- a/internal/query/org_member_test.go +++ b/internal/query/org_member_test.go @@ -284,7 +284,7 @@ func Test_OrgMemberPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*OrgMembership)(nil), }, } for _, tt := range tests { diff --git a/internal/query/org_metadata.go b/internal/query/org_metadata.go index 5ce36369ed..517d2d634d 100644 --- a/internal/query/org_metadata.go +++ b/internal/query/org_metadata.go @@ -77,7 +77,7 @@ var ( } ) -func (q *Queries) GetOrgMetadataByKey(ctx context.Context, shouldTriggerBulk bool, orgID string, key string, withOwnerRemoved bool, queries ...SearchQuery) (_ *OrgMetadata, err error) { +func (q *Queries) GetOrgMetadataByKey(ctx context.Context, shouldTriggerBulk bool, orgID string, key string, withOwnerRemoved bool, queries ...SearchQuery) (metadata *OrgMetadata, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -102,11 +102,14 @@ func (q *Queries) GetOrgMetadataByKey(ctx context.Context, shouldTriggerBulk boo return nil, errors.ThrowInternal(err, "QUERY-aDaG2", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + metadata, err = scan(row) + return err + }, stmt, args...) + return metadata, err } -func (q *Queries) SearchOrgMetadata(ctx context.Context, shouldTriggerBulk bool, orgID string, queries *OrgMetadataSearchQueries, withOwnerRemoved bool) (_ *OrgMetadataList, err error) { +func (q *Queries) SearchOrgMetadata(ctx context.Context, shouldTriggerBulk bool, orgID string, queries *OrgMetadataSearchQueries, withOwnerRemoved bool) (metadata *OrgMetadataList, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -126,14 +129,14 @@ func (q *Queries) SearchOrgMetadata(ctx context.Context, shouldTriggerBulk bool, return nil, errors.ThrowInternal(err, "QUERY-Egbld", "Errors.Query.SQLStatment") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + metadata, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Ho2wf", "Errors.Internal") } - metadata, err := scan(rows) - if err != nil { - return nil, err - } + metadata.LatestSequence, err = q.latestSequence(ctx, orgMetadataTable) return metadata, err } diff --git a/internal/query/org_metadata_test.go b/internal/query/org_metadata_test.go index 4bfee48412..600a2cca48 100644 --- a/internal/query/org_metadata_test.go +++ b/internal/query/org_metadata_test.go @@ -63,7 +63,7 @@ func Test_OrgMetadataPrepares(t *testing.T) { name: "prepareOrgMetadataQuery no result", prepare: prepareOrgMetadataQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(orgMetadataQuery), nil, nil, @@ -118,7 +118,7 @@ func Test_OrgMetadataPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*OrgMetadata)(nil), }, { name: "prepareOrgMetadataListQuery no result", @@ -239,7 +239,7 @@ func Test_OrgMetadataPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*OrgMetadataList)(nil), }, } for _, tt := range tests { diff --git a/internal/query/org_test.go b/internal/query/org_test.go index 194c7bd13e..c4c2c21ad5 100644 --- a/internal/query/org_test.go +++ b/internal/query/org_test.go @@ -209,13 +209,13 @@ func Test_OrgPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Orgs)(nil), }, { name: "prepareOrgQuery no result", prepare: prepareOrgQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareOrgQueryStmt), nil, nil, @@ -274,13 +274,13 @@ func Test_OrgPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Org)(nil), }, { name: "prepareOrgUniqueQuery no result", prepare: prepareOrgUniqueQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareOrgUniqueStmt), nil, nil, @@ -323,7 +323,7 @@ func Test_OrgPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: false, }, } for _, tt := range tests { diff --git a/internal/query/password_age_policy.go b/internal/query/password_age_policy.go index fda59e88e0..e2130f7767 100644 --- a/internal/query/password_age_policy.go +++ b/internal/query/password_age_policy.go @@ -81,7 +81,7 @@ var ( } ) -func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *PasswordAgePolicy, err error) { +func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *PasswordAgePolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -107,11 +107,14 @@ func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk return nil, errors.ThrowInternal(err, "QUERY-SKR6X", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBulk bool) (_ *PasswordAgePolicy, err error) { +func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBulk bool) (policy *PasswordAgePolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -129,8 +132,11 @@ func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBul return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } func preparePasswordAgePolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PasswordAgePolicy, error)) { diff --git a/internal/query/password_age_policy_test.go b/internal/query/password_age_policy_test.go index 15fbfe5935..738a8bb825 100644 --- a/internal/query/password_age_policy_test.go +++ b/internal/query/password_age_policy_test.go @@ -52,7 +52,7 @@ func Test_PasswordAgePolicyPrepares(t *testing.T) { name: "preparePasswordAgePolicyQuery no result", prepare: preparePasswordAgePolicyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(preparePasswordAgePolicyStmt), nil, nil, @@ -113,7 +113,7 @@ func Test_PasswordAgePolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*PasswordAgePolicy)(nil), }, } for _, tt := range tests { diff --git a/internal/query/password_complexity_policy.go b/internal/query/password_complexity_policy.go index e5f15f1810..a61048b990 100644 --- a/internal/query/password_complexity_policy.go +++ b/internal/query/password_complexity_policy.go @@ -33,7 +33,7 @@ type PasswordComplexityPolicy struct { IsDefault bool } -func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *PasswordComplexityPolicy, err error) { +func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *PasswordComplexityPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -59,11 +59,14 @@ func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTrigg return nil, errors.ThrowInternal(err, "QUERY-lDnrk", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTriggerBulk bool) (_ *PasswordComplexityPolicy, err error) { +func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTriggerBulk bool) (policy *PasswordComplexityPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -82,8 +85,11 @@ func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTri return nil, errors.ThrowInternal(err, "QUERY-h4Uyr", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } var ( diff --git a/internal/query/password_complexity_policy_test.go b/internal/query/password_complexity_policy_test.go index 7629940fc1..3f3743831a 100644 --- a/internal/query/password_complexity_policy_test.go +++ b/internal/query/password_complexity_policy_test.go @@ -58,7 +58,7 @@ func Test_PasswordComplexityPolicyPrepares(t *testing.T) { name: "preparePasswordComplexityPolicyQuery no result", prepare: preparePasswordComplexityPolicyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(preparePasswordComplexityPolicyStmt), nil, nil, @@ -125,7 +125,7 @@ func Test_PasswordComplexityPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*PasswordComplexityPolicy)(nil), }, } for _, tt := range tests { diff --git a/internal/query/prepare_test.go b/internal/query/prepare_test.go index cb69fb21e5..242b387408 100644 --- a/internal/query/prepare_test.go +++ b/internal/query/prepare_test.go @@ -15,6 +15,7 @@ import ( sq "github.com/Masterminds/squirrel" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" + "github.com/zitadel/zitadel/internal/database" ) var ( @@ -53,14 +54,15 @@ func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExp } return isErr(err) } - object, ok := execScan(client, builder, scan, errCheck) + object, ok, didScan := execScan(&database.DB{DB: client}, builder, scan, errCheck) if !ok { t.Error(object) return false } - - if !assert.Equal(t, expectedObject, object) { - return false + if didScan { + if !assert.Equal(t, expectedObject, object) { + return false + } } if err := mock.ExpectationsWereMet(); err != nil { @@ -77,7 +79,23 @@ type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + m.ExpectBegin() q := m.ExpectQuery(stmt).WithArgs(args...) + m.ExpectCommit() + result := sqlmock.NewRows(cols) + if len(row) > 0 { + result.AddRow(row...) + } + q.WillReturnRows(result) + return m + } +} + +func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + m.ExpectBegin() + q := m.ExpectQuery(stmt).WithArgs(args...) + m.ExpectRollback() result := sqlmock.NewRows(cols) if len(row) > 0 { result.AddRow(row...) @@ -89,7 +107,28 @@ func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Va func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + m.ExpectBegin() q := m.ExpectQuery(stmt).WithArgs(args...) + m.ExpectCommit() + result := sqlmock.NewRows(cols) + count := uint64(len(rows)) + for _, row := range rows { + if cols[len(cols)-1] == "count" { + row = append(row, count) + } + result.AddRow(row...) + } + q.WillReturnRows(result) + q.RowsWillBeClosed() + return m + } +} + +func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + m.ExpectBegin() + q := m.ExpectQuery(stmt).WithArgs(args...) + m.ExpectRollback() result := sqlmock.NewRows(cols) count := uint64(len(rows)) for _, row := range rows { @@ -106,8 +145,10 @@ func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driv func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { + m.ExpectBegin() q := m.ExpectQuery(stmt).WithArgs(args...) q.WillReturnError(err) + m.ExpectRollback() return m } } @@ -127,52 +168,65 @@ var ( selectBuilderType = reflect.TypeOf(sq.SelectBuilder{}) ) -func execScan(client *sql.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (interface{}, bool) { +func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) { scanType := reflect.TypeOf(scan) err := validateScan(scanType) if err != nil { - return err, false + return err, false, false } stmt, args, err := builder.ToSql() if err != nil { - return fmt.Errorf("unexpeted error from sql builder: %w", err), false + return fmt.Errorf("unexpeted error from sql builder: %w", err), false, false } //resultSet represents *sql.Row or *sql.Rows, // depending on whats assignable to the scan function - var resultSet interface{} + var res []reflect.Value //execute sql stmt // if scan(*sql.Rows)... if scanType.In(0).AssignableTo(rowsType) { - resultSet, err = client.Query(stmt, args...) - if err != nil { - return errCheck(err) - } + err = client.Query(func(rows *sql.Rows) error { + didScan = true + res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(rows)}) + if err, ok := res[1].Interface().(error); ok { + return err + } + return nil + }, stmt, args...) // if scan(*sql.Row)... } else if scanType.In(0).AssignableTo(rowType) { - row := client.QueryRow(stmt, args...) - if row.Err() != nil { - return errCheck(row.Err()) - } - resultSet = row + err = client.QueryRow(func(r *sql.Row) error { + didScan = true + res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)}) + if err, ok := res[1].Interface().(error); ok { + return err + } + return nil + }, stmt, args...) + } else { - return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false + return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false, false } - // res contains object and error - res := reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(resultSet)}) + if err != nil { + err, ok := errCheck(err) + if didScan { + return res[0].Interface(), ok, didScan + } + return err, ok, didScan + } //check for error if res[1].Interface() != nil { if err, ok := errCheck(res[1].Interface().(error)); !ok { - return fmt.Errorf("scan failed: %w", err), false + return fmt.Errorf("scan failed: %w", err), false, didScan } } - return res[0].Interface(), true + return res[0].Interface(), true, didScan } func validateScan(scanType reflect.Type) error { diff --git a/internal/query/privacy_policy.go b/internal/query/privacy_policy.go index e207c3e9b2..8ff9f0edbf 100644 --- a/internal/query/privacy_policy.go +++ b/internal/query/privacy_policy.go @@ -91,7 +91,7 @@ var ( } ) -func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *PrivacyPolicy, err error) { +func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *PrivacyPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -116,11 +116,14 @@ func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool return nil, errors.ThrowInternal(err, "QUERY-UXuPI", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } -func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bool) (_ *PrivacyPolicy, err error) { +func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bool) (policy *PrivacyPolicy, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -139,8 +142,11 @@ func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bo return nil, errors.ThrowInternal(err, "QUERY-LkFZ7", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } func preparePrivacyPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PrivacyPolicy, error)) { diff --git a/internal/query/privacy_policy_test.go b/internal/query/privacy_policy_test.go index ea68af11b2..70b32723d9 100644 --- a/internal/query/privacy_policy_test.go +++ b/internal/query/privacy_policy_test.go @@ -56,7 +56,7 @@ func Test_PrivacyPolicyPrepares(t *testing.T) { name: "preparePrivacyPolicyQuery no result", prepare: preparePrivacyPolicyQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(preparePrivacyPolicyStmt), nil, nil, @@ -121,7 +121,7 @@ func Test_PrivacyPolicyPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*PrivacyPolicy)(nil), }, } for _, tt := range tests { diff --git a/internal/query/project.go b/internal/query/project.go index 065a3e5c5b..0e628d7b1d 100644 --- a/internal/query/project.go +++ b/internal/query/project.go @@ -100,7 +100,7 @@ type ProjectSearchQueries struct { Queries []SearchQuery } -func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (_ *Project, err error) { +func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (project *Project, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -121,8 +121,11 @@ func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id st return nil, errors.ThrowInternal(err, "QUERY-2m00Q", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + project, err = scan(row) + return err + }, query, args...) + return project, err } func (q *Queries) SearchProjects(ctx context.Context, queries *ProjectSearchQueries, withOwnerRemoved bool) (projects *Projects, err error) { @@ -139,14 +142,13 @@ func (q *Queries) SearchProjects(ctx context.Context, queries *ProjectSearchQuer return nil, errors.ThrowInvalidArgument(err, "QUERY-fn9ew", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + projects, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-2j00f", "Errors.Internal") } - projects, err = scan(rows) - if err != nil { - return nil, err - } projects.LatestSequence, err = q.latestSequence(ctx, projectsTable) return projects, err } diff --git a/internal/query/project_grant.go b/internal/query/project_grant.go index fee8ae8369..d7ee5ce1a1 100644 --- a/internal/query/project_grant.go +++ b/internal/query/project_grant.go @@ -111,7 +111,7 @@ type ProjectGrantSearchQueries struct { Queries []SearchQuery } -func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (_ *ProjectGrant, err error) { +func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (grant *ProjectGrant, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -133,11 +133,14 @@ func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool, return nil, errors.ThrowInternal(err, "QUERY-Nf93d", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + grant, err = scan(row) + return err + }, query, args...) + return grant, err } -func (q *Queries) ProjectGrantByIDAndGrantedOrg(ctx context.Context, id, grantedOrg string, withOwnerRemoved bool) (_ *ProjectGrant, err error) { +func (q *Queries) ProjectGrantByIDAndGrantedOrg(ctx context.Context, id, grantedOrg string, withOwnerRemoved bool) (grant *ProjectGrant, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -156,11 +159,14 @@ func (q *Queries) ProjectGrantByIDAndGrantedOrg(ctx context.Context, id, granted return nil, errors.ThrowInternal(err, "QUERY-MO9fs", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + grant, err = scan(row) + return err + }, query, args...) + return grant, err } -func (q *Queries) SearchProjectGrants(ctx context.Context, queries *ProjectGrantSearchQueries, withOwnerRemoved bool) (projects *ProjectGrants, err error) { +func (q *Queries) SearchProjectGrants(ctx context.Context, queries *ProjectGrantSearchQueries, withOwnerRemoved bool) (grants *ProjectGrants, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -177,16 +183,16 @@ func (q *Queries) SearchProjectGrants(ctx context.Context, queries *ProjectGrant return nil, errors.ThrowInvalidArgument(err, "QUERY-N9fsg", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + grants, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-PP02n", "Errors.Internal") } - projects, err = scan(rows) - if err != nil { - return nil, err - } - projects.LatestSequence, err = q.latestSequence(ctx, projectGrantsTable) - return projects, err + + grants.LatestSequence, err = q.latestSequence(ctx, projectGrantsTable) + return grants, err } func (q *Queries) SearchProjectGrantsByProjectIDAndRoleKey(ctx context.Context, projectID, roleKey string, withOwnerRemoved bool) (projects *ProjectGrants, err error) { diff --git a/internal/query/project_grant_member.go b/internal/query/project_grant_member.go index 0f21679080..11de7cd0ea 100644 --- a/internal/query/project_grant_member.go +++ b/internal/query/project_grant_member.go @@ -95,7 +95,7 @@ func addProjectGrantMemberWithoutOwnerRemoved(eq map[string]interface{}) { eq[ProjectGrantMemberGrantedOrgRemoved.identifier()] = false } -func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrantMembersQuery, withOwnerRemoved bool) (*Members, error) { +func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrantMembersQuery, withOwnerRemoved bool) (members *Members, err error) { query, scan := prepareProjectGrantMembersQuery(ctx, q.client) eq := sq.Eq{ProjectGrantMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} if !withOwnerRemoved { @@ -112,14 +112,14 @@ func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrant return nil, err } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + members, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Pdg1I", "Errors.Internal") } - members, err := scan(rows) - if err != nil { - return nil, err - } + members.LatestSequence = currentSequence return members, err } diff --git a/internal/query/project_grant_member_test.go b/internal/query/project_grant_member_test.go index 686fe778a6..99e7c5320e 100644 --- a/internal/query/project_grant_member_test.go +++ b/internal/query/project_grant_member_test.go @@ -287,7 +287,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*ProjectGrantMembership)(nil), }, } for _, tt := range tests { diff --git a/internal/query/project_grant_test.go b/internal/query/project_grant_test.go index 65de48a66c..b504cf1fb4 100644 --- a/internal/query/project_grant_test.go +++ b/internal/query/project_grant_test.go @@ -381,13 +381,13 @@ func Test_ProjectGrantPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*ProjectGrants)(nil), }, { name: "prepareProjectGrantQuery no result", prepare: prepareProjectGrantQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(projectGrantQuery), nil, nil, @@ -568,7 +568,7 @@ func Test_ProjectGrantPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*ProjectGrant)(nil), }, } for _, tt := range tests { diff --git a/internal/query/project_member.go b/internal/query/project_member.go index 15d5f83251..3cdb300869 100644 --- a/internal/query/project_member.go +++ b/internal/query/project_member.go @@ -78,7 +78,7 @@ func addProjectMemberWithoutOwnerRemoved(eq map[string]interface{}) { eq[ProjectMemberOwnerRemovedUser.identifier()] = false } -func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQuery, withOwnerRemoved bool) (_ *Members, err error) { +func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQuery, withOwnerRemoved bool) (members *Members, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -98,14 +98,14 @@ func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQue return nil, err } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + members, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-uh6pj", "Errors.Internal") } - members, err := scan(rows) - if err != nil { - return nil, err - } + members.LatestSequence = currentSequence return members, err } diff --git a/internal/query/project_member_test.go b/internal/query/project_member_test.go index fa9181d359..f917718897 100644 --- a/internal/query/project_member_test.go +++ b/internal/query/project_member_test.go @@ -284,7 +284,7 @@ func Test_ProjectMemberPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*ProjectMembership)(nil), }, } for _, tt := range tests { diff --git a/internal/query/project_role.go b/internal/query/project_role.go index 7365b12471..fbe8027149 100644 --- a/internal/query/project_role.go +++ b/internal/query/project_role.go @@ -83,7 +83,7 @@ type ProjectRoleSearchQueries struct { Queries []SearchQuery } -func (q *Queries) SearchProjectRoles(ctx context.Context, shouldTriggerBulk bool, queries *ProjectRoleSearchQueries, withOwnerRemoved bool) (projects *ProjectRoles, err error) { +func (q *Queries) SearchProjectRoles(ctx context.Context, shouldTriggerBulk bool, queries *ProjectRoleSearchQueries, withOwnerRemoved bool) (roles *ProjectRoles, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -102,19 +102,18 @@ func (q *Queries) SearchProjectRoles(ctx context.Context, shouldTriggerBulk bool return nil, errors.ThrowInvalidArgument(err, "QUERY-3N9ff", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + roles, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-5Ngd9", "Errors.Internal") } - projects, err = scan(rows) - if err != nil { - return nil, err - } - projects.LatestSequence, err = q.latestSequence(ctx, projectRolesTable) - return projects, err + roles.LatestSequence, err = q.latestSequence(ctx, projectRolesTable) + return roles, err } -func (q *Queries) SearchGrantedProjectRoles(ctx context.Context, grantID, grantedOrg string, queries *ProjectRoleSearchQueries, withOwnerRemoved bool) (projects *ProjectRoles, err error) { +func (q *Queries) SearchGrantedProjectRoles(ctx context.Context, grantID, grantedOrg string, queries *ProjectRoleSearchQueries, withOwnerRemoved bool) (roles *ProjectRoles, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -138,16 +137,16 @@ func (q *Queries) SearchGrantedProjectRoles(ctx context.Context, grantID, grante return nil, errors.ThrowInvalidArgument(err, "QUERY-3N9ff", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + roles, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-5Ngd9", "Errors.Internal") } - projects, err = scan(rows) - if err != nil { - return nil, err - } - projects.LatestSequence, err = q.latestSequence(ctx, projectRolesTable) - return projects, err + + roles.LatestSequence, err = q.latestSequence(ctx, projectRolesTable) + return roles, err } func NewProjectRoleProjectIDSearchQuery(value string) (SearchQuery, error) { diff --git a/internal/query/project_role_test.go b/internal/query/project_role_test.go index 9324cf2929..dbc646d4d3 100644 --- a/internal/query/project_role_test.go +++ b/internal/query/project_role_test.go @@ -170,7 +170,7 @@ func Test_ProjectRolePrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*ProjectRoles)(nil), }, } for _, tt := range tests { diff --git a/internal/query/project_test.go b/internal/query/project_test.go index ea96f62c4f..4d4bd72ea4 100644 --- a/internal/query/project_test.go +++ b/internal/query/project_test.go @@ -238,13 +238,13 @@ func Test_ProjectPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Projects)(nil), }, { name: "prepareProjectQuery no result", prepare: prepareProjectQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( prepareProjectStmt, nil, nil, @@ -309,7 +309,7 @@ func Test_ProjectPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Project)(nil), }, } for _, tt := range tests { diff --git a/internal/query/secret_generator_test.go b/internal/query/secret_generator_test.go index 2d84c9b1c7..44d2015c19 100644 --- a/internal/query/secret_generator_test.go +++ b/internal/query/secret_generator_test.go @@ -234,13 +234,13 @@ func Test_SecretGeneratorsPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*SecretGenerators)(nil), }, { name: "prepareSecretGeneratorQuery no result", prepare: prepareSecretGeneratorQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( prepareSecretGeneratorStmt, nil, nil, @@ -307,7 +307,7 @@ func Test_SecretGeneratorsPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*SecretGenerator)(nil), }, } for _, tt := range tests { diff --git a/internal/query/secret_generators.go b/internal/query/secret_generators.go index a4859279f6..35c9ccbd81 100644 --- a/internal/query/secret_generators.go +++ b/internal/query/secret_generators.go @@ -134,7 +134,7 @@ func (q *Queries) InitHashGenerator(ctx context.Context, generatorType domain.Se return crypto.NewHashGenerator(cryptoConfig, algorithm), nil } -func (q *Queries) SecretGeneratorByType(ctx context.Context, generatorType domain.SecretGeneratorType) (_ *SecretGenerator, err error) { +func (q *Queries) SecretGeneratorByType(ctx context.Context, generatorType domain.SecretGeneratorType) (generator *SecretGenerator, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -147,8 +147,11 @@ func (q *Queries) SecretGeneratorByType(ctx context.Context, generatorType domai return nil, errors.ThrowInternal(err, "QUERY-3k99f", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + generator, err = scan(row) + return err + }, query, args...) + return generator, err } func (q *Queries) SearchSecretGenerators(ctx context.Context, queries *SecretGeneratorSearchQueries) (secretGenerators *SecretGenerators, err error) { @@ -164,14 +167,13 @@ func (q *Queries) SearchSecretGenerators(ctx context.Context, queries *SecretGen return nil, errors.ThrowInvalidArgument(err, "QUERY-sn9lw", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + secretGenerators, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-4miii", "Errors.Internal") } - secretGenerators, err = scan(rows) - if err != nil { - return nil, err - } secretGenerators.LatestSequence, err = q.latestSequence(ctx, secretGeneratorsTable) return secretGenerators, err } diff --git a/internal/query/security_policy.go b/internal/query/security_policy.go index 29e15de21b..5a1b23dfb5 100644 --- a/internal/query/security_policy.go +++ b/internal/query/security_policy.go @@ -57,7 +57,7 @@ type SecurityPolicy struct { AllowedOrigins database.StringArray } -func (q *Queries) SecurityPolicy(ctx context.Context) (*SecurityPolicy, error) { +func (q *Queries) SecurityPolicy(ctx context.Context) (policy *SecurityPolicy, err error) { stmt, scan := prepareSecurityPolicyQuery(ctx, q.client) query, args, err := stmt.Where(sq.Eq{ SecurityPolicyColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), @@ -66,8 +66,11 @@ func (q *Queries) SecurityPolicy(ctx context.Context) (*SecurityPolicy, error) { return nil, errors.ThrowInternal(err, "QUERY-Sf6d1", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + policy, err = scan(row) + return err + }, query, args...) + return policy, err } func prepareSecurityPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SecurityPolicy, error)) { diff --git a/internal/query/session.go b/internal/query/session.go index 0c7d67dbf8..2a1672a3fa 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -151,7 +151,7 @@ var ( } ) -func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (_ *Session, err error) { +func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (session *Session, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -170,8 +170,11 @@ func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, s return nil, errors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - session, tokenID, err := scan(row) + var tokenID string + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + session, tokenID, err = scan(row) + return err + }, stmt, args...) if err != nil { return nil, err } @@ -184,7 +187,7 @@ func (q *Queries) SessionByID(ctx context.Context, shouldTriggerBulk bool, id, s return session, nil } -func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries) (_ *Sessions, err error) { +func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -197,14 +200,14 @@ func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQue return nil, errors.ThrowInvalidArgument(err, "QUERY-sn9Jf", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil || rows.Err() != nil { + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + sessions, err = scan(rows) + return err + }, stmt, args...) + if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Sfg42", "Errors.Internal") } - sessions, err := scan(rows) - if err != nil { - return nil, err - } + sessions.LatestSequence, err = q.latestSequence(ctx, sessionsTable) return sessions, err } diff --git a/internal/query/sessions_test.go b/internal/query/sessions_test.go index bc9e8bc6a0..c66868a9d6 100644 --- a/internal/query/sessions_test.go +++ b/internal/query/sessions_test.go @@ -343,7 +343,7 @@ func Test_SessionsPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Sessions)(nil), }, } for _, tt := range tests { @@ -368,7 +368,7 @@ func Test_SessionPrepare(t *testing.T) { name: "prepareSessionQuery no result", prepare: prepareSessionQueryTesting(t, ""), want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( expectedSessionQuery, nil, nil, @@ -460,7 +460,7 @@ func Test_SessionPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Session)(nil), }, } for _, tt := range tests { diff --git a/internal/query/sms.go b/internal/query/sms.go index 6b0273ed26..956328b64a 100644 --- a/internal/query/sms.go +++ b/internal/query/sms.go @@ -115,7 +115,7 @@ var ( } ) -func (q *Queries) SMSProviderConfigByID(ctx context.Context, id string) (_ *SMSConfig, err error) { +func (q *Queries) SMSProviderConfigByID(ctx context.Context, id string) (config *SMSConfig, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -130,11 +130,14 @@ func (q *Queries) SMSProviderConfigByID(ctx context.Context, id string) (_ *SMSC return nil, errors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + config, err = scan(row) + return err + }, stmt, args...) + return config, err } -func (q *Queries) SMSProviderConfig(ctx context.Context, queries ...SearchQuery) (_ *SMSConfig, err error) { +func (q *Queries) SMSProviderConfig(ctx context.Context, queries ...SearchQuery) (config *SMSConfig, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -151,11 +154,14 @@ func (q *Queries) SMSProviderConfig(ctx context.Context, queries ...SearchQuery) return nil, errors.ThrowInternal(err, "QUERY-dn9JW", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + config, err = scan(row) + return err + }, stmt, args...) + return config, err } -func (q *Queries) SearchSMSConfigs(ctx context.Context, queries *SMSConfigsSearchQueries) (_ *SMSConfigs, err error) { +func (q *Queries) SearchSMSConfigs(ctx context.Context, queries *SMSConfigsSearchQueries) (configs *SMSConfigs, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -168,16 +174,15 @@ func (q *Queries) SearchSMSConfigs(ctx context.Context, queries *SMSConfigsSearc return nil, errors.ThrowInvalidArgument(err, "QUERY-sn9Jf", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + configs, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-aJnZL", "Errors.Internal") } - apps, err := scan(rows) - if err != nil { - return nil, err - } - apps.LatestSequence, err = q.latestSequence(ctx, smsConfigsTable) - return apps, err + configs.LatestSequence, err = q.latestSequence(ctx, smsConfigsTable) + return configs, err } func NewSMSProviderStateQuery(state domain.SMSConfigState) (SearchQuery, error) { diff --git a/internal/query/sms_test.go b/internal/query/sms_test.go index 88b0516de0..9f6c906c77 100644 --- a/internal/query/sms_test.go +++ b/internal/query/sms_test.go @@ -225,7 +225,7 @@ func Test_SMSConfigssPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*SMSConfigs)(nil), }, } for _, tt := range tests { @@ -250,7 +250,7 @@ func Test_SMSConfigPrepare(t *testing.T) { name: "prepareSMSConfigQuery no result", prepare: prepareSMSConfigQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( expectedSMSConfigQuery, nil, nil, @@ -317,7 +317,7 @@ func Test_SMSConfigPrepare(t *testing.T) { return nil, true }, }, - object: nil, + object: (*SMSConfig)(nil), }, } for _, tt := range tests { diff --git a/internal/query/smtp.go b/internal/query/smtp.go index 4b0f0a019d..53db6f7b2a 100644 --- a/internal/query/smtp.go +++ b/internal/query/smtp.go @@ -91,7 +91,7 @@ type SMTPConfig struct { Password *crypto.CryptoValue } -func (q *Queries) SMTPConfigByAggregateID(ctx context.Context, aggregateID string) (_ *SMTPConfig, err error) { +func (q *Queries) SMTPConfigByAggregateID(ctx context.Context, aggregateID string) (config *SMTPConfig, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -104,8 +104,11 @@ func (q *Queries) SMTPConfigByAggregateID(ctx context.Context, aggregateID strin return nil, errors.ThrowInternal(err, "QUERY-3m9sl", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, query, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + config, err = scan(row) + return err + }, query, args...) + return config, err } func prepareSMTPConfigQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*SMTPConfig, error)) { diff --git a/internal/query/smtp_test.go b/internal/query/smtp_test.go index 77ef763ed0..899c3ba4e2 100644 --- a/internal/query/smtp_test.go +++ b/internal/query/smtp_test.go @@ -56,7 +56,7 @@ func Test_SMTPConfigsPrepares(t *testing.T) { name: "prepareSMTPConfigQuery no result", prepare: prepareSMTPConfigQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( prepareSMTPConfigStmt, nil, nil, @@ -121,7 +121,7 @@ func Test_SMTPConfigsPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*SMTPConfig)(nil), }, } for _, tt := range tests { diff --git a/internal/query/user.go b/internal/query/user.go index c335129ed0..9b421bdc92 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -333,7 +333,7 @@ func addUserWithoutOwnerRemoved(eq map[string]interface{}) { eq[userPreferredLoginNameOwnerRemovedDomainCol.identifier()] = false } -func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userID string, withOwnerRemoved bool, queries ...SearchQuery) (_ *User, err error) { +func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userID string, withOwnerRemoved bool, queries ...SearchQuery) (user *User, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -358,11 +358,14 @@ func (q *Queries) GetUserByID(ctx context.Context, shouldTriggerBulk bool, userI return nil, errors.ThrowInternal(err, "QUERY-FBg21", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + user, err = scan(row) + return err + }, stmt, args...) + return user, err } -func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, withOwnerRemoved bool, queries ...SearchQuery) (_ *User, err error) { +func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, withOwnerRemoved bool, queries ...SearchQuery) (user *User, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -386,11 +389,14 @@ func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, withOwner return nil, errors.ThrowInternal(err, "QUERY-Dnhr2", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + user, err = scan(row) + return err + }, stmt, args...) + return user, err } -func (q *Queries) GetHumanProfile(ctx context.Context, userID string, withOwnerRemoved bool, queries ...SearchQuery) (_ *Profile, err error) { +func (q *Queries) GetHumanProfile(ctx context.Context, userID string, withOwnerRemoved bool, queries ...SearchQuery) (profile *Profile, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -410,11 +416,14 @@ func (q *Queries) GetHumanProfile(ctx context.Context, userID string, withOwnerR return nil, errors.ThrowInternal(err, "QUERY-Dgbg2", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + profile, err = scan(row) + return err + }, stmt, args...) + return profile, err } -func (q *Queries) GetHumanEmail(ctx context.Context, userID string, withOwnerRemoved bool, queries ...SearchQuery) (_ *Email, err error) { +func (q *Queries) GetHumanEmail(ctx context.Context, userID string, withOwnerRemoved bool, queries ...SearchQuery) (email *Email, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -434,11 +443,14 @@ func (q *Queries) GetHumanEmail(ctx context.Context, userID string, withOwnerRem return nil, errors.ThrowInternal(err, "QUERY-BHhj3", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + email, err = scan(row) + return err + }, stmt, args...) + return email, err } -func (q *Queries) GetHumanPhone(ctx context.Context, userID string, withOwnerRemoved bool, queries ...SearchQuery) (_ *Phone, err error) { +func (q *Queries) GetHumanPhone(ctx context.Context, userID string, withOwnerRemoved bool, queries ...SearchQuery) (phone *Phone, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -458,11 +470,14 @@ func (q *Queries) GetHumanPhone(ctx context.Context, userID string, withOwnerRem return nil, errors.ThrowInternal(err, "QUERY-Dg43g", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + phone, err = scan(row) + return err + }, stmt, args...) + return phone, err } -func (q *Queries) GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string, withOwnerRemoved bool, queries ...SearchQuery) (_ *NotifyUser, err error) { +func (q *Queries) GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string, withOwnerRemoved bool, queries ...SearchQuery) (user *NotifyUser, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -487,11 +502,14 @@ func (q *Queries) GetNotifyUserByID(ctx context.Context, shouldTriggered bool, u return nil, errors.ThrowInternal(err, "QUERY-Err3g", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + user, err = scan(row) + return err + }, stmt, args...) + return user, err } -func (q *Queries) GetNotifyUser(ctx context.Context, shouldTriggered bool, withOwnerRemoved bool, queries ...SearchQuery) (_ *NotifyUser, err error) { +func (q *Queries) GetNotifyUser(ctx context.Context, shouldTriggered bool, withOwnerRemoved bool, queries ...SearchQuery) (user *NotifyUser, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -515,11 +533,14 @@ func (q *Queries) GetNotifyUser(ctx context.Context, shouldTriggered bool, withO return nil, errors.ThrowInternal(err, "QUERY-Err3g", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + user, err = scan(row) + return err + }, stmt, args...) + return user, err } -func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, withOwnerRemoved bool) (_ *Users, err error) { +func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, withOwnerRemoved bool) (users *Users, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -534,19 +555,19 @@ func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, w return nil, errors.ThrowInternal(err, "QUERY-Dgbg2", "Errors.Query.SQLStatment") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + users, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-AG4gs", "Errors.Internal") } - users, err := scan(rows) - if err != nil { - return nil, err - } + users.LatestSequence, err = q.latestSequence(ctx, userTable) return users, err } -func (q *Queries) IsUserUnique(ctx context.Context, username, email, resourceOwner string, withOwnerRemoved bool) (_ bool, err error) { +func (q *Queries) IsUserUnique(ctx context.Context, username, email, resourceOwner string, withOwnerRemoved bool) (isUnique bool, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -584,8 +605,12 @@ func (q *Queries) IsUserUnique(ctx context.Context, username, email, resourceOwn if err != nil { return false, errors.ThrowInternal(err, "QUERY-Dg43g", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + isUnique, err = scan(row) + return err + }, stmt, args...) + return isUnique, err } func (q *UserSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index fc80685bc8..0881c253c6 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -130,11 +130,10 @@ func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMe return nil, errors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil { - return nil, errors.ThrowInternal(err, "QUERY-3n99f", "Errors.Internal") - } - userAuthMethods, err = scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + userAuthMethods, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, err } @@ -165,11 +164,10 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri return nil, errors.ThrowInvalidArgument(err, "QUERY-Sfdrg", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil || rows.Err() != nil { - return nil, errors.ThrowInternal(err, "QUERY-SDgr3", "Errors.Internal") - } - userAuthMethodTypes, err = scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + userAuthMethodTypes, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, err } @@ -200,11 +198,14 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st return nil, false, false, errors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil || rows.Err() != nil { + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + userAuthMethodTypes, forceMFA, forceMFALocalOnly, err = scan(rows) + return err + }, stmt, args...) + if err != nil { return nil, false, false, errors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal") } - return scan(rows) + return userAuthMethodTypes, forceMFA, forceMFALocalOnly, nil } func NewUserAuthMethodUserIDSearchQuery(value string) (SearchQuery, error) { diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go index 2cfe553e06..978e0e81a2 100644 --- a/internal/query/user_auth_method_test.go +++ b/internal/query/user_auth_method_test.go @@ -224,7 +224,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*AuthMethodTypes)(nil), }, { name: "prepareActiveUserAuthMethodTypesQuery no result", @@ -313,7 +313,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*AuthMethodTypes)(nil), }, { name: "prepareUserAuthMethodTypesRequiredQuery no result", diff --git a/internal/query/user_grant.go b/internal/query/user_grant.go index f78cd99d86..9643b3d93a 100644 --- a/internal/query/user_grant.go +++ b/internal/query/user_grant.go @@ -232,7 +232,7 @@ func addUserGrantWithoutOwnerRemoved(eq map[string]interface{}) { addLoginNameWithoutOwnerRemoved(eq) } -func (q *Queries) UserGrant(ctx context.Context, shouldTriggerBulk bool, withOwnerRemoved bool, queries ...SearchQuery) (_ *UserGrant, err error) { +func (q *Queries) UserGrant(ctx context.Context, shouldTriggerBulk bool, withOwnerRemoved bool, queries ...SearchQuery) (grant *UserGrant, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -253,11 +253,14 @@ func (q *Queries) UserGrant(ctx context.Context, shouldTriggerBulk bool, withOwn return nil, errors.ThrowInternal(err, "QUERY-Fa1KW", "Errors.Query.SQLStatement") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + grant, err = scan(row) + return err + }, stmt, args...) + return grant, err } -func (q *Queries) UserGrants(ctx context.Context, queries *UserGrantsQueries, shouldTriggerBulk, withOwnerRemoved bool) (_ *UserGrants, err error) { +func (q *Queries) UserGrants(ctx context.Context, queries *UserGrantsQueries, shouldTriggerBulk, withOwnerRemoved bool) (grants *UserGrants, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -280,11 +283,10 @@ func (q *Queries) UserGrants(ctx context.Context, queries *UserGrantsQueries, sh return nil, err } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil { - return nil, err - } - grants, err := scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + grants, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, err } diff --git a/internal/query/user_grant_test.go b/internal/query/user_grant_test.go index d7fccea10c..da54f2008a 100644 --- a/internal/query/user_grant_test.go +++ b/internal/query/user_grant_test.go @@ -122,7 +122,7 @@ func Test_UserGrantPrepares(t *testing.T) { name: "prepareUserGrantQuery no result", prepare: prepareUserGrantQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( userGrantStmt, nil, nil, @@ -441,7 +441,7 @@ func Test_UserGrantPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*UserGrant)(nil), }, { name: "prepareUserGrantsQuery no result", @@ -920,7 +920,7 @@ func Test_UserGrantPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*UserGrants)(nil), }, } for _, tt := range tests { diff --git a/internal/query/user_membership.go b/internal/query/user_membership.go index f72ef4f085..f2eae22a02 100644 --- a/internal/query/user_membership.go +++ b/internal/query/user_membership.go @@ -105,7 +105,7 @@ func (q *MembershipSearchQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder return query } -func (q *Queries) Memberships(ctx context.Context, queries *MembershipSearchQuery, withOwnerRemoved bool) (_ *Memberships, err error) { +func (q *Queries) Memberships(ctx context.Context, queries *MembershipSearchQuery, withOwnerRemoved bool) (memberships *Memberships, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -121,11 +121,10 @@ func (q *Queries) Memberships(ctx context.Context, queries *MembershipSearchQuer } queryArgs = append(queryArgs, args...) - rows, err := q.client.QueryContext(ctx, stmt, queryArgs...) - if err != nil { - return nil, errors.ThrowInternal(err, "QUERY-eAV2x", "Errors.Internal") - } - memberships, err := scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + memberships, err = scan(rows) + return err + }, stmt, queryArgs...) if err != nil { return nil, err } diff --git a/internal/query/user_membership_test.go b/internal/query/user_membership_test.go index 60708637ff..237e21359d 100644 --- a/internal/query/user_membership_test.go +++ b/internal/query/user_membership_test.go @@ -444,7 +444,7 @@ func Test_MembershipPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Memberships)(nil), }, } for _, tt := range tests { diff --git a/internal/query/user_metadata.go b/internal/query/user_metadata.go index d0a9202b8b..6a2be6a7e4 100644 --- a/internal/query/user_metadata.go +++ b/internal/query/user_metadata.go @@ -77,7 +77,7 @@ var ( } ) -func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bool, userID, key string, withOwnerRemoved bool, queries ...SearchQuery) (_ *UserMetadata, err error) { +func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bool, userID, key string, withOwnerRemoved bool, queries ...SearchQuery) (metadata *UserMetadata, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -102,11 +102,14 @@ func (q *Queries) GetUserMetadataByKey(ctx context.Context, shouldTriggerBulk bo return nil, errors.ThrowInternal(err, "QUERY-aDGG2", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + metadata, err = scan(row) + return err + }, stmt, args...) + return metadata, err } -func (q *Queries) SearchUserMetadata(ctx context.Context, shouldTriggerBulk bool, userID string, queries *UserMetadataSearchQueries, withOwnerRemoved bool) (_ *UserMetadataList, err error) { +func (q *Queries) SearchUserMetadata(ctx context.Context, shouldTriggerBulk bool, userID string, queries *UserMetadataSearchQueries, withOwnerRemoved bool) (metadata *UserMetadataList, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -127,11 +130,10 @@ func (q *Queries) SearchUserMetadata(ctx context.Context, shouldTriggerBulk bool return nil, errors.ThrowInternal(err, "QUERY-Egbgd", "Errors.Query.SQLStatment") } - rows, err := q.client.QueryContext(ctx, stmt, args...) - if err != nil { - return nil, errors.ThrowInternal(err, "QUERY-Hr2wf", "Errors.Internal") - } - metadata, err := scan(rows) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + metadata, err = scan(rows) + return err + }, stmt, args...) if err != nil { return nil, err } diff --git a/internal/query/user_metadata_test.go b/internal/query/user_metadata_test.go index 781caf732d..62f7cfb6a3 100644 --- a/internal/query/user_metadata_test.go +++ b/internal/query/user_metadata_test.go @@ -62,7 +62,7 @@ func Test_UserMetadataPrepares(t *testing.T) { name: "prepareUserMetadataQuery no result", prepare: prepareUserMetadataQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(userMetadataQuery), nil, nil, @@ -117,7 +117,7 @@ func Test_UserMetadataPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*UserMetadata)(nil), }, { name: "prepareUserMetadataListQuery no result", @@ -238,7 +238,7 @@ func Test_UserMetadataPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*UserMetadataList)(nil), }, } for _, tt := range tests { diff --git a/internal/query/user_personal_access_token.go b/internal/query/user_personal_access_token.go index 5558eed3c2..1024fae5be 100644 --- a/internal/query/user_personal_access_token.go +++ b/internal/query/user_personal_access_token.go @@ -85,7 +85,7 @@ type PersonalAccessTokenSearchQueries struct { Queries []SearchQuery } -func (q *Queries) PersonalAccessTokenByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (_ *PersonalAccessToken, err error) { +func (q *Queries) PersonalAccessTokenByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (pat *PersonalAccessToken, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -109,8 +109,14 @@ func (q *Queries) PersonalAccessTokenByID(ctx context.Context, shouldTriggerBulk return nil, errors.ThrowInternal(err, "QUERY-Dgfb4", "Errors.Query.SQLStatment") } - row := q.client.QueryRowContext(ctx, stmt, args...) - return scan(row) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + pat, err = scan(row) + return err + }, stmt, args...) + if err != nil { + return nil, err + } + return pat, nil } func (q *Queries) SearchPersonalAccessTokens(ctx context.Context, queries *PersonalAccessTokenSearchQueries, withOwnerRemoved bool) (personalAccessTokens *PersonalAccessTokens, err error) { @@ -129,14 +135,15 @@ func (q *Queries) SearchPersonalAccessTokens(ctx context.Context, queries *Perso return nil, errors.ThrowInvalidArgument(err, "QUERY-Hjw2w", "Errors.Query.InvalidRequest") } - rows, err := q.client.QueryContext(ctx, stmt, args...) + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + personalAccessTokens, err = scan(rows) + return err + + }, stmt, args...) if err != nil { return nil, errors.ThrowInternal(err, "QUERY-Bmz63", "Errors.Internal") } - personalAccessTokens, err = scan(rows) - if err != nil { - return nil, err - } + personalAccessTokens.LatestSequence, err = q.latestSequence(ctx, personalAccessTokensTable) return personalAccessTokens, err } diff --git a/internal/query/user_personal_access_token_test.go b/internal/query/user_personal_access_token_test.go index 6a74e04de2..c12a151ed2 100644 --- a/internal/query/user_personal_access_token_test.go +++ b/internal/query/user_personal_access_token_test.go @@ -75,7 +75,7 @@ func Test_PersonalAccessTokenPrepares(t *testing.T) { name: "preparePersonalAccessTokenQuery no result", prepare: preparePersonalAccessTokenQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( personalAccessTokenStmt, nil, nil, @@ -134,7 +134,7 @@ func Test_PersonalAccessTokenPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*PersonalAccessToken)(nil), }, { name: "preparePersonalAccessTokensQuery no result", @@ -261,7 +261,7 @@ func Test_PersonalAccessTokenPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*PersonalAccessTokens)(nil), }, } for _, tt := range tests { diff --git a/internal/query/user_test.go b/internal/query/user_test.go index d90e17b901..986d3c1e9c 100644 --- a/internal/query/user_test.go +++ b/internal/query/user_test.go @@ -333,7 +333,7 @@ func Test_UserPrepares(t *testing.T) { name: "prepareUserQuery no result", prepare: prepareUserQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(userQuery), nil, nil, @@ -489,13 +489,13 @@ func Test_UserPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*User)(nil), }, { name: "prepareProfileQuery no result", prepare: prepareProfileQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(profileQuery), nil, nil, @@ -552,7 +552,7 @@ func Test_UserPrepares(t *testing.T) { name: "prepareProfileQuery not human found (error)", prepare: prepareProfileQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(profileQuery), profileCols, []driver.Value{ @@ -595,13 +595,13 @@ func Test_UserPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Profile)(nil), }, { name: "prepareEmailQuery no result", prepare: prepareEmailQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(emailQuery), nil, nil, @@ -650,7 +650,7 @@ func Test_UserPrepares(t *testing.T) { name: "prepareEmailQuery not human found (error)", prepare: prepareEmailQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(emailQuery), emailCols, []driver.Value{ @@ -689,13 +689,13 @@ func Test_UserPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Email)(nil), }, { name: "preparePhoneQuery no result", prepare: preparePhoneQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(phoneQuery), nil, nil, @@ -744,7 +744,7 @@ func Test_UserPrepares(t *testing.T) { name: "preparePhoneQuery not human found (error)", prepare: preparePhoneQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(phoneQuery), phoneCols, []driver.Value{ @@ -783,7 +783,7 @@ func Test_UserPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Phone)(nil), }, { name: "prepareUserUniqueQuery no result", @@ -837,13 +837,13 @@ func Test_UserPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: false, }, { name: "prepareNotifyUserQuery no result", prepare: prepareNotifyUserQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(notifyUserQuery), nil, nil, @@ -924,7 +924,7 @@ func Test_UserPrepares(t *testing.T) { name: "prepareNotifyUserQuery not notify found (error)", prepare: prepareNotifyUserQuery, want: want{ - sqlExpectations: mockQuery( + sqlExpectations: mockQueryScanErr( regexp.QuoteMeta(notifyUserQuery), notifyUserCols, []driver.Value{ @@ -980,13 +980,13 @@ func Test_UserPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*NotifyUser)(nil), }, { name: "prepareUsersQuery no result", prepare: prepareUsersQuery, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQuery( regexp.QuoteMeta(usersQuery), nil, nil, @@ -1214,7 +1214,7 @@ func Test_UserPrepares(t *testing.T) { return nil, true }, }, - object: nil, + object: (*Users)(nil), }, } for _, tt := range tests { diff --git a/internal/view/repository/db_mock_test.go b/internal/view/repository/db_mock_test.go index 5fc8e66771..7abbabd986 100644 --- a/internal/view/repository/db_mock_test.go +++ b/internal/view/repository/db_mock_test.go @@ -175,48 +175,58 @@ func (db *dbMock) expectRollback(err error) *dbMock { func (db *dbMock) expectGetByID(table, key, value string) *dbMock { query := fmt.Sprintf(expectedGetByID, table, key) + db.mock.ExpectBegin() db.mock.ExpectQuery(query). WithArgs(value). WillReturnRows(sqlmock.NewRows([]string{key}). AddRow(key)) + db.mock.ExpectCommit() return db } func (db *dbMock) expectGetByIDErr(table, key, value string, err error) *dbMock { query := fmt.Sprintf(expectedGetByID, table, key) + db.mock.ExpectBegin() db.mock.ExpectQuery(query). WithArgs(value). WillReturnError(err) + db.mock.ExpectCommit() return db } func (db *dbMock) expectGetByQuery(table, key, method, value string) *dbMock { query := fmt.Sprintf(expectedGetByQuery, table, key, method) + db.mock.ExpectBegin() db.mock.ExpectQuery(query). WithArgs(value). WillReturnRows(sqlmock.NewRows([]string{key}). AddRow(key)) + db.mock.ExpectCommit() return db } func (db *dbMock) expectGetByQueryCaseSensitive(table, key, method, value string) *dbMock { query := fmt.Sprintf(expectedGetByQueryCaseSensitive, table, key, method) + db.mock.ExpectBegin() db.mock.ExpectQuery(query). WithArgs(value). WillReturnRows(sqlmock.NewRows([]string{key}). AddRow(key)) + db.mock.ExpectCommit() return db } func (db *dbMock) expectGetByQueryErr(table, key, method, value string, err error) *dbMock { query := fmt.Sprintf(expectedGetByQuery, table, key, method) + db.mock.ExpectBegin() db.mock.ExpectQuery(query). WithArgs(value). WillReturnError(err) + db.mock.ExpectCommit() return db } @@ -313,10 +323,14 @@ func (db *dbMock) expectGetSearchRequestNoParams(table string, resultAmount, tot rows.AddRow(fmt.Sprintf("hodor-%d", i)) } + db.mock.ExpectBegin() + db.mock.ExpectQuery(queryCount). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total)) db.mock.ExpectQuery(query). WillReturnRows(rows) + + db.mock.ExpectCommit() return db } @@ -329,10 +343,12 @@ func (db *dbMock) expectGetSearchRequestWithLimit(table string, limit, resultAmo rows.AddRow(fmt.Sprintf("hodor-%d", i)) } + db.mock.ExpectBegin() db.mock.ExpectQuery(queryCount). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total)) db.mock.ExpectQuery(query). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -345,10 +361,12 @@ func (db *dbMock) expectGetSearchRequestWithOffset(table string, offset, resultA rows.AddRow(fmt.Sprintf("hodor-%d", i)) } + db.mock.ExpectBegin() db.mock.ExpectQuery(queryCount). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total)) db.mock.ExpectQuery(query). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -361,10 +379,12 @@ func (db *dbMock) expectGetSearchRequestWithSorting(table, sorting string, sorti rows.AddRow(fmt.Sprintf("hodor-%d", i)) } + db.mock.ExpectBegin() db.mock.ExpectQuery(queryCount). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total)) db.mock.ExpectQuery(query). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -377,12 +397,14 @@ func (db *dbMock) expectGetSearchRequestWithSearchQuery(table, key, method, valu rows.AddRow(fmt.Sprintf("hodor-%d", i)) } + db.mock.ExpectBegin() db.mock.ExpectQuery(queryCount). WithArgs(value). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total)) db.mock.ExpectQuery(query). WithArgs(value). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -395,12 +417,14 @@ func (db *dbMock) expectGetSearchRequestWithAllParams(table, key, method, value, rows.AddRow(fmt.Sprintf("hodor-%d", i)) } + db.mock.ExpectBegin() db.mock.ExpectQuery(queryCount). WithArgs(value). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total)) db.mock.ExpectQuery(query). WithArgs(value). WillReturnRows(rows) + db.mock.ExpectCommit() return db } @@ -413,9 +437,11 @@ func (db *dbMock) expectGetSearchRequestErr(table string, resultAmount, total in rows.AddRow(fmt.Sprintf("hodor-%d", i)) } + db.mock.ExpectBegin() db.mock.ExpectQuery(queryCount). WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total)) db.mock.ExpectQuery(query). WillReturnError(err) + db.mock.ExpectCommit() return db } diff --git a/internal/view/repository/query.go b/internal/view/repository/query.go index 36d41142a6..77d27c75be 100644 --- a/internal/view/repository/query.go +++ b/internal/view/repository/query.go @@ -1,10 +1,14 @@ package repository import ( + "context" + "database/sql" "fmt" "github.com/jinzhu/gorm" + "github.com/zitadel/logging" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" caos_errs "github.com/zitadel/zitadel/internal/errors" @@ -48,6 +52,13 @@ func PrepareSearchQuery(table string, request SearchRequest) func(db *gorm.DB, r } } + query = query.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + defer func() { + if err := query.Commit().Error; err != nil { + logging.OnError(err).Info("commit failed") + } + }() + query = query.Count(&count) if res == nil { return count, nil diff --git a/internal/view/repository/requests.go b/internal/view/repository/requests.go index a624643c46..74921fbcbb 100644 --- a/internal/view/repository/requests.go +++ b/internal/view/repository/requests.go @@ -1,6 +1,8 @@ package repository import ( + "context" + "database/sql" "errors" "fmt" "strings" @@ -13,7 +15,14 @@ import ( func PrepareGetByKey(table string, key ColumnKey, id string) func(db *gorm.DB, res interface{}) error { return func(db *gorm.DB, res interface{}) error { - err := db.Table(table). + tx := db.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + defer func() { + if err := tx.Commit().Error; err != nil { + logging.OnError(err).Info("commit failed") + } + }() + + err := tx.Table(table). Where(fmt.Sprintf("%s = ?", key.ToColumnName()), id). Take(res). Error @@ -39,7 +48,14 @@ func PrepareGetByQuery(table string, queries ...SearchQuery) func(db *gorm.DB, r } } - err := query.Take(res).Error + tx := query.BeginTx(context.Background(), &sql.TxOptions{ReadOnly: true}) + defer func() { + if err := tx.Commit().Error; err != nil { + logging.OnError(err).Info("commit failed") + } + }() + + err := tx.Take(res).Error if err == nil { return nil }