feat(storage): read only transactions for queries (#6415)

* fix: tests

* bastle wie en grosse

* fix(database): scan as callback

* fix tests

* fix merge failures

* remove as of system time

* refactor: remove unused test

* refacotr: remove unused lines
This commit is contained in:
Silvan 2023-08-22 12:49:22 +02:00 committed by GitHub
parent a9fb2a6e5c
commit 99e1c654a3
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
128 changed files with 1355 additions and 897 deletions

View File

@ -128,5 +128,5 @@ func keyStorage(config database.Config, masterKey string) (crypto.KeyStorage, er
if err != nil { if err != nil {
return nil, err return nil, err
} }
return cryptoDB.NewKeyStorage(db.DB, masterKey) return cryptoDB.NewKeyStorage(db, masterKey)
} }

View File

@ -2,7 +2,6 @@ package setup
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"os" "os"
"strings" "strings"
@ -14,6 +13,7 @@ import (
"github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/config/systemdefaults"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
crypto_db "github.com/zitadel/zitadel/internal/crypto/database" crypto_db "github.com/zitadel/zitadel/internal/crypto/database"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
) )
@ -30,7 +30,7 @@ type FirstInstance struct {
smtpEncryptionKey *crypto.KeyConfig smtpEncryptionKey *crypto.KeyConfig
oidcEncryptionKey *crypto.KeyConfig oidcEncryptionKey *crypto.KeyConfig
masterKey string masterKey string
db *sql.DB db *database.DB
es *eventstore.Eventstore es *eventstore.Eventstore
defaults systemdefaults.SystemDefaults defaults systemdefaults.SystemDefaults
zitadelRoles []authz.RoleMapping zitadelRoles []authz.RoleMapping

View File

@ -77,7 +77,7 @@ func Setup(config *Config, steps *Steps, masterKey string) {
steps.FirstInstance.smtpEncryptionKey = config.EncryptionKeys.SMTP steps.FirstInstance.smtpEncryptionKey = config.EncryptionKeys.SMTP
steps.FirstInstance.oidcEncryptionKey = config.EncryptionKeys.OIDC steps.FirstInstance.oidcEncryptionKey = config.EncryptionKeys.OIDC
steps.FirstInstance.masterKey = masterKey steps.FirstInstance.masterKey = masterKey
steps.FirstInstance.db = dbClient.DB steps.FirstInstance.db = dbClient
steps.FirstInstance.es = eventstoreClient steps.FirstInstance.es = eventstoreClient
steps.FirstInstance.defaults = config.SystemDefaults steps.FirstInstance.defaults = config.SystemDefaults
steps.FirstInstance.zitadelRoles = config.InternalAuthZ.RolePermissionMappings steps.FirstInstance.zitadelRoles = config.InternalAuthZ.RolePermissionMappings

View File

@ -124,7 +124,7 @@ func startZitadel(config *Config, masterKey string, server chan<- *Server) error
return fmt.Errorf("cannot start client for projection: %w", err) return fmt.Errorf("cannot start client for projection: %w", err)
} }
keyStorage, err := cryptoDB.NewKeyStorage(dbClient.DB, masterKey) keyStorage, err := cryptoDB.NewKeyStorage(dbClient, masterKey)
if err != nil { if err != nil {
return fmt.Errorf("cannot start key storage: %w", err) return fmt.Errorf("cannot start key storage: %w", err)
} }

View File

@ -15,7 +15,7 @@ type View struct {
} }
func StartView(sqlClient *database.DB) (*View, error) { func StartView(sqlClient *database.DB) (*View, error) {
gorm, err := gorm.Open("postgres", sqlClient) gorm, err := gorm.Open("postgres", sqlClient.DB)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -23,7 +23,7 @@ type View struct {
} }
func StartView(sqlClient *database.DB, keyAlgorithm crypto.EncryptionAlgorithm, queries *query.Queries, idGenerator id.Generator, es eventstore.Eventstore) (*View, error) { func StartView(sqlClient *database.DB, keyAlgorithm crypto.EncryptionAlgorithm, queries *query.Queries, idGenerator id.Generator, es eventstore.Eventstore) (*View, error) {
gorm, err := gorm.Open("postgres", sqlClient) gorm, err := gorm.Open("postgres", sqlClient.DB)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -59,7 +59,12 @@ func (c *AuthRequestCache) getAuthRequest(key, value, instanceID string) (*domai
var b []byte var b []byte
var requestType domain.AuthRequestType var requestType domain.AuthRequestType
query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE instance_id = $1 and %s = $2", key) query := fmt.Sprintf("SELECT request, request_type FROM auth.auth_requests WHERE instance_id = $1 and %s = $2", key)
err := c.client.QueryRow(query, instanceID, value).Scan(&b, &requestType) err := c.client.QueryRow(
func(row *sql.Row) error {
return row.Scan(&b, &requestType)
},
query, instanceID, value)
if err != nil { if err != nil {
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, caos_errs.ThrowNotFound(err, "CACHE-d24aD", "Errors.AuthRequest.NotFound") return nil, caos_errs.ThrowNotFound(err, "CACHE-d24aD", "Errors.AuthRequest.NotFound")

View File

@ -19,7 +19,7 @@ type View struct {
} }
func StartView(sqlClient *database.DB, idGenerator id.Generator, queries *query.Queries) (*View, error) { func StartView(sqlClient *database.DB, idGenerator id.Generator, queries *query.Queries) (*View, error) {
gorm, err := gorm.Open("postgres", sqlClient) gorm, err := gorm.Open("postgres", sqlClient.DB)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -6,11 +6,12 @@ import (
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
z_db "github.com/zitadel/zitadel/internal/database"
caos_errs "github.com/zitadel/zitadel/internal/errors" caos_errs "github.com/zitadel/zitadel/internal/errors"
) )
type database struct { type database struct {
client *sql.DB client *z_db.DB
masterKey string masterKey string
encrypt func(key, masterKey string) (encryptedKey string, err error) encrypt func(key, masterKey string) (encryptedKey string, err error)
decrypt func(encryptedKey, masterKey string) (key string, err error) decrypt func(encryptedKey, masterKey string) (key string, err error)
@ -22,7 +23,7 @@ const (
encryptionKeysKeyCol = "key" encryptionKeysKeyCol = "key"
) )
func NewKeyStorage(client *sql.DB, masterKey string) (*database, error) { func NewKeyStorage(client *z_db.DB, masterKey string) (*database, error) {
if err := checkMasterKeyLength(masterKey); err != nil { if err := checkMasterKeyLength(masterKey); err != nil {
return nil, err return nil, err
} }
@ -42,32 +43,30 @@ func (d *database) ReadKeys() (crypto.Keys, error) {
if err != nil { if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys") return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
} }
rows, err := d.client.Query(stmt, args...) err = d.client.Query(func(rows *sql.Rows) error {
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
}
for rows.Next() { for rows.Next() {
var id, encryptionKey string var id, encryptionKey string
err = rows.Scan(&id, &encryptionKey) err = rows.Scan(&id, &encryptionKey)
if err != nil { if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys") return caos_errs.ThrowInternal(err, "", "unable to read keys")
} }
key, err := d.decrypt(encryptionKey, d.masterKey) key, err := d.decrypt(encryptionKey, d.masterKey)
if err != nil { if err != nil {
if err := rows.Close(); err != nil { return caos_errs.ThrowInternal(err, "", "unable to decrypt key")
return nil, caos_errs.ThrowInternal(err, "", "unable to close rows")
}
return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key")
} }
keys[id] = key keys[id] = key
} }
if err := rows.Close(); err != nil { return nil
return nil, caos_errs.ThrowInternal(err, "", "unable to close rows") }, stmt, args...)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read keys")
} }
return keys, err
return keys, nil
} }
func (d *database) ReadKey(id string) (*crypto.Key, error) { func (d *database) ReadKey(id string) (_ *crypto.Key, err error) {
stmt, args, err := sq.Select(encryptionKeysKeyCol). stmt, args, err := sq.Select(encryptionKeysKeyCol).
From(EncryptionKeysTable). From(EncryptionKeysTable).
Where(sq.Eq{encryptionKeysIDCol: id}). Where(sq.Eq{encryptionKeysIDCol: id}).
@ -76,19 +75,23 @@ func (d *database) ReadKey(id string) (*crypto.Key, error) {
if err != nil { if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read key") return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
} }
row := d.client.QueryRow(stmt, args...) var key string
if err != nil { err = d.client.QueryRow(func(row *sql.Row) error {
return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
}
var encryptionKey string var encryptionKey string
err = row.Scan(&encryptionKey) err = row.Scan(&encryptionKey)
if err != nil {
return caos_errs.ThrowInternal(err, "", "unable to read key")
}
key, err = d.decrypt(encryptionKey, d.masterKey)
if err != nil {
return caos_errs.ThrowInternal(err, "", "unable to decrypt key")
}
return nil
}, stmt, args...)
if err != nil { if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to read key") return nil, caos_errs.ThrowInternal(err, "", "unable to read key")
} }
key, err := d.decrypt(encryptionKey, d.masterKey)
if err != nil {
return nil, caos_errs.ThrowInternal(err, "", "unable to decrypt key")
}
return &crypto.Key{ return &crypto.Key{
ID: id, ID: id,
Value: key, Value: key,

View File

@ -12,6 +12,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
z_db "github.com/zitadel/zitadel/internal/database"
caos_errs "github.com/zitadel/zitadel/internal/errors" caos_errs "github.com/zitadel/zitadel/internal/errors"
) )
@ -46,7 +47,7 @@ func Test_database_ReadKeys(t *testing.T) {
{ {
"decryption error", "decryption error",
fields{ fields{
client: dbMock(t, expectQuery( client: dbMock(t, expectQueryScanErr(
"SELECT id, key FROM system.encryption_keys", "SELECT id, key FROM system.encryption_keys",
[]string{"id", "key"}, []string{"id", "key"},
[][]driver.Value{ [][]driver.Value{
@ -172,7 +173,7 @@ func Test_database_ReadKey(t *testing.T) {
{ {
"key not found err", "key not found err",
fields{ fields{
client: dbMock(t, expectQuery( client: dbMock(t, expectQueryScanErr(
"SELECT key FROM system.encryption_keys WHERE id = $1", "SELECT key FROM system.encryption_keys WHERE id = $1",
nil, nil,
nil, nil,
@ -192,7 +193,7 @@ func Test_database_ReadKey(t *testing.T) {
{ {
"decryption error", "decryption error",
fields{ fields{
client: dbMock(t, expectQuery( client: dbMock(t, expectQueryScanErr(
"SELECT key FROM system.encryption_keys WHERE id = $1", "SELECT key FROM system.encryption_keys WHERE id = $1",
[]string{"key"}, []string{"key"},
[][]driver.Value{ [][]driver.Value{
@ -445,7 +446,7 @@ func Test_checkMasterKeyLength(t *testing.T) {
type db struct { type db struct {
mock sqlmock.Sqlmock mock sqlmock.Sqlmock
db *sql.DB db *z_db.DB
} }
func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db { func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
@ -459,19 +460,41 @@ func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
} }
return db{ return db{
mock: mock, mock: mock,
db: client, db: &z_db.DB{DB: client},
} }
} }
func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) { func expectQueryErr(query string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) {
m.ExpectBegin()
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err) m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnError(err)
m.ExpectRollback()
}
}
func expectQueryScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
m.ExpectBegin()
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
} }
} }
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) { func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) {
m.ExpectBegin()
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...) q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols) result := sqlmock.NewRows(cols)
count := uint64(len(rows)) count := uint64(len(rows))
for _, row := range rows { for _, row := range rows {

View File

@ -2,7 +2,6 @@ package cockroach
import ( import (
"database/sql" "database/sql"
"fmt"
"strconv" "strconv"
"strings" "strings"
"time" "time"
@ -94,12 +93,7 @@ func (c *Config) Type() string {
} }
func (c *Config) Timetravel(d time.Duration) string { func (c *Config) Timetravel(d time.Duration) string {
// verify that it is at least 1 micro second return ""
if d < time.Microsecond {
d = time.Microsecond
}
return fmt.Sprintf(" AS OF SYSTEM TIME '-%d µs' ", d.Microseconds())
} }
type User struct { type User struct {

View File

@ -1,61 +0,0 @@
package cockroach
import (
"testing"
"time"
)
func TestConfig_Timetravel(t *testing.T) {
type args struct {
d time.Duration
}
tests := []struct {
name string
args args
want string
}{
{
name: "no duration",
args: args{
d: 0,
},
want: " AS OF SYSTEM TIME '-1 µs' ",
},
{
name: "less than microsecond",
args: args{
d: 100 * time.Nanosecond,
},
want: " AS OF SYSTEM TIME '-1 µs' ",
},
{
name: "10 microseconds",
args: args{
d: 10 * time.Microsecond,
},
want: " AS OF SYSTEM TIME '-10 µs' ",
},
{
name: "10 milliseconds",
args: args{
d: 10 * time.Millisecond,
},
want: " AS OF SYSTEM TIME '-10000 µs' ",
},
{
name: "1 second",
args: args{
d: 1 * time.Second,
},
want: " AS OF SYSTEM TIME '-1000000 µs' ",
},
}
for _, tt := range tests {
c := &Config{}
t.Run(tt.name, func(t *testing.T) {
if got := c.Timetravel(tt.args.d); got != tt.want {
t.Errorf("Config.Timetravel() = %q, want %q", got, tt.want)
}
})
}
}

View File

@ -1,9 +1,12 @@
package database package database
import ( import (
"context"
"database/sql" "database/sql"
"reflect" "reflect"
"github.com/zitadel/logging"
_ "github.com/zitadel/zitadel/internal/database/cockroach" _ "github.com/zitadel/zitadel/internal/database/cockroach"
"github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/database/dialect"
_ "github.com/zitadel/zitadel/internal/database/postgres" _ "github.com/zitadel/zitadel/internal/database/postgres"
@ -24,6 +27,66 @@ type DB struct {
dialect.Database dialect.Database
} }
func (db *DB) Query(scan func(*sql.Rows) error, query string, args ...any) error {
return db.QueryContext(context.Background(), scan, query, args...)
}
func (db *DB) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) (err error) {
tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
return err
}
defer func() {
if err != nil {
rollbackErr := tx.Rollback()
logging.OnError(rollbackErr).Info("commit of read only transaction failed")
return
}
err = tx.Commit()
}()
rows, err := tx.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Info("rows.Close failed")
}()
if err = scan(rows); err != nil {
return err
}
return rows.Err()
}
func (db *DB) QueryRow(scan func(*sql.Row) error, query string, args ...any) (err error) {
return db.QueryRowContext(context.Background(), scan, query, args...)
}
func (db *DB) QueryRowContext(ctx context.Context, scan func(row *sql.Row) error, query string, args ...any) (err error) {
tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
return err
}
defer func() {
if err != nil {
rollbackErr := tx.Rollback()
logging.OnError(rollbackErr).Info("commit of read only transaction failed")
return
}
err = tx.Commit()
}()
row := tx.QueryRowContext(ctx, query, args...)
err = scan(row)
if err != nil {
return err
}
return row.Err()
}
func Connect(config Config, useAdmin bool) (*DB, error) { func Connect(config Config, useAdmin bool) (*DB, error) {
client, err := config.connector.Connect(useAdmin) client, err := config.connector.Connect(useAdmin)
if err != nil { if err != nil {

View File

@ -13,6 +13,7 @@ import (
const ( const (
currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE` currentSequenceStmtFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2) FOR UPDATE`
currentSequenceStmtWithoutLockFormat = `SELECT current_sequence, aggregate_type, instance_id FROM %s WHERE projection_name = $1 AND instance_id = ANY ($2)`
updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES ` updateCurrentSequencesStmtFormat = `INSERT INTO %s (projection_name, aggregate_type, current_sequence, instance_id, timestamp) VALUES `
updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp` updateCurrentSequencesConflictStmt = ` ON CONFLICT (projection_name, aggregate_type, instance_id) DO UPDATE SET current_sequence = EXCLUDED.current_sequence, timestamp = EXCLUDED.timestamp`
) )
@ -24,15 +25,15 @@ type instanceSequence struct {
sequence uint64 sequence uint64
} }
func (h *StatementHandler) currentSequences(ctx context.Context, query func(context.Context, string, ...interface{}) (*sql.Rows, error), instanceIDs database.StringArray) (currentSequences, error) { func (h *StatementHandler) currentSequences(ctx context.Context, isTx bool, query func(context.Context, func(*sql.Rows) error, string, ...interface{}) error, instanceIDs database.StringArray) (currentSequences, error) {
rows, err := query(ctx, h.currentSequenceStmt, h.ProjectionName, instanceIDs) stmt := h.currentSequenceStmt
if err != nil { if !isTx {
return nil, err stmt = h.currentSequenceWithoutLockStmt
} }
defer rows.Close()
sequences := make(currentSequences, len(h.aggregates)) sequences := make(currentSequences, len(h.aggregates))
err := query(ctx,
func(rows *sql.Rows) error {
for rows.Next() { for rows.Next() {
var ( var (
aggregateType eventstore.AggregateType aggregateType eventstore.AggregateType
@ -40,9 +41,9 @@ func (h *StatementHandler) currentSequences(ctx context.Context, query func(cont
instanceID string instanceID string
) )
err = rows.Scan(&sequence, &aggregateType, &instanceID) err := rows.Scan(&sequence, &aggregateType, &instanceID)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "CRDB-dbatK", "scan failed") return errors.ThrowInternal(err, "CRDB-dbatK", "scan failed")
} }
sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{ sequences[aggregateType] = append(sequences[aggregateType], &instanceSequence{
@ -50,15 +51,12 @@ func (h *StatementHandler) currentSequences(ctx context.Context, query func(cont
instanceID: instanceID, instanceID: instanceID,
}) })
} }
return nil
if err = rows.Close(); err != nil { },
return nil, errors.ThrowInternal(err, "CRDB-h5i5m", "close rows failed") stmt, h.ProjectionName, instanceIDs)
if err != nil {
return nil, err
} }
if err = rows.Err(); err != nil {
return nil, errors.ThrowInternal(err, "CRDB-O8zig", "errors in scanning rows")
}
return sequences, nil return sequences, nil
} }

View File

@ -124,13 +124,17 @@ func expectSavePointRelease() func(sqlmock.Sqlmock) {
} }
} }
func expectCurrentSequence(tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) { func expectCurrentSequence(isTx bool, tableName, projection string, seq uint64, aggregateType string, instanceIDs []string) func(sqlmock.Sqlmock) {
rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"}) rows := sqlmock.NewRows([]string{"current_sequence", "aggregate_type", "instance_id"})
for _, instanceID := range instanceIDs { for _, instanceID := range instanceIDs {
rows.AddRow(seq, aggregateType, instanceID) rows.AddRow(seq, aggregateType, instanceID)
} }
return func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`). stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
if isTx {
stmt += " FOR UPDATE"
}
m.ExpectQuery(stmt).
WithArgs( WithArgs(
projection, projection,
database.StringArray(instanceIDs), database.StringArray(instanceIDs),
@ -141,9 +145,13 @@ func expectCurrentSequence(tableName, projection string, seq uint64, aggregateTy
} }
} }
func expectCurrentSequenceErr(tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) { func expectCurrentSequenceErr(isTx bool, tableName, projection string, instanceIDs []string, err error) func(sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) { return func(m sqlmock.Sqlmock) {
m.ExpectQuery(`SELECT current_sequence, aggregate_type, instance_id FROM `+tableName+` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\) FOR UPDATE`). stmt := `SELECT current_sequence, aggregate_type, instance_id FROM ` + tableName + ` WHERE projection_name = \$1 AND instance_id = ANY \(\$2\)`
if isTx {
stmt += " FOR UPDATE"
}
m.ExpectQuery(stmt).
WithArgs( WithArgs(
projection, projection,
database.StringArray(instanceIDs), database.StringArray(instanceIDs),

View File

@ -39,6 +39,7 @@ type StatementHandler struct {
client *database.DB client *database.DB
sequenceTable string sequenceTable string
currentSequenceStmt string currentSequenceStmt string
currentSequenceWithoutLockStmt string
updateSequencesBaseStmt string updateSequencesBaseStmt string
maxFailureCount uint maxFailureCount uint
failureCountStmt string failureCountStmt string
@ -81,6 +82,7 @@ func NewStatementHandler(
sequenceTable: config.SequenceTable, sequenceTable: config.SequenceTable,
maxFailureCount: config.MaxFailureCount, maxFailureCount: config.MaxFailureCount,
currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable), currentSequenceStmt: fmt.Sprintf(currentSequenceStmtFormat, config.SequenceTable),
currentSequenceWithoutLockStmt: fmt.Sprintf(currentSequenceStmtWithoutLockFormat, config.SequenceTable),
updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable), updateSequencesBaseStmt: fmt.Sprintf(updateCurrentSequencesStmtFormat, config.SequenceTable),
failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable), failureCountStmt: fmt.Sprintf(failureCountStmtFormat, config.FailedEventsTable),
setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable), setFailureCountStmt: fmt.Sprintf(setFailureCountStmtFormat, config.FailedEventsTable),
@ -114,7 +116,7 @@ func (h *StatementHandler) searchQuery(ctx context.Context, instanceIDs []string
} }
func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) { func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []string) (*eventstore.SearchQueryBuilder, uint64, error) {
sequences, err := h.currentSequences(ctx, h.client.QueryContext, instanceIDs) sequences, err := h.currentSequences(ctx, false, h.client.QueryContext, instanceIDs)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -140,6 +142,26 @@ func (h *StatementHandler) dbSearchQuery(ctx context.Context, instanceIDs []stri
return queryBuilder, h.bulkLimit, nil return queryBuilder, h.bulkLimit, nil
} }
type transaction struct {
*sql.Tx
}
func (t *transaction) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
rows, err := t.Tx.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Info("rows.Close failed")
}()
if err = scan(rows); err != nil {
return err
}
return rows.Err()
}
// Update implements handler.Update // Update implements handler.Update
func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) { func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statement, reduce handler.Reduce) (index int, err error) {
if len(stmts) == 0 { if len(stmts) == 0 {
@ -154,7 +176,7 @@ func (h *StatementHandler) Update(ctx context.Context, stmts []*handler.Statemen
return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed") return -1, errors.ThrowInternal(err, "CRDB-e89Gq", "begin failed")
} }
sequences, err := h.currentSequences(ctx, tx.QueryContext, instanceIDs) sequences, err := h.currentSequences(ctx, true, (&transaction{Tx: tx}).QueryContext, instanceIDs)
if err != nil { if err != nil {
tx.Rollback() tx.Rollback()
return -1, err return -1, err

View File

@ -90,7 +90,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
return errors.Is(err, sql.ErrTxDone) return errors.Is(err, sql.ErrTxDone)
}, },
expectations: []mockExpectation{ expectations: []mockExpectation{
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone), expectBegin(),
expectCurrentSequenceErr(false, "my_sequences", "my_projection", []string{"instanceID1"}, sql.ErrTxDone),
expectRollback(),
}, },
SearchQueryBuilder: nil, SearchQueryBuilder: nil,
}, },
@ -112,7 +114,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
return err == nil return err == nil
}, },
expectations: []mockExpectation{ expectations: []mockExpectation{
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}), expectBegin(),
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1"}),
expectCommit(),
}, },
SearchQueryBuilder: eventstore. SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent). NewSearchQueryBuilder(eventstore.ColumnsEvent).
@ -142,7 +146,9 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
return err == nil return err == nil
}, },
expectations: []mockExpectation{ expectations: []mockExpectation{
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}), expectBegin(),
expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID1", "instanceID2"}),
expectCommit(),
}, },
SearchQueryBuilder: eventstore. SearchQueryBuilder: eventstore.
NewSearchQueryBuilder(eventstore.ColumnsEvent). NewSearchQueryBuilder(eventstore.ColumnsEvent).
@ -216,6 +222,7 @@ func TestProjectionHandler_SearchQuery(t *testing.T) {
t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err) t.Errorf("ProjectionHandler.prepareBulkStmts() error = %v", err)
return return
} }
if !reflect.DeepEqual(query, tt.want.SearchQueryBuilder) { if !reflect.DeepEqual(query, tt.want.SearchQueryBuilder) {
t.Errorf("unexpected query: expected %v, got %v", tt.want.SearchQueryBuilder, query) t.Errorf("unexpected query: expected %v, got %v", tt.want.SearchQueryBuilder, query)
} }
@ -289,7 +296,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequenceErr("my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone), expectCurrentSequenceErr(false, "my_sequences", "my_projection", []string{"instanceID"}, sql.ErrTxDone),
expectRollback(), expectRollback(),
}, },
isErr: func(err error) bool { isErr: func(err error) bool {
@ -321,7 +328,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectRollback(), expectRollback(),
}, },
isErr: func(err error) bool { isErr: func(err error) bool {
@ -360,7 +367,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectCommit(), expectCommit(),
}, },
isErr: func(err error) bool { isErr: func(err error) bool {
@ -399,7 +406,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}), expectCurrentSequence(false, "my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
expectSavePoint(), expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}), expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(), expectSavePointRelease(),
@ -442,7 +449,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "agg", []string{"instanceID"}), expectCurrentSequence(false, "my_sequences", "my_projection", 5, "agg", []string{"instanceID"}),
expectSavePoint(), expectSavePoint(),
expectCreate("my_projection", []string{"col"}, []string{"$1"}), expectCreate("my_projection", []string{"col"}, []string{"$1"}),
expectSavePointRelease(), expectSavePointRelease(),
@ -478,7 +485,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(), expectCommit(),
}, },
@ -511,7 +518,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(), expectCommit(),
}, },
@ -551,7 +558,7 @@ func TestStatementHandler_Update(t *testing.T) {
want: want{ want: want{
expectations: []mockExpectation{ expectations: []mockExpectation{
expectBegin(), expectBegin(),
expectCurrentSequence("my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}), expectCurrentSequence(false, "my_sequences", "my_projection", 5, "testAgg", []string{"instanceID"}),
expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"), expectUpdateCurrentSequence("my_sequences", "my_projection", 7, "testAgg", "instanceID"),
expectCommit(), expectCommit(),
}, },
@ -1425,7 +1432,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
return errors.Is(err, sql.ErrConnDone) return errors.Is(err, sql.ErrConnDone)
}, },
expectations: []mockExpectation{ expectations: []mockExpectation{
expectCurrentSequenceErr("my_table", "my_projection", nil, sql.ErrConnDone), expectCurrentSequenceErr(true, "my_table", "my_projection", nil, sql.ErrConnDone),
}, },
}, },
}, },
@ -1487,7 +1494,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
return errors.Is(err, nil) return errors.Is(err, nil)
}, },
expectations: []mockExpectation{ expectations: []mockExpectation{
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID"}), expectCurrentSequence(true, "my_table", "my_projection", 5, "agg", []string{"instanceID"}),
}, },
sequences: currentSequences{ sequences: currentSequences{
"agg": []*instanceSequence{ "agg": []*instanceSequence{
@ -1515,7 +1522,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
return errors.Is(err, nil) return errors.Is(err, nil)
}, },
expectations: []mockExpectation{ expectations: []mockExpectation{
expectCurrentSequence("my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}), expectCurrentSequence(true, "my_table", "my_projection", 5, "agg", []string{"instanceID1", "instanceID2"}),
}, },
sequences: currentSequences{ sequences: currentSequences{
"agg": []*instanceSequence{ "agg": []*instanceSequence{
@ -1563,7 +1570,7 @@ func TestStatementHandler_currentSequence(t *testing.T) {
t.Fatalf("unexpected err in begin: %v", err) t.Fatalf("unexpected err in begin: %v", err)
} }
seq, err := h.currentSequences(context.Background(), tx.QueryContext, tt.args.instanceIDs) seq, err := h.currentSequences(context.Background(), true, (&transaction{Tx: tx}).QueryContext, tt.args.instanceIDs)
if !tt.want.isErr(err) { if !tt.want.isErr(err) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
} }

View File

@ -161,14 +161,19 @@ func (db *CRDB) Push(ctx context.Context, events []*repository.Event, uniqueCons
var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`) var instanceRegexp = regexp.MustCompile(`eventstore\.i_[0-9a-zA-Z]{1,}_seq`)
func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error { func (db *CRDB) CreateInstance(ctx context.Context, instanceID string) error {
row := db.QueryRowContext(ctx, "SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID)
if row.Err() != nil {
return caos_errs.ThrowInvalidArgument(row.Err(), "SQL-7gtFA", "Errors.InvalidArgument")
}
var sequenceName string var sequenceName string
err := db.QueryRowContext(ctx,
func(row *sql.Row) error {
if err := row.Scan(&sequenceName); err != nil || !instanceRegexp.MatchString(sequenceName) { if err := row.Scan(&sequenceName); err != nil || !instanceRegexp.MatchString(sequenceName) {
return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument") return caos_errs.ThrowInvalidArgument(err, "SQL-7gtFA", "Errors.InvalidArgument")
} }
return nil
},
"SELECT CONCAT('eventstore.i_', $1::TEXT, '_seq')", instanceID,
)
if err != nil {
return err
}
if _, err := db.ExecContext(ctx, "CREATE SEQUENCE "+sequenceName); err != nil { if _, err := db.ExecContext(ctx, "CREATE SEQUENCE "+sequenceName); err != nil {
return caos_errs.ThrowInternal(err, "SQL-7gtFA", "Errors.Internal") return caos_errs.ThrowInternal(err, "SQL-7gtFA", "Errors.Internal")
@ -220,9 +225,9 @@ func (db *CRDB) handleUniqueConstraints(ctx context.Context, tx *sql.Tx, uniqueC
} }
// Filter returns all events matching the given search query // Filter returns all events matching the given search query
func (db *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) { func (crdb *CRDB) Filter(ctx context.Context, searchQuery *repository.SearchQuery) (events []*repository.Event, err error) {
events = []*repository.Event{} events = []*repository.Event{}
err = query(ctx, db, searchQuery, &events) err = query(ctx, crdb, searchQuery, &events)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -250,8 +255,8 @@ func (db *CRDB) InstanceIDs(ctx context.Context, searchQuery *repository.SearchQ
return ids, nil return ids, nil
} }
func (db *CRDB) db() *sql.DB { func (db *CRDB) db() *database.DB {
return db.DB.DB return db.DB
} }
func (db *CRDB) orderByEventSequence(desc bool) string { func (db *CRDB) orderByEventSequence(desc bool) string {

View File

@ -11,6 +11,7 @@ import (
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect" "github.com/zitadel/zitadel/internal/database/dialect"
z_errors "github.com/zitadel/zitadel/internal/errors" z_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore/repository" "github.com/zitadel/zitadel/internal/eventstore/repository"
@ -24,13 +25,33 @@ type querier interface {
eventQuery() string eventQuery() string
maxSequenceQuery() string maxSequenceQuery() string
instanceIDsQuery() string instanceIDsQuery() string
db() *sql.DB db() *database.DB
orderByEventSequence(desc bool) string orderByEventSequence(desc bool) string
dialect.Database dialect.Database
} }
type scan func(dest ...interface{}) error type scan func(dest ...interface{}) error
type tx struct {
*sql.Tx
}
func (t *tx) QueryContext(ctx context.Context, scan func(rows *sql.Rows) error, query string, args ...any) error {
rows, err := t.Tx.QueryContext(ctx, query, args...)
if err != nil {
return err
}
defer func() {
closeErr := rows.Close()
logging.OnError(closeErr).Info("rows.Close failed")
}()
if err = scan(rows); err != nil {
return err
}
return rows.Err()
}
func query(ctx context.Context, criteria querier, searchQuery *repository.SearchQuery, dest interface{}) error { func query(ctx context.Context, criteria querier, searchQuery *repository.SearchQuery, dest interface{}) error {
query, rowScanner := prepareColumns(criteria, searchQuery.Columns) query, rowScanner := prepareColumns(criteria, searchQuery.Columns)
where, values := prepareCondition(criteria, searchQuery.Filters) where, values := prepareCondition(criteria, searchQuery.Filters)
@ -56,26 +77,27 @@ func query(ctx context.Context, criteria querier, searchQuery *repository.Search
query = criteria.placeholder(query) query = criteria.placeholder(query)
var contextQuerier interface { var contextQuerier interface {
QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) QueryContext(context.Context, func(rows *sql.Rows) error, string, ...interface{}) error
} }
contextQuerier = criteria.db() contextQuerier = criteria.db()
if searchQuery.Tx != nil { if searchQuery.Tx != nil {
contextQuerier = searchQuery.Tx contextQuerier = &tx{Tx: searchQuery.Tx}
} }
rows, err := contextQuerier.QueryContext(ctx, query, values...) err := contextQuerier.QueryContext(ctx,
if err != nil { func(rows *sql.Rows) error {
logging.New().WithError(err).Info("query failed")
return z_errors.ThrowInternal(err, "SQL-KyeAx", "unable to filter events")
}
defer rows.Close()
for rows.Next() { for rows.Next() {
err = rowScanner(rows.Scan, dest) err := rowScanner(rows.Scan, dest)
if err != nil { if err != nil {
return err return err
} }
} }
return nil
}, query, values...)
if err != nil {
logging.New().WithError(err).Info("query failed")
return z_errors.ThrowInternal(err, "SQL-KyeAx", "unable to filter events")
}
return nil return nil
} }

View File

@ -741,7 +741,7 @@ func Test_query_events_mocked(t *testing.T) {
}, },
}, },
fields: fields{ fields: fields{
mock: newMockClient(t).expectQuery(t, mock: newMockClient(t).expectQueryScanErr(t,
`SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC`, `SELECT creation_date, event_type, event_sequence, previous_aggregate_sequence, previous_aggregate_type_sequence, event_data, editor_service, editor_user, resource_owner, instance_id, aggregate_type, aggregate_id, aggregate_version FROM eventstore.events WHERE \( aggregate_type = \$1 \) ORDER BY creation_date DESC, event_sequence DESC`,
[]driver.Value{repository.AggregateType("user")}, []driver.Value{repository.AggregateType("user")},
&repository.Event{Sequence: 100}), &repository.Event{Sequence: 100}),
@ -853,7 +853,21 @@ type dbMock struct {
} }
func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock { func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock {
m.mock.ExpectBegin()
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...) query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
m.mock.ExpectCommit()
rows := sqlmock.NewRows([]string{"event_sequence"})
for _, event := range events {
rows = rows.AddRow(event.Sequence)
}
query.WillReturnRows(rows).RowsWillBeClosed()
return m
}
func (m *dbMock) expectQueryScanErr(t *testing.T, expectedQuery string, args []driver.Value, events ...*repository.Event) *dbMock {
m.mock.ExpectBegin()
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
m.mock.ExpectRollback()
rows := sqlmock.NewRows([]string{"event_sequence"}) rows := sqlmock.NewRows([]string{"event_sequence"})
for _, event := range events { for _, event := range events {
rows = rows.AddRow(event.Sequence) rows = rows.AddRow(event.Sequence)
@ -863,6 +877,7 @@ func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.V
} }
func (m *dbMock) expectQueryErr(t *testing.T, expectedQuery string, args []driver.Value, err error) *dbMock { func (m *dbMock) expectQueryErr(t *testing.T, expectedQuery string, args []driver.Value, err error) *dbMock {
m.mock.ExpectBegin()
m.mock.ExpectQuery(expectedQuery).WithArgs(args...).WillReturnError(err) m.mock.ExpectQuery(expectedQuery).WithArgs(args...).WillReturnError(err)
return m return m
} }

View File

@ -127,9 +127,11 @@ func (db *dbMock) expectFilterEventsLimit(aggregateType string, limit uint64, ev
for i := 0; i < eventCount; i++ { for i := 0; i < eventCount; i++ {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectBegin()
db.mock.ExpectQuery(expectedFilterEventsLimitFormat). db.mock.ExpectQuery(expectedFilterEventsLimitFormat).
WithArgs(aggregateType, limit). WithArgs(aggregateType, limit).
WillReturnRows(rows) WillReturnRows(rows)
db.mock.ExpectCommit()
return db return db
} }
@ -138,8 +140,10 @@ func (db *dbMock) expectFilterEventsDesc(aggregateType string, eventCount int) *
for i := eventCount; i > 0; i-- { for i := eventCount; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectBegin()
db.mock.ExpectQuery(expectedFilterEventsDescFormat). db.mock.ExpectQuery(expectedFilterEventsDescFormat).
WillReturnRows(rows) WillReturnRows(rows)
db.mock.ExpectCommit()
return db return db
} }
@ -148,9 +152,11 @@ func (db *dbMock) expectFilterEventsAggregateIDLimit(aggregateType, aggregateID
for i := limit; i > 0; i-- { for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectBegin()
db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit). db.mock.ExpectQuery(expectedFilterEventsAggregateIDLimit).
WithArgs(aggregateType, aggregateID, limit). WithArgs(aggregateType, aggregateID, limit).
WillReturnRows(rows) WillReturnRows(rows)
db.mock.ExpectCommit()
return db return db
} }
@ -159,28 +165,36 @@ func (db *dbMock) expectFilterEventsAggregateIDTypeLimit(aggregateType, aggregat
for i := limit; i > 0; i-- { for i := limit; i > 0; i-- {
rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0") rows.AddRow(time.Now(), "eventType", Sequence(i+1), Sequence(i), nil, "svc", "hodor", "org", "instanceID", "aggType", "aggID", "v1.0.0")
} }
db.mock.ExpectBegin()
db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit). db.mock.ExpectQuery(expectedFilterEventsAggregateIDTypeLimit).
WithArgs(aggregateType, aggregateID, limit). WithArgs(aggregateType, aggregateID, limit).
WillReturnRows(rows) WillReturnRows(rows)
db.mock.ExpectCommit()
return db return db
} }
func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock { func (db *dbMock) expectFilterEventsError(returnedErr error) *dbMock {
db.mock.ExpectBegin()
db.mock.ExpectQuery(expectedGetAllEvents). db.mock.ExpectQuery(expectedGetAllEvents).
WillReturnError(returnedErr) WillReturnError(returnedErr)
db.mock.ExpectRollback()
return db return db
} }
func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock { func (db *dbMock) expectLatestSequenceFilter(aggregateType string, sequence Sequence) *dbMock {
db.mock.ExpectBegin()
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`). db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
WithArgs(aggregateType). WithArgs(aggregateType).
WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence)) WillReturnRows(sqlmock.NewRows([]string{"max_sequence"}).AddRow(sequence))
db.mock.ExpectCommit()
return db return db
} }
func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock { func (db *dbMock) expectLatestSequenceFilterError(aggregateType string, err error) *dbMock {
db.mock.ExpectBegin()
db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`). db.mock.ExpectQuery(`SELECT MAX\(event_sequence\) FROM eventstore\.events AS OF SYSTEM TIME '-1 ms' WHERE \( aggregate_type = \$1 \)`).
WithArgs(aggregateType).WillReturnError(err) WithArgs(aggregateType).WillReturnError(err)
// db.mock.ExpectRollback()
return db return db
} }

View File

@ -3,12 +3,13 @@ package sql
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"runtime/debug" "runtime/debug"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/errors" errs "github.com/zitadel/zitadel/internal/errors"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models" es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
@ -24,73 +25,75 @@ func (db *SQL) Filter(ctx context.Context, searchQuery *es_models.SearchQueryFac
return db.filter(ctx, db.client, searchQuery) return db.filter(ctx, db.client, searchQuery)
} }
func (sql *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) { func (server *SQL) filter(ctx context.Context, db *database.DB, searchQuery *es_models.SearchQueryFactory) (events []*es_models.Event, err error) {
query, limit, values, rowScanner := sql.buildQuery(ctx, db, searchQuery) query, limit, values, rowScanner := server.buildQuery(ctx, db, searchQuery)
if query == "" { if query == "" {
return nil, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") return nil, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
} }
rows, err := db.Query(query, values...)
if err != nil {
logging.New().WithError(err).Info("query failed")
return nil, errors.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
}
defer rows.Close()
events = make([]*es_models.Event, 0, limit) events = make([]*es_models.Event, 0, limit)
err = db.QueryContext(ctx,
func(rows *sql.Rows) error {
for rows.Next() { for rows.Next() {
event := new(es_models.Event) event := new(es_models.Event)
err := rowScanner(rows.Scan, event) err := rowScanner(rows.Scan, event)
if err != nil { if err != nil {
return nil, err return err
} }
events = append(events, event) events = append(events, event)
} }
return nil
},
query, values...,
)
if err != nil {
logging.New().WithError(err).Info("query failed")
return nil, errs.ThrowInternal(err, "SQL-IJuyR", "unable to filter events")
}
return events, nil return events, nil
} }
func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) { func (db *SQL) LatestSequence(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (uint64, error) {
query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory) query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory)
if query == "" { if query == "" {
return 0, errors.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory") return 0, errs.ThrowInvalidArgument(nil, "SQL-rWeBw", "invalid query factory")
} }
row := db.client.QueryRow(query, values...)
sequence := new(Sequence) sequence := new(Sequence)
err := rowScanner(row.Scan, sequence) err := db.client.QueryRowContext(ctx, func(row *sql.Row) error {
if err != nil { return rowScanner(row.Scan, sequence)
}, query, values...)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
logging.New().WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Info("query failed") logging.New().WithError(err).WithField("traceID", tracing.TraceIDFromCtx(ctx)).Info("query failed")
return 0, errors.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence") return 0, errs.ThrowInternal(err, "SQL-Yczyx", "unable to filter latest sequence")
} }
return uint64(*sequence), nil return uint64(*sequence), nil
} }
func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) ([]string, error) { func (db *SQL) InstanceIDs(ctx context.Context, queryFactory *es_models.SearchQueryFactory) (ids []string, err error) {
query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory) query, _, values, rowScanner := db.buildQuery(ctx, db.client, queryFactory)
if query == "" { if query == "" {
return nil, errors.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory") return nil, errs.ThrowInvalidArgument(nil, "SQL-Sfwg2", "invalid query factory")
} }
rows, err := db.client.Query(query, values...) err = db.client.QueryContext(ctx,
if err != nil { func(rows *sql.Rows) error {
logging.New().WithError(err).Info("query failed")
return nil, errors.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids")
}
defer rows.Close()
ids := make([]string, 0)
for rows.Next() { for rows.Next() {
var id string var id string
err := rowScanner(rows.Scan, &id) err := rowScanner(rows.Scan, &id)
if err != nil { if err != nil {
return nil, err return err
} }
ids = append(ids, id) ids = append(ids, id)
} }
return nil
},
query, values...)
if err != nil {
logging.New().WithError(err).Info("query failed")
return nil, errs.ThrowInternal(err, "SQL-Sfg3r", "unable to filter instance ids")
}
return ids, nil return ids, nil
} }

View File

@ -130,6 +130,7 @@ func TestSQL_Filter(t *testing.T) {
if (err != nil) != tt.res.wantErr { if (err != nil) != tt.res.wantErr {
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr) t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
} }
if tt.res.eventsLen != 0 && len(events) != tt.res.eventsLen { if tt.res.eventsLen != 0 && len(events) != tt.res.eventsLen {
t.Errorf("events has wrong length got: %d want %d", len(events), tt.res.eventsLen) t.Errorf("events has wrong length got: %d want %d", len(events), tt.res.eventsLen)
} }
@ -221,10 +222,12 @@ func TestSQL_LatestSequence(t *testing.T) {
sql := &SQL{ sql := &SQL{
client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)}, client: &database.DB{DB: tt.fields.client.sqlClient, Database: new(testDB)},
} }
sequence, err := sql.LatestSequence(context.Background(), tt.args.searchQuery) sequence, err := sql.LatestSequence(context.Background(), tt.args.searchQuery)
if (err != nil) != tt.res.wantErr { if (err != nil) != tt.res.wantErr {
t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr) t.Errorf("SQL.Filter() error = %v, wantErr %v", err, tt.res.wantErr)
} }
if tt.res.sequence != sequence { if tt.res.sequence != sequence {
t.Errorf("events has wrong length got: %d want %d", sequence, tt.res.sequence) t.Errorf("events has wrong length got: %d want %d", sequence, tt.res.sequence)
} }

View File

@ -2,6 +2,7 @@ package access
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -136,9 +137,15 @@ func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string,
} }
var count uint64 var count uint64
if err = l.dbClient. err = l.dbClient.
QueryRowContext(ctx, stmt, args...). QueryRowContext(ctx,
Scan(&count); err != nil { func(row *sql.Row) error {
return row.Scan(&count)
},
stmt, args...,
)
if err != nil {
return 0, caos_errors.ThrowInternal(err, "ACCESS-pBPrM", "Errors.Logstore.Access.ScanFailed") return 0, caos_errors.ThrowInternal(err, "ACCESS-pBPrM", "Errors.Logstore.Access.ScanFailed")
} }

View File

@ -2,6 +2,7 @@ package execution
import ( import (
"context" "context"
"database/sql"
"fmt" "fmt"
"time" "time"
@ -113,9 +114,14 @@ func (l *databaseLogStorage) QueryUsage(ctx context.Context, instanceId string,
} }
var durationSeconds uint64 var durationSeconds uint64
if err = l.dbClient. err = l.dbClient.
QueryRowContext(ctx, stmt, args...). QueryRowContext(ctx,
Scan(&durationSeconds); err != nil { func(row *sql.Row) error {
return row.Scan(&durationSeconds)
},
stmt, args...,
)
if err != nil {
return 0, caos_errors.ThrowInternal(err, "EXEC-Ad8nP", "Errors.Logstore.Execution.ScanFailed") return 0, caos_errors.ThrowInternal(err, "EXEC-Ad8nP", "Errors.Logstore.Execution.ScanFailed")
} }
return durationSeconds, nil return durationSeconds, nil

View File

@ -130,19 +130,19 @@ func (q *Queries) SearchActions(ctx context.Context, queries *ActionSearchQuerie
return nil, errors.ThrowInvalidArgument(err, "QUERY-SDgwg", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-SDgwg", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
actions, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SDfr52", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-SDfr52", "Errors.Internal")
} }
actions, err = scan(rows)
if err != nil {
return nil, err
}
actions.LatestSequence, err = q.latestSequence(ctx, actionTable) actions.LatestSequence, err = q.latestSequence(ctx, actionTable)
return actions, err return actions, err
} }
func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, withOwnerRemoved bool) (_ *Action, err error) { func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, withOwnerRemoved bool) (action *Action, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -160,8 +160,11 @@ func (q *Queries) GetActionByID(ctx context.Context, id string, orgID string, wi
return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) action, err = scan(row)
return err
}, query, args...)
return action, err
} }
func NewActionResourceOwnerQuery(id string) (SearchQuery, error) { func NewActionResourceOwnerQuery(id string) (SearchQuery, error) {

View File

@ -67,7 +67,7 @@ type Flow struct {
TriggerActions map[domain.TriggerType][]*Action TriggerActions map[domain.TriggerType][]*Action
} }
func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID string, withOwnerRemoved bool) (_ *Flow, err error) { func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID string, withOwnerRemoved bool) (flow *Flow, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -85,14 +85,14 @@ func (q *Queries) GetFlow(ctx context.Context, flowType domain.FlowType, orgID s
return nil, errors.ThrowInvalidArgument(err, "QUERY-HBRh3", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-HBRh3", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
if err != nil { flow, err = scan(rows)
return nil, errors.ThrowInternal(err, "QUERY-Gg42f", "Errors.Internal") return err
} }, stmt, args...)
return scan(rows) return flow, err
} }
func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flowType domain.FlowType, triggerType domain.TriggerType, orgID string, withOwnerRemoved bool) (_ []*Action, err error) { func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flowType domain.FlowType, triggerType domain.TriggerType, orgID string, withOwnerRemoved bool) (actions []*Action, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -112,14 +112,14 @@ func (q *Queries) GetActiveActionsByFlowAndTriggerType(ctx context.Context, flow
return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-Dgff3", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, query, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
if err != nil { actions, err = scan(rows)
return nil, errors.ThrowInternal(err, "QUERY-SDf52", "Errors.Internal") return err
} }, query, args...)
return scan(rows) return actions, err
} }
func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, withOwnerRemoved bool) (_ []domain.FlowType, err error) { func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, withOwnerRemoved bool) (types []domain.FlowType, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -136,12 +136,11 @@ func (q *Queries) GetFlowTypesOfActionID(ctx context.Context, actionID string, w
return nil, errors.ThrowInvalidArgument(err, "QUERY-Dh311", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-Dh311", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, query, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
if err != nil { types, err = scan(rows)
return nil, errors.ThrowInternal(err, "QUERY-Bhj4w", "Errors.Internal") return err
} }, query, args...)
return types, err
return scan(rows)
} }
func prepareFlowTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowType, error)) { func prepareFlowTypesQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.FlowType, error)) {

View File

@ -33,8 +33,8 @@ var (
` projections.flow_triggers2.sequence,` + ` projections.flow_triggers2.sequence,` +
` projections.flow_triggers2.resource_owner` + ` projections.flow_triggers2.resource_owner` +
` FROM projections.flow_triggers2` + ` FROM projections.flow_triggers2` +
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` + ` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id`
` AS OF SYSTEM TIME '-1 ms'` // ` AS OF SYSTEM TIME '-1 ms'`
prepareFlowCols = []string{ prepareFlowCols = []string{
"id", "id",
"creation_date", "creation_date",
@ -66,8 +66,8 @@ var (
` projections.actions3.allowed_to_fail,` + ` projections.actions3.allowed_to_fail,` +
` projections.actions3.timeout` + ` projections.actions3.timeout` +
` FROM projections.flow_triggers2` + ` FROM projections.flow_triggers2` +
` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id` + ` LEFT JOIN projections.actions3 ON projections.flow_triggers2.action_id = projections.actions3.id AND projections.flow_triggers2.instance_id = projections.actions3.instance_id`
` AS OF SYSTEM TIME '-1 ms'` // ` AS OF SYSTEM TIME '-1 ms'`
prepareTriggerActionCols = []string{ prepareTriggerActionCols = []string{
"id", "id",
@ -83,8 +83,8 @@ var (
} }
prepareFlowTypeStmt = `SELECT projections.flow_triggers2.flow_type` + prepareFlowTypeStmt = `SELECT projections.flow_triggers2.flow_type` +
` FROM projections.flow_triggers2` + ` FROM projections.flow_triggers2`
` AS OF SYSTEM TIME '-1 ms'` // ` AS OF SYSTEM TIME '-1 ms'`
prepareFlowTypeCols = []string{ prepareFlowTypeCols = []string{
"flow_type", "flow_type",

View File

@ -25,8 +25,8 @@ var (
` projections.actions3.timeout,` + ` projections.actions3.timeout,` +
` projections.actions3.allowed_to_fail,` + ` projections.actions3.allowed_to_fail,` +
` COUNT(*) OVER ()` + ` COUNT(*) OVER ()` +
` FROM projections.actions3` + ` FROM projections.actions3`
` AS OF SYSTEM TIME '-1 ms'` // ` AS OF SYSTEM TIME '-1 ms'`
prepareActionsCols = []string{ prepareActionsCols = []string{
"id", "id",
"creation_date", "creation_date",
@ -51,8 +51,8 @@ var (
` projections.actions3.script,` + ` projections.actions3.script,` +
` projections.actions3.timeout,` + ` projections.actions3.timeout,` +
` projections.actions3.allowed_to_fail` + ` projections.actions3.allowed_to_fail` +
` FROM projections.actions3` + ` FROM projections.actions3`
` AS OF SYSTEM TIME '-1 ms'` // ` AS OF SYSTEM TIME '-1 ms'`
prepareActionCols = []string{ prepareActionCols = []string{
"id", "id",
"creation_date", "creation_date",
@ -215,13 +215,13 @@ func Test_ActionPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Action)(nil),
}, },
{ {
name: "prepareActionQuery no result", name: "prepareActionQuery no result",
prepare: prepareActionQuery, prepare: prepareActionQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareActionStmt), regexp.QuoteMeta(prepareActionStmt),
nil, nil,
nil, nil,
@ -284,7 +284,7 @@ func Test_ActionPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Action)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -248,7 +248,7 @@ var (
} }
) )
func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bool, projectID, appID string, withOwnerRemoved bool) (_ *App, err error) { func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bool, projectID, appID string, withOwnerRemoved bool) (app *App, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -270,11 +270,14 @@ func (q *Queries) AppByProjectAndAppID(ctx context.Context, shouldTriggerBulk bo
return nil, errors.ThrowInternal(err, "QUERY-AFDgg", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-AFDgg", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) app, err = scan(row)
return err
}, query, args...)
return app, err
} }
func (q *Queries) AppByID(ctx context.Context, appID string, withOwnerRemoved bool) (_ *App, err error) { func (q *Queries) AppByID(ctx context.Context, appID string, withOwnerRemoved bool) (app *App, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -291,11 +294,14 @@ func (q *Queries) AppByID(ctx context.Context, appID string, withOwnerRemoved bo
return nil, errors.ThrowInternal(err, "QUERY-immt9", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-immt9", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) app, err = scan(row)
return err
}, query, args...)
return app, err
} }
func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string, withOwnerRemoved bool) (_ *App, err error) { func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string, withOwnerRemoved bool) (app *App, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -312,11 +318,14 @@ func (q *Queries) AppBySAMLEntityID(ctx context.Context, entityID string, withOw
return nil, errors.ThrowInternal(err, "QUERY-JgUop", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-JgUop", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) app, err = scan(row)
return err
}, query, args...)
return app, err
} }
func (q *Queries) ProjectByClientID(ctx context.Context, appID string, withOwnerRemoved bool) (_ *Project, err error) { func (q *Queries) ProjectByClientID(ctx context.Context, appID string, withOwnerRemoved bool) (project *Project, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -337,11 +346,14 @@ func (q *Queries) ProjectByClientID(ctx context.Context, appID string, withOwner
return nil, errors.ThrowInternal(err, "QUERY-XhJi3", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-XhJi3", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) project, err = scan(row)
return err
}, query, args...)
return project, err
} }
func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string, withOwnerRemoved bool) (_ string, err error) { func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string, withOwnerRemoved bool) (id string, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -358,11 +370,14 @@ func (q *Queries) ProjectIDFromOIDCClientID(ctx context.Context, appID string, w
return "", errors.ThrowInternal(err, "QUERY-7d92U", "Errors.Query.SQLStatement") return "", errors.ThrowInternal(err, "QUERY-7d92U", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) id, err = scan(row)
return err
}, query, args...)
return id, err
} }
func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string, withOwnerRemoved bool) (_ string, err error) { func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string, withOwnerRemoved bool) (id string, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -384,11 +399,14 @@ func (q *Queries) ProjectIDFromClientID(ctx context.Context, appID string, withO
return "", errors.ThrowInternal(err, "QUERY-SDfg3", "Errors.Query.SQLStatement") return "", errors.ThrowInternal(err, "QUERY-SDfg3", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) id, err = scan(row)
return err
}, query, args...)
return id, err
} }
func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string, withOwnerRemoved bool) (_ *Project, err error) { func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string, withOwnerRemoved bool) (project *Project, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -405,11 +423,14 @@ func (q *Queries) ProjectByOIDCClientID(ctx context.Context, id string, withOwne
return nil, errors.ThrowInternal(err, "QUERY-XhJi4", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-XhJi4", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) project, err = scan(row)
return err
}, query, args...)
return project, err
} }
func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (_ *App, err error) { func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (app *App, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -426,11 +447,14 @@ func (q *Queries) AppByOIDCClientID(ctx context.Context, clientID string, withOw
return nil, errors.ThrowInternal(err, "QUERY-JgVop", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-JgVop", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) app, err = scan(row)
return err
}, query, args...)
return app, err
} }
func (q *Queries) AppByClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (_ *App, err error) { func (q *Queries) AppByClientID(ctx context.Context, clientID string, withOwnerRemoved bool) (app *App, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -450,11 +474,14 @@ func (q *Queries) AppByClientID(ctx context.Context, clientID string, withOwnerR
return nil, errors.ThrowInternal(err, "QUERY-Dfge2", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-Dfge2", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) app, err = scan(row)
return err
}, query, args...)
return app, err
} }
func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (_ *Apps, err error) { func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (apps *Apps, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -468,19 +495,18 @@ func (q *Queries) SearchApps(ctx context.Context, queries *AppSearchQueries, wit
return nil, errors.ThrowInvalidArgument(err, "QUERY-fajp8", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-fajp8", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
apps, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-aJnZL", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-aJnZL", "Errors.Internal")
} }
apps, err := scan(rows)
if err != nil {
return nil, err
}
apps.LatestSequence, err = q.latestSequence(ctx, appsTable) apps.LatestSequence, err = q.latestSequence(ctx, appsTable)
return apps, err return apps, err
} }
func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (_ []string, err error) { func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries, withOwnerRemoved bool) (ids []string, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -494,11 +520,14 @@ func (q *Queries) SearchClientIDs(ctx context.Context, queries *AppSearchQueries
return nil, errors.ThrowInvalidArgument(err, "QUERY-fajp8", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-fajp8", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
ids, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-aJnZL", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-aJnZL", "Errors.Internal")
} }
return scan(rows) return ids, nil
} }
func NewAppNameSearchQuery(method TextComparison, value string) (SearchQuery, error) { func NewAppNameSearchQuery(method TextComparison, value string) (SearchQuery, error) {

View File

@ -1115,7 +1115,7 @@ func Test_AppsPrepare(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*App)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -1140,7 +1140,7 @@ func Test_AppPrepare(t *testing.T) {
name: "prepareAppQuery no result", name: "prepareAppQuery no result",
prepare: prepareAppQuery, prepare: prepareAppQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
expectedAppQuery, expectedAppQuery,
nil, nil,
nil, nil,
@ -1747,7 +1747,7 @@ func Test_AppPrepare(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*App)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -1833,7 +1833,7 @@ func Test_AppIDsPrepare(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*App)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -1858,7 +1858,7 @@ func Test_ProjectIDByAppPrepare(t *testing.T) {
name: "prepareProjectIDByAppQuery no result", name: "prepareProjectIDByAppQuery no result",
prepare: prepareProjectIDByAppQuery, prepare: prepareProjectIDByAppQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
expectedProjectIDByAppQuery, expectedProjectIDByAppQuery,
nil, nil,
nil, nil,
@ -1899,7 +1899,7 @@ func Test_ProjectIDByAppPrepare(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: "",
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -1924,7 +1924,7 @@ func Test_ProjectByAppPrepare(t *testing.T) {
name: "prepareProjectByAppQuery no result", name: "prepareProjectByAppQuery no result",
prepare: prepareProjectByAppQuery, prepare: prepareProjectByAppQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
expectedProjectByAppQuery, expectedProjectByAppQuery,
nil, nil,
nil, nil,
@ -2097,7 +2097,7 @@ func Test_ProjectByAppPrepare(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Project)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -60,13 +60,17 @@ func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, i
) )
dst := new(AuthRequest) dst := new(AuthRequest)
err = q.client.DB.QueryRowContext( err = q.client.QueryRowContext(
ctx, q.authRequestByIDQuery(ctx), ctx,
id, authz.GetInstance(ctx).InstanceID(), func(row *sql.Row) error {
).Scan( return row.Scan(
&dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.ClientID, &scope, &dst.RedirectURI, &dst.ID, &dst.CreationDate, &dst.LoginClient, &dst.ClientID, &scope, &dst.RedirectURI,
&prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID, &prompt, &locales, &dst.LoginHint, &dst.MaxAge, &dst.HintUserID,
) )
},
q.authRequestByIDQuery(ctx),
id, authz.GetInstance(ctx).InstanceID(),
)
if errs.Is(err, sql.ErrNoRows) { if errs.Is(err, sql.ErrNoRows) {
return nil, errors.ThrowNotFound(err, "QUERY-Thee9", "Errors.AuthRequest.NotExisting") return nil, errors.ThrowNotFound(err, "QUERY-Thee9", "Errors.AuthRequest.NotExisting")
} }

View File

@ -125,7 +125,7 @@ func TestQueries_AuthRequestByID(t *testing.T) {
shouldTriggerBulk: false, shouldTriggerBulk: false,
id: "123", id: "123",
}, },
expect: mockQuery(expQuery, cols, nil, "123", "instanceID"), expect: mockQueryScanErr(expQuery, cols, nil, "123", "instanceID"),
wantErr: errors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.AuthRequest.NotExisting"), wantErr: errors.ThrowNotFound(sql.ErrNoRows, "QUERY-Thee9", "Errors.AuthRequest.NotExisting"),
}, },
{ {

View File

@ -144,14 +144,14 @@ func (q *Queries) SearchAuthNKeys(ctx context.Context, queries *AuthNKeySearchQu
return nil, errors.ThrowInvalidArgument(err, "QUERY-SAf3f", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-SAf3f", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
authNKeys, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Dbg53", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-Dbg53", "Errors.Internal")
} }
authNKeys, err = scan(rows)
if err != nil {
return nil, err
}
authNKeys.LatestSequence, err = q.latestSequence(ctx, authNKeyTable) authNKeys.LatestSequence, err = q.latestSequence(ctx, authNKeyTable)
return authNKeys, err return authNKeys, err
} }
@ -174,19 +174,18 @@ func (q *Queries) SearchAuthNKeysData(ctx context.Context, queries *AuthNKeySear
return nil, errors.ThrowInvalidArgument(err, "QUERY-SAg3f", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-SAg3f", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
authNKeys, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Dbi53", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-Dbi53", "Errors.Internal")
} }
authNKeys, err = scan(rows)
if err != nil {
return nil, err
}
authNKeys.LatestSequence, err = q.latestSequence(ctx, authNKeyTable) authNKeys.LatestSequence, err = q.latestSequence(ctx, authNKeyTable)
return authNKeys, err return authNKeys, err
} }
func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (_ *AuthNKey, err error) { func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (key *AuthNKey, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -211,11 +210,14 @@ func (q *Queries) GetAuthNKeyByID(ctx context.Context, shouldTriggerBulk bool, i
return nil, errors.ThrowInternal(err, "QUERY-AGhg4", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-AGhg4", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) key, err = scan(row)
return err
}, stmt, args...)
return key, err
} }
func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string, withOwnerRemoved bool) (_ []byte, err error) { func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id string, identifier string, withOwnerRemoved bool) (key []byte, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -250,8 +252,11 @@ func (q *Queries) GetAuthNKeyPublicKeyByIDAndIdentifier(ctx context.Context, id
return nil, errors.ThrowInternal(err, "QUERY-DAb32", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-DAb32", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) key, err = scan(row)
return err
}, query, args...)
return key, err
} }
func NewAuthNKeyResourceOwnerQuery(id string) (SearchQuery, error) { func NewAuthNKeyResourceOwnerQuery(id string) (SearchQuery, error) {

View File

@ -349,13 +349,13 @@ func Test_AuthNKeyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*AuthNKey)(nil),
}, },
{ {
name: "prepareAuthNKeyQuery no result", name: "prepareAuthNKeyQuery no result",
prepare: prepareAuthNKeyQuery, prepare: prepareAuthNKeyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareAuthNKeyStmt), regexp.QuoteMeta(prepareAuthNKeyStmt),
nil, nil,
nil, nil,
@ -412,13 +412,13 @@ func Test_AuthNKeyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*AuthNKey)(nil),
}, },
{ {
name: "prepareAuthNKeyPublicKeyQuery no result", name: "prepareAuthNKeyPublicKeyQuery no result",
prepare: prepareAuthNKeyPublicKeyQuery, prepare: prepareAuthNKeyPublicKeyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareAuthNKeyPublicKeyStmt), regexp.QuoteMeta(prepareAuthNKeyPublicKeyStmt),
nil, nil,
nil, nil,
@ -461,7 +461,7 @@ func Test_AuthNKeyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: ([]byte)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -66,7 +66,7 @@ var (
} }
) )
func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage domain.KeyUsage) (_ *Certificates, err error) { func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage domain.KeyUsage) (certs *Certificates, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -88,19 +88,19 @@ func (q *Queries) ActiveCertificates(ctx context.Context, t time.Time, usage dom
return nil, errors.ThrowInternal(err, "QUERY-SDfkg", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-SDfkg", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
certs, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Sgan4", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-Sgan4", "Errors.Internal")
} }
keys, err := scan(rows)
if err != nil { certs.LatestSequence, err = q.latestSequence(ctx, keyTable)
return nil, err
}
keys.LatestSequence, err = q.latestSequence(ctx, keyTable)
if !errors.IsNotFound(err) { if !errors.IsNotFound(err) {
return keys, err return certs, err
} }
return keys, nil return certs, nil
} }
func prepareCertificateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) { func prepareCertificateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*Certificates, error)) {

View File

@ -138,7 +138,7 @@ func Test_CertificatePrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Certificate)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -66,14 +66,17 @@ func (q *Queries) SearchCurrentSequences(ctx context.Context, queries *CurrentSe
return nil, errors.ThrowInvalidArgument(err, "QUERY-MmFef", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-MmFef", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
failedEvents, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-22H8f", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-22H8f", "Errors.Internal")
} }
return scan(rows) return failedEvents, nil
} }
func (q *Queries) latestSequence(ctx context.Context, projections ...table) (_ *LatestSequence, err error) { func (q *Queries) latestSequence(ctx context.Context, projections ...table) (seq *LatestSequence, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -91,8 +94,12 @@ func (q *Queries) latestSequence(ctx context.Context, projections ...table) (_ *
return nil, errors.ThrowInternal(err, "QUERY-5CfX9", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-5CfX9", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) seq, err = scan(row)
return err
}, stmt, args...)
return seq, err
} }
func (q *Queries) ClearCurrentSequence(ctx context.Context, projectionName string) (err error) { func (q *Queries) ClearCurrentSequence(ctx context.Context, projectionName string) (err error) {
@ -129,11 +136,17 @@ func (q *Queries) checkAndLock(ctx context.Context, projectionName string) error
if err != nil { if err != nil {
return errors.ThrowInternal(err, "QUERY-Dfwf2", "Errors.ProjectionName.Invalid") return errors.ThrowInternal(err, "QUERY-Dfwf2", "Errors.ProjectionName.Invalid")
} }
row := q.client.QueryRowContext(ctx, projectionQuery, args...)
var count int var count int
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
if err := row.Scan(&count); err != nil || count == 0 { if err := row.Scan(&count); err != nil || count == 0 {
return errors.ThrowInternal(err, "QUERY-ej8fn", "Errors.ProjectionName.Invalid") return errors.ThrowInternal(err, "QUERY-ej8fn", "Errors.ProjectionName.Invalid")
} }
return err
}, projectionQuery, args...)
if err != nil {
return err
}
lock := fmt.Sprintf(lockStmtFormat, locksTable.identifier()) lock := fmt.Sprintf(lockStmtFormat, locksTable.identifier())
if err != nil { if err != nil {
return errors.ThrowInternal(err, "QUERY-DVfg3", "Errors.RemoveFailed") return errors.ThrowInternal(err, "QUERY-DVfg3", "Errors.RemoveFailed")

View File

@ -132,7 +132,7 @@ func Test_CurrentSequencesPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*CurrentSequences)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -104,14 +104,14 @@ func (q *Queries) CustomTextList(ctx context.Context, aggregateID, template, lan
return nil, errors.ThrowInternal(err, "QUERY-M9gse", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-M9gse", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, query, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
texts, err = scan(rows)
return err
}, query, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-2j00f", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-2j00f", "Errors.Internal")
} }
texts, err = scan(rows)
if err != nil {
return nil, err
}
texts.LatestSequence, err = q.latestSequence(ctx, projectsTable) texts.LatestSequence, err = q.latestSequence(ctx, projectsTable)
return texts, err return texts, err
} }
@ -134,14 +134,14 @@ func (q *Queries) CustomTextListByTemplate(ctx context.Context, aggregateID, tem
return nil, errors.ThrowInternal(err, "QUERY-M49fs", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-M49fs", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, query, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
texts, err = scan(rows)
return err
}, query, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-3n9ge", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-3n9ge", "Errors.Internal")
} }
texts, err = scan(rows)
if err != nil {
return nil, err
}
texts.LatestSequence, err = q.latestSequence(ctx, projectsTable) texts.LatestSequence, err = q.latestSequence(ctx, projectsTable)
return texts, err return texts, err
} }

View File

@ -180,7 +180,7 @@ func Test_CustomTextPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*CustomText)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -70,7 +70,7 @@ var (
} }
) )
func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (_ *domain.DeviceAuth, err error) { func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCode string) (deviceAuth *domain.DeviceAuth, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -85,10 +85,14 @@ func (q *Queries) DeviceAuthByDeviceCode(ctx context.Context, clientID, deviceCo
return nil, errors.ThrowInternal(err, "QUERY-uk1Oh", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-uk1Oh", "Errors.Query.SQLStatement")
} }
return scan(q.client.QueryRowContext(ctx, query, args...)) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
deviceAuth, err = scan(row)
return err
}, query, args...)
return deviceAuth, err
} }
func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (_ *domain.DeviceAuth, err error) { func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (deviceAuth *domain.DeviceAuth, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -102,7 +106,11 @@ func (q *Queries) DeviceAuthByUserCode(ctx context.Context, userCode string) (_
return nil, errors.ThrowInternal(err, "QUERY-Axu7l", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-Axu7l", "Errors.Query.SQLStatement")
} }
return scan(q.client.QueryRowContext(ctx, query, args...)) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
deviceAuth, err = scan(row)
return err
}, query, args...)
return deviceAuth, err
} }
var deviceAuthSelectColumns = []string{ var deviceAuthSelectColumns = []string{

View File

@ -67,9 +67,11 @@ func TestQueries_DeviceAuthByDeviceCode(t *testing.T) {
} }
defer client.Close() defer client.Close()
mock.ExpectBegin()
mock.ExpectQuery(expectedDeviceAuthWhereDeviceCodeQuery).WillReturnRows( mock.ExpectQuery(expectedDeviceAuthWhereDeviceCodeQuery).WillReturnRows(
sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
) )
mock.ExpectCommit()
q := Queries{ q := Queries{
client: &database.DB{DB: client}, client: &database.DB{DB: client},
} }
@ -86,9 +88,11 @@ func TestQueries_DeviceAuthByUserCode(t *testing.T) {
} }
defer client.Close() defer client.Close()
mock.ExpectBegin()
mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows( mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows(
sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...), sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
) )
mock.ExpectCommit()
q := Queries{ q := Queries{
client: &database.DB{DB: client}, client: &database.DB{DB: client},
} }
@ -133,6 +137,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: (*domain.DeviceAuth)(nil),
}, },
{ {
name: "other error", name: "other error",
@ -148,6 +153,7 @@ func Test_prepareDeviceAuthQuery(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: (*domain.DeviceAuth)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -86,7 +86,7 @@ var (
} }
) )
func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *DomainPolicy, err error) { func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *DomainPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -120,11 +120,14 @@ func (q *Queries) DomainPolicyByOrg(ctx context.Context, shouldTriggerBulk bool,
return nil, errors.ThrowInternal(err, "QUERY-D3CqT", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-D3CqT", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultDomainPolicy(ctx context.Context) (_ *DomainPolicy, err error) { func (q *Queries) DefaultDomainPolicy(ctx context.Context) (policy *DomainPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -139,8 +142,11 @@ func (q *Queries) DefaultDomainPolicy(ctx context.Context) (_ *DomainPolicy, err
return nil, errors.ThrowInternal(err, "QUERY-pM7lP", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-pM7lP", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func prepareDomainPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy, error)) { func prepareDomainPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DomainPolicy, error)) {

View File

@ -54,7 +54,7 @@ func Test_DomainPolicyPrepares(t *testing.T) {
name: "prepareDomainPolicyQuery no result", name: "prepareDomainPolicyQuery no result",
prepare: prepareDomainPolicyQuery, prepare: prepareDomainPolicyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareDomainPolicyStmt), regexp.QuoteMeta(prepareDomainPolicyStmt),
nil, nil,
nil, nil,
@ -117,7 +117,7 @@ func Test_DomainPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*DomainPolicy)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -77,11 +77,14 @@ func (q *Queries) SearchFailedEvents(ctx context.Context, queries *FailedEventSe
return nil, errors.ThrowInvalidArgument(err, "QUERY-n8rjJ", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-n8rjJ", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
failedEvents, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-3j99J", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-3j99J", "Errors.Internal")
} }
return scan(rows) return failedEvents, nil
} }
func (q *Queries) RemoveFailedEvent(ctx context.Context, projectionName, instanceID string, sequence uint64) (err error) { func (q *Queries) RemoveFailedEvent(ctx context.Context, projectionName, instanceID string, sequence uint64) (err error) {

View File

@ -146,7 +146,7 @@ func Test_FailedEventsPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*FailedEvents)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -76,7 +76,7 @@ func addIamMemberWithoutOwnerRemoved(eq map[string]interface{}) {
eq[InstanceMemberOwnerRemovedUser.identifier()] = false eq[InstanceMemberOwnerRemovedUser.identifier()] = false
} }
func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, withOwnerRemoved bool) (_ *Members, err error) { func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, withOwnerRemoved bool) (members *Members, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -96,14 +96,13 @@ func (q *Queries) IAMMembers(ctx context.Context, queries *IAMMembersQuery, with
return nil, err return nil, err
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
members, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Pdg1I", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-Pdg1I", "Errors.Internal")
} }
members, err := scan(rows)
if err != nil {
return nil, err
}
members.LatestSequence = currentSequence members.LatestSequence = currentSequence
return members, err return members, err
} }

View File

@ -280,7 +280,7 @@ func Test_IAMMemberPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*IAMMembership)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -188,7 +188,7 @@ var (
) )
// IDPByIDAndResourceOwner searches for the requested id in the context of the resource owner and IAM // IDPByIDAndResourceOwner searches for the requested id in the context of the resource owner and IAM
func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk bool, id, resourceOwner string, withOwnerRemoved bool) (_ *IDP, err error) { func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk bool, id, resourceOwner string, withOwnerRemoved bool) (idp *IDP, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -216,8 +216,11 @@ func (q *Queries) IDPByIDAndResourceOwner(ctx context.Context, shouldTriggerBulk
return nil, errors.ThrowInternal(err, "QUERY-0gocI", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-0gocI", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) idp, err = scan(row)
return err
}, query, args...)
return idp, err
} }
// IDPs searches idps matching the query // IDPs searches idps matching the query
@ -237,14 +240,13 @@ func (q *Queries) IDPs(ctx context.Context, queries *IDPSearchQueries, withOwner
return nil, errors.ThrowInvalidArgument(err, "QUERY-X6X7y", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-X6X7y", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
idps, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-xPlVH", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-xPlVH", "Errors.Internal")
} }
idps, err = scan(rows)
if err != nil {
return nil, err
}
idps.LatestSequence, err = q.latestSequence(ctx, idpTable) idps.LatestSequence, err = q.latestSequence(ctx, idpTable)
return idps, err return idps, err
} }

View File

@ -105,13 +105,13 @@ func (q *Queries) IDPLoginPolicyLinks(ctx context.Context, resourceOwner string,
if err != nil { if err != nil {
return nil, errors.ThrowInvalidArgument(err, "QUERY-FDbKW", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-FDbKW", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil || rows.Err() != nil { err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
return nil, errors.ThrowInternal(err, "QUERY-ZkKUc", "Errors.Internal")
}
idps, err = scan(rows) idps, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, err return nil, errors.ThrowInternal(err, "QUERY-ZkKUc", "Errors.Internal")
} }
idps.LatestSequence, err = q.latestSequence(ctx, idpLoginPolicyLinkTable) idps.LatestSequence, err = q.latestSequence(ctx, idpLoginPolicyLinkTable)
return idps, err return idps, err

View File

@ -128,7 +128,7 @@ func Test_IDPLoginPolicyLinkPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*IDPs)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -606,7 +606,7 @@ var (
) )
// IDPTemplateByID searches for the requested id // IDPTemplateByID searches for the requested id
func (q *Queries) IDPTemplateByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (_ *IDPTemplate, err error) { func (q *Queries) IDPTemplateByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool, queries ...SearchQuery) (template *IDPTemplate, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -630,8 +630,11 @@ func (q *Queries) IDPTemplateByID(ctx context.Context, shouldTriggerBulk bool, i
return nil, errors.ThrowInternal(err, "QUERY-SFefg", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-SFefg", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) template, err = scan(row)
return err
}, stmt, args...)
return template, err
} }
// IDPTemplates searches idp templates matching the query // IDPTemplates searches idp templates matching the query
@ -651,14 +654,13 @@ func (q *Queries) IDPTemplates(ctx context.Context, queries *IDPTemplateSearchQu
return nil, errors.ThrowInvalidArgument(err, "QUERY-SAF34", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-SAF34", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
idps, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-BDFrq", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-BDFrq", "Errors.Internal")
} }
idps, err = scan(rows)
if err != nil {
return nil, err
}
idps.LatestSequence, err = q.latestSequence(ctx, idpTemplateTable) idps.LatestSequence, err = q.latestSequence(ctx, idpTemplateTable)
return idps, err return idps, err
} }

View File

@ -443,7 +443,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) {
name: "prepareIDPTemplateByIDQuery no result", name: "prepareIDPTemplateByIDQuery no result",
prepare: prepareIDPTemplateByIDQuery, prepare: prepareIDPTemplateByIDQuery,
want: want{ want: want{
sqlExpectations: mockQuery( sqlExpectations: mockQueryScanErr(
regexp.QuoteMeta(idpTemplateQuery), regexp.QuoteMeta(idpTemplateQuery),
nil, nil,
nil, nil,
@ -1646,7 +1646,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*IDPTemplate)(nil),
}, },
{ {
name: "prepareIDPTemplatesQuery no result", name: "prepareIDPTemplatesQuery no result",
@ -2606,7 +2606,7 @@ func Test_IDPTemplateTemplatesPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*IDPTemplates)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -144,7 +144,7 @@ func Test_IDPPrepares(t *testing.T) {
name: "prepareIDPByIDQuery no result", name: "prepareIDPByIDQuery no result",
prepare: prepareIDPByIDQuery, prepare: prepareIDPByIDQuery,
want: want{ want: want{
sqlExpectations: mockQuery( sqlExpectations: mockQueryScanErr(
regexp.QuoteMeta(idpQuery), regexp.QuoteMeta(idpQuery),
nil, nil,
nil, nil,
@ -341,7 +341,7 @@ func Test_IDPPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*IDP)(nil),
}, },
{ {
name: "prepareIDPsQuery no result", name: "prepareIDPsQuery no result",
@ -728,7 +728,7 @@ func Test_IDPPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*IDPs)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -103,14 +103,13 @@ func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQ
return nil, errors.ThrowInvalidArgument(err, "QUERY-4zzFK", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-4zzFK", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
idps, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-C1E4D", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-C1E4D", "Errors.Internal")
} }
idps, err = scan(rows)
if err != nil {
return nil, err
}
idps.LatestSequence, err = q.latestSequence(ctx, idpUserLinkTable) idps.LatestSequence, err = q.latestSequence(ctx, idpUserLinkTable)
return idps, err return idps, err
} }

View File

@ -135,7 +135,7 @@ func Test_IDPUserLinkPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*IDPUserLinks)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -166,18 +166,17 @@ func (q *Queries) SearchInstances(ctx context.Context, queries *InstanceSearchQu
return nil, errors.ThrowInvalidArgument(err, "QUERY-M9fow", "Errors.Query.SQLStatement") return nil, errors.ThrowInvalidArgument(err, "QUERY-M9fow", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
instances, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-3j98f", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-3j98f", "Errors.Internal")
} }
instances, err = scan(rows)
if err != nil {
return nil, err
}
return instances, err return instances, err
} }
func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (_ *Instance, err error) { func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (instance *Instance, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -193,14 +192,14 @@ func (q *Queries) Instance(ctx context.Context, shouldTriggerBulk bool) (_ *Inst
return nil, errors.ThrowInternal(err, "QUERY-d9ngs", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-d9ngs", "Errors.Query.SQLStatement")
} }
row, err := q.client.QueryContext(ctx, query, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
if err != nil { instance, err = scan(rows)
return nil, err return err
} }, query, args...)
return scan(row) return instance, err
} }
func (q *Queries) InstanceByHost(ctx context.Context, host string) (_ authz.Instance, err error) { func (q *Queries) InstanceByHost(ctx context.Context, host string) (instance authz.Instance, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -213,11 +212,11 @@ func (q *Queries) InstanceByHost(ctx context.Context, host string) (_ authz.Inst
return nil, errors.ThrowInternal(err, "QUERY-SAfg2", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-SAfg2", "Errors.Query.SQLStatement")
} }
row, err := q.client.QueryContext(ctx, query, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
if err != nil { instance, err = scan(rows)
return nil, err return err
} }, query, args...)
return scan(row) return instance, err
} }
func (q *Queries) InstanceByID(ctx context.Context) (_ authz.Instance, err error) { func (q *Queries) InstanceByID(ctx context.Context) (_ authz.Instance, err error) {

View File

@ -88,11 +88,10 @@ func (q *Queries) SearchInstanceDomainsGlobal(ctx context.Context, queries *Inst
} }
func (q *Queries) queryInstanceDomains(ctx context.Context, stmt string, scan func(*sql.Rows) (*InstanceDomains, error), args ...interface{}) (domains *InstanceDomains, err error) { func (q *Queries) queryInstanceDomains(ctx context.Context, stmt string, scan func(*sql.Rows) (*InstanceDomains, error), args ...interface{}) (domains *InstanceDomains, err error) {
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Dh9Ap", "Errors.Internal")
}
domains, err = scan(rows) domains, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -162,7 +162,7 @@ func Test_InstanceDomainPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Domains)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -97,7 +97,7 @@ func Test_InstancePrepares(t *testing.T) {
additionalArgs: []reflect.Value{reflect.ValueOf("")}, additionalArgs: []reflect.Value{reflect.ValueOf("")},
prepare: prepareInstanceQuery, prepare: prepareInstanceQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(instanceQuery), regexp.QuoteMeta(instanceQuery),
nil, nil,
nil, nil,
@ -160,7 +160,7 @@ func Test_InstancePrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Instance)(nil),
}, },
{ {
name: "prepareInstancesQuery no result", name: "prepareInstancesQuery no result",

View File

@ -177,7 +177,7 @@ var (
} }
) )
func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (_ *PublicKeys, err error) { func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (keys *PublicKeys, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -194,14 +194,14 @@ func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (_ *PublicK
return nil, errors.ThrowInternal(err, "QUERY-SDFfg", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-SDFfg", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
keys, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Sghn4", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-Sghn4", "Errors.Internal")
} }
keys, err := scan(rows)
if err != nil {
return nil, err
}
keys.LatestSequence, err = q.latestSequence(ctx, keyTable) keys.LatestSequence, err = q.latestSequence(ctx, keyTable)
if !errors.IsNotFound(err) { if !errors.IsNotFound(err) {
return keys, err return keys, err
@ -209,7 +209,7 @@ func (q *Queries) ActivePublicKeys(ctx context.Context, t time.Time) (_ *PublicK
return keys, nil return keys, nil
} }
func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (_ *PrivateKeys, err error) { func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (keys *PrivateKeys, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -229,14 +229,13 @@ func (q *Queries) ActivePrivateSigningKey(ctx context.Context, t time.Time) (_ *
return nil, errors.ThrowInternal(err, "QUERY-SDff2", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-SDff2", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, query, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
keys, err = scan(rows)
return err
}, query, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-WRFG4", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-WRFG4", "Errors.Internal")
} }
keys, err := scan(rows)
if err != nil {
return nil, err
}
keys.LatestSequence, err = q.latestSequence(ctx, keyTable) keys.LatestSequence, err = q.latestSequence(ctx, keyTable)
if !errors.IsNotFound(err) { if !errors.IsNotFound(err) {
return keys, err return keys, err

View File

@ -147,7 +147,7 @@ func Test_KeyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*PublicKeys)(nil),
}, },
{ {
name: "preparePrivateKeysQuery no result", name: "preparePrivateKeysQuery no result",
@ -230,7 +230,7 @@ func Test_KeyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*PrivateKeys)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -42,7 +42,7 @@ type Theme struct {
IconURL string IconURL string
} }
func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (_ *LabelPolicy, err error) { func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (policy *LabelPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -68,11 +68,14 @@ func (q *Queries) ActiveLabelPolicyByOrg(ctx context.Context, orgID string, with
return nil, errors.ThrowInternal(err, "QUERY-V22un", "unable to create sql stmt") return nil, errors.ThrowInternal(err, "QUERY-V22un", "unable to create sql stmt")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (_ *LabelPolicy, err error) { func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (policy *LabelPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -98,11 +101,14 @@ func (q *Queries) PreviewLabelPolicyByOrg(ctx context.Context, orgID string) (_
return nil, errors.ThrowInternal(err, "QUERY-AG5eq", "unable to create sql stmt") return nil, errors.ThrowInternal(err, "QUERY-AG5eq", "unable to create sql stmt")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (_ *LabelPolicy, err error) { func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (policy *LabelPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -118,11 +124,14 @@ func (q *Queries) DefaultActiveLabelPolicy(ctx context.Context) (_ *LabelPolicy,
return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "unable to create sql stmt") return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "unable to create sql stmt")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (_ *LabelPolicy, err error) { func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (policy *LabelPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -138,8 +147,11 @@ func (q *Queries) DefaultPreviewLabelPolicy(ctx context.Context) (_ *LabelPolicy
return nil, errors.ThrowInternal(err, "QUERY-B3JQR", "unable to create sql stmt") return nil, errors.ThrowInternal(err, "QUERY-B3JQR", "unable to create sql stmt")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
var ( var (

View File

@ -81,7 +81,7 @@ var (
} }
) )
func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *LockoutPolicy, err error) { func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *LockoutPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -110,11 +110,14 @@ func (q *Queries) LockoutPolicyByOrg(ctx context.Context, shouldTriggerBulk bool
return nil, errors.ThrowInternal(err, "QUERY-SKR6X", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-SKR6X", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (_ *LockoutPolicy, err error) { func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (policy *LockoutPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -129,8 +132,11 @@ func (q *Queries) DefaultLockoutPolicy(ctx context.Context) (_ *LockoutPolicy, e
return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func prepareLockoutPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LockoutPolicy, error)) { func prepareLockoutPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*LockoutPolicy, error)) {

View File

@ -53,7 +53,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) {
name: "prepareLockoutPolicyQuery no result", name: "prepareLockoutPolicyQuery no result",
prepare: prepareLockoutPolicyQuery, prepare: prepareLockoutPolicyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareLockoutPolicyStmt), regexp.QuoteMeta(prepareLockoutPolicyStmt),
nil, nil,
nil, nil,
@ -114,7 +114,7 @@ func Test_LockoutPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*LockoutPolicy)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -166,7 +166,7 @@ var (
} }
) )
func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *LoginPolicy, err error) { func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *LoginPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -191,11 +191,14 @@ func (q *Queries) LoginPolicyByID(ctx context.Context, shouldTriggerBulk bool, o
return nil, errors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
policy, err = q.scanAndAddLinksToLoginPolicy(ctx, rows, scan)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SWgr3", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-SWgr3", "Errors.Internal")
} }
return q.scanAndAddLinksToLoginPolicy(ctx, rows, scan) return policy, nil
} }
func (q *Queries) scanAndAddLinksToLoginPolicy(ctx context.Context, rows *sql.Rows, scan func(*sql.Rows) (*LoginPolicy, error)) (*LoginPolicy, error) { func (q *Queries) scanAndAddLinksToLoginPolicy(ctx context.Context, rows *sql.Rows, scan func(*sql.Rows) (*LoginPolicy, error)) (*LoginPolicy, error) {
@ -214,7 +217,7 @@ func (q *Queries) scanAndAddLinksToLoginPolicy(ctx context.Context, rows *sql.Ro
return policy, nil return policy, nil
} }
func (q *Queries) DefaultLoginPolicy(ctx context.Context) (_ *LoginPolicy, err error) { func (q *Queries) DefaultLoginPolicy(ctx context.Context) (policy *LoginPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -227,14 +230,17 @@ func (q *Queries) DefaultLoginPolicy(ctx context.Context) (_ *LoginPolicy, err e
return nil, errors.ThrowInternal(err, "QUERY-t4TBK", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-t4TBK", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
policy, err = q.scanAndAddLinksToLoginPolicy(ctx, rows, scan)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-SArt2", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-SArt2", "Errors.Internal")
} }
return q.scanAndAddLinksToLoginPolicy(ctx, rows, scan) return policy, nil
} }
func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (_ *SecondFactors, err error) { func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (factors *SecondFactors, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -259,8 +265,10 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (_ *Seco
return nil, errors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-scVHo", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
factors, err := scan(row) factors, err = scan(row)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -268,7 +276,7 @@ func (q *Queries) SecondFactorsByOrg(ctx context.Context, orgID string) (_ *Seco
return factors, err return factors, err
} }
func (q *Queries) DefaultSecondFactors(ctx context.Context) (_ *SecondFactors, err error) { func (q *Queries) DefaultSecondFactors(ctx context.Context) (factors *SecondFactors, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -281,8 +289,10 @@ func (q *Queries) DefaultSecondFactors(ctx context.Context) (_ *SecondFactors, e
return nil, errors.ThrowInternal(err, "QUERY-CZ2Nv", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-CZ2Nv", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
factors, err := scan(row) factors, err = scan(row)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -290,7 +300,7 @@ func (q *Queries) DefaultSecondFactors(ctx context.Context) (_ *SecondFactors, e
return factors, err return factors, err
} }
func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (_ *MultiFactors, err error) { func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (factors *MultiFactors, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -315,8 +325,10 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (_ *Multi
return nil, errors.ThrowInternal(err, "QUERY-B4o7h", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-B4o7h", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
factors, err := scan(row) factors, err = scan(row)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -324,7 +336,7 @@ func (q *Queries) MultiFactorsByOrg(ctx context.Context, orgID string) (_ *Multi
return factors, err return factors, err
} }
func (q *Queries) DefaultMultiFactors(ctx context.Context) (_ *MultiFactors, err error) { func (q *Queries) DefaultMultiFactors(ctx context.Context) (factors *MultiFactors, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -337,8 +349,10 @@ func (q *Queries) DefaultMultiFactors(ctx context.Context) (_ *MultiFactors, err
return nil, errors.ThrowInternal(err, "QUERY-WxYjr", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-WxYjr", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
factors, err := scan(row) factors, err = scan(row)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -98,7 +98,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
name: "prepareLoginPolicyQuery no result", name: "prepareLoginPolicyQuery no result",
prepare: prepareLoginPolicyQuery, prepare: prepareLoginPolicyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(loginPolicyQuery), regexp.QuoteMeta(loginPolicyQuery),
nil, nil,
nil, nil,
@ -189,13 +189,13 @@ func Test_LoginPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*LoginPolicy)(nil),
}, },
{ {
name: "prepareLoginPolicy2FAsQuery no result", name: "prepareLoginPolicy2FAsQuery no result",
prepare: prepareLoginPolicy2FAsQuery, prepare: prepareLoginPolicy2FAsQuery,
want: want{ want: want{
sqlExpectations: mockQuery( sqlExpectations: mockQueryScanErr(
regexp.QuoteMeta(prepareLoginPolicy2FAsStmt), regexp.QuoteMeta(prepareLoginPolicy2FAsStmt),
prepareLoginPolicy2FAsCols, prepareLoginPolicy2FAsCols,
nil, nil,
@ -257,13 +257,13 @@ func Test_LoginPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*SecondFactors)(nil),
}, },
{ {
name: "prepareLoginPolicyMFAsQuery no result", name: "prepareLoginPolicyMFAsQuery no result",
prepare: prepareLoginPolicyMFAsQuery, prepare: prepareLoginPolicyMFAsQuery,
want: want{ want: want{
sqlExpectations: mockQuery( sqlExpectations: mockQueryScanErr(
regexp.QuoteMeta(prepareLoginPolicyMFAsStmt), regexp.QuoteMeta(prepareLoginPolicyMFAsStmt),
prepareLoginPolicyMFAsCols, prepareLoginPolicyMFAsCols,
nil, nil,
@ -325,7 +325,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*MultiFactors)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -70,7 +70,7 @@ var (
} }
) )
func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (_ *MailTemplate, err error) { func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (template *MailTemplate, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -93,11 +93,14 @@ func (q *Queries) MailTemplateByOrg(ctx context.Context, orgID string, withOwner
return nil, errors.ThrowInternal(err, "QUERY-m0sJg", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-m0sJg", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) template, err = scan(row)
return err
}, query, args...)
return template, err
} }
func (q *Queries) DefaultMailTemplate(ctx context.Context) (_ *MailTemplate, err error) { func (q *Queries) DefaultMailTemplate(ctx context.Context) (template *MailTemplate, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -112,8 +115,11 @@ func (q *Queries) DefaultMailTemplate(ctx context.Context) (_ *MailTemplate, err
return nil, errors.ThrowInternal(err, "QUERY-2m0fH", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-2m0fH", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) template, err = scan(row)
return err
}, query, args...)
return template, err
} }
func prepareMailTemplateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MailTemplate, error)) { func prepareMailTemplateQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MailTemplate, error)) {

View File

@ -126,7 +126,7 @@ var (
} }
) )
func (q *Queries) DefaultMessageText(ctx context.Context) (_ *MessageText, err error) { func (q *Queries) DefaultMessageText(ctx context.Context) (text *MessageText, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -140,8 +140,11 @@ func (q *Queries) DefaultMessageText(ctx context.Context) (_ *MessageText, err e
return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) text, err = scan(row)
return err
}, query, args...)
return text, err
} }
func (q *Queries) DefaultMessageTextByTypeAndLanguageFromFileSystem(ctx context.Context, messageType, language string) (_ *MessageText, err error) { func (q *Queries) DefaultMessageTextByTypeAndLanguageFromFileSystem(ctx context.Context, messageType, language string) (_ *MessageText, err error) {
@ -159,7 +162,7 @@ func (q *Queries) DefaultMessageTextByTypeAndLanguageFromFileSystem(ctx context.
return messageTexts.GetMessageTextByType(messageType), nil return messageTexts.GetMessageTextByType(messageType), nil
} }
func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggregateID, messageType, language string, withOwnerRemoved bool) (_ *MessageText, err error) { func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggregateID, messageType, language string, withOwnerRemoved bool) (msg *MessageText, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -179,8 +182,10 @@ func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggreg
return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
msg, err := scan(row) msg, err = scan(row)
return err
}, query, args...)
if errors.IsNotFound(err) { if errors.IsNotFound(err) {
return q.IAMMessageTextByTypeAndLanguage(ctx, messageType, language) return q.IAMMessageTextByTypeAndLanguage(ctx, messageType, language)
} }

View File

@ -64,7 +64,7 @@ func Test_MessageTextPrepares(t *testing.T) {
name: "prepareMessageTextQuery no result", name: "prepareMessageTextQuery no result",
prepare: prepareMessageTextQuery, prepare: prepareMessageTextQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareMessageTextStmt), regexp.QuoteMeta(prepareMessageTextStmt),
nil, nil,
nil, nil,
@ -135,7 +135,7 @@ func Test_MessageTextPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*MessageText)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -69,7 +69,7 @@ var (
) )
// SearchMilestones tries to defer the instanceID from the passed context if no instanceIDs are passed // SearchMilestones tries to defer the instanceID from the passed context if no instanceIDs are passed
func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *MilestonesSearchQueries) (_ *Milestones, err error) { func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, queries *MilestonesSearchQueries) (milestones *Milestones, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
query, scan := prepareMilestonesQuery(ctx, q.client) query, scan := prepareMilestonesQuery(ctx, q.client)
@ -80,22 +80,14 @@ func (q *Queries) SearchMilestones(ctx context.Context, instanceIDs []string, qu
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-A9i5k", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-A9i5k", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
milestones, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, err return nil, err
} }
defer func() {
if closeErr := rows.Close(); closeErr != nil && err == nil {
err = errors.ThrowInternal(closeErr, "QUERY-CK9mI", "Errors.Query.CloseRows")
}
}()
milestones, err := scan(rows)
if err != nil {
return nil, err
}
if err = rows.Err(); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-asLsI", "Errors.Internal")
}
milestones.LatestSequence, err = q.latestSequence(ctx, milestonesTable) milestones.LatestSequence, err = q.latestSequence(ctx, milestonesTable)
return milestones, err return milestones, err

View File

@ -178,7 +178,7 @@ func Test_MilestonesPrepare(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Milestones)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -76,7 +76,7 @@ var (
} }
) )
func (q *Queries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *NotificationPolicy, err error) { func (q *Queries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *NotificationPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -104,11 +104,14 @@ func (q *Queries) NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk
return nil, errors.ThrowInternal(err, "QUERY-Xuoapqm", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-Xuoapqm", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBulk bool) (_ *NotificationPolicy, err error) { func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBulk bool) (policy *NotificationPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -130,8 +133,11 @@ func (q *Queries) DefaultNotificationPolicy(ctx context.Context, shouldTriggerBu
return nil, errors.ThrowInternal(err, "QUERY-xlqp209", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-xlqp209", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func prepareNotificationPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*NotificationPolicy, error)) { func prepareNotificationPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*NotificationPolicy, error)) {

View File

@ -50,7 +50,7 @@ func Test_NotificationPolicyPrepares(t *testing.T) {
name: "prepareNotificationPolicyQuery no result", name: "prepareNotificationPolicyQuery no result",
prepare: prepareNotificationPolicyQuery, prepare: prepareNotificationPolicyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
notificationPolicyStmt, notificationPolicyStmt,
nil, nil,
nil, nil,
@ -109,7 +109,7 @@ func Test_NotificationPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*NotificationPolicy)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -70,7 +70,7 @@ var (
} }
) )
func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (_ *DebugNotificationProvider, err error) { func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (provider *DebugNotificationProvider, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -90,8 +90,11 @@ func (q *Queries) NotificationProviderByIDAndType(ctx context.Context, aggID str
return nil, errors.ThrowInternal(err, "QUERY-f9jSf", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-f9jSf", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) provider, err = scan(row)
return err
}, stmt, args...)
return provider, err
} }
func prepareDebugNotificationProviderQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DebugNotificationProvider, error)) { func prepareDebugNotificationProviderQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*DebugNotificationProvider, error)) {

View File

@ -50,7 +50,7 @@ func Test_NotificationProviderPrepares(t *testing.T) {
name: "prepareNotificationProviderQuery no result", name: "prepareNotificationProviderQuery no result",
prepare: prepareDebugNotificationProviderQuery, prepare: prepareDebugNotificationProviderQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareNotificationProviderStmt), regexp.QuoteMeta(prepareNotificationProviderStmt),
nil, nil,
nil, nil,
@ -109,7 +109,7 @@ func Test_NotificationProviderPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*DebugNotificationProvider)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -75,7 +75,7 @@ type OIDCSettings struct {
RefreshTokenExpiration time.Duration RefreshTokenExpiration time.Duration
} }
func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) (_ *OIDCSettings, err error) { func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) (settings *OIDCSettings, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -88,8 +88,11 @@ func (q *Queries) OIDCSettingsByAggID(ctx context.Context, aggregateID string) (
return nil, errors.ThrowInternal(err, "QUERY-s9nle", "Errors.Query.SQLStatment") return nil, errors.ThrowInternal(err, "QUERY-s9nle", "Errors.Query.SQLStatment")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) settings, err = scan(row)
return err
}, query, args...)
return settings, err
} }
func prepareOIDCSettingsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*OIDCSettings, error)) { func prepareOIDCSettingsQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*OIDCSettings, error)) {

View File

@ -52,7 +52,7 @@ func Test_OIDCConfigsPrepares(t *testing.T) {
name: "prepareOIDCSettingsQuery no result", name: "prepareOIDCSettingsQuery no result",
prepare: prepareOIDCSettingsQuery, prepare: prepareOIDCSettingsQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
prepareOIDCSettingsStmt, prepareOIDCSettingsStmt,
nil, nil,
nil, nil,
@ -113,7 +113,7 @@ func Test_OIDCConfigsPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*OIDCSettings)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -89,7 +89,7 @@ func (q *OrgSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
return query return query
} }
func (q *Queries) OrgByID(ctx context.Context, shouldTriggerBulk bool, id string) (_ *Org, err error) { func (q *Queries) OrgByID(ctx context.Context, shouldTriggerBulk bool, id string) (org *Org, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -106,11 +106,14 @@ func (q *Queries) OrgByID(ctx context.Context, shouldTriggerBulk bool, id string
return nil, errors.ThrowInternal(err, "QUERY-AWx52", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-AWx52", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) org, err = scan(row)
return err
}, query, args...)
return org, err
} }
func (q *Queries) OrgByPrimaryDomain(ctx context.Context, domain string) (_ *Org, err error) { func (q *Queries) OrgByPrimaryDomain(ctx context.Context, domain string) (org *Org, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -124,11 +127,14 @@ func (q *Queries) OrgByPrimaryDomain(ctx context.Context, domain string) (_ *Org
return nil, errors.ThrowInternal(err, "QUERY-TYUCE", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-TYUCE", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) org, err = scan(row)
return err
}, query, args...)
return org, err
} }
func (q *Queries) OrgByVerifiedDomain(ctx context.Context, domain string) (_ *Org, err error) { func (q *Queries) OrgByVerifiedDomain(ctx context.Context, domain string) (org *Org, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -142,8 +148,11 @@ func (q *Queries) OrgByVerifiedDomain(ctx context.Context, domain string) (_ *Or
return nil, errors.ThrowInternal(err, "QUERY-TYUCE", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-TYUCE", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) org, err = scan(row)
return err
}, query, args...)
return org, err
} }
func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUnique bool, err error) { func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUnique bool, err error) {
@ -176,8 +185,11 @@ func (q *Queries) IsOrgUnique(ctx context.Context, name, domain string) (isUniqu
return false, errors.ThrowInternal(err, "QUERY-Dgbe2", "Errors.Query.SQLStatement") return false, errors.ThrowInternal(err, "QUERY-Dgbe2", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) isUnique, err = scan(row)
return err
}, stmt, args...)
return isUnique, err
} }
func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID string, err error) { func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID string, err error) {
@ -209,14 +221,14 @@ func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries) (or
return nil, errors.ThrowInvalidArgument(err, "QUERY-wQ3by", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-wQ3by", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
orgs, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-M6mYN", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-M6mYN", "Errors.Internal")
} }
orgs, err = scan(rows)
if err != nil {
return nil, err
}
orgs.LatestSequence, err = q.latestSequence(ctx, orgsTable) orgs.LatestSequence, err = q.latestSequence(ctx, orgsTable)
return orgs, err return orgs, err
} }

View File

@ -70,14 +70,14 @@ func (q *Queries) SearchOrgDomains(ctx context.Context, queries *OrgDomainSearch
return nil, errors.ThrowInvalidArgument(err, "QUERY-ZRfj1", "Errors.Query.SQLStatement") return nil, errors.ThrowInvalidArgument(err, "QUERY-ZRfj1", "Errors.Query.SQLStatement")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
domains, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-M6mYN", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-M6mYN", "Errors.Internal")
} }
domains, err = scan(rows)
if err != nil {
return nil, err
}
domains.LatestSequence, err = q.latestSequence(ctx, orgDomainsTable) domains.LatestSequence, err = q.latestSequence(ctx, orgDomainsTable)
return domains, err return domains, err
} }

View File

@ -172,7 +172,7 @@ func Test_OrgDomainPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Domains)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -78,7 +78,7 @@ func addOrgMemberWithoutOwnerRemoved(eq map[string]interface{}) {
eq[OrgMemberOwnerRemovedUser.identifier()] = false eq[OrgMemberOwnerRemovedUser.identifier()] = false
} }
func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery, withOwnerRemoved bool) (_ *Members, err error) { func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery, withOwnerRemoved bool) (members *Members, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -98,14 +98,14 @@ func (q *Queries) OrgMembers(ctx context.Context, queries *OrgMembersQuery, with
return nil, err return nil, err
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
members, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-5g4yV", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-5g4yV", "Errors.Internal")
} }
members, err := scan(rows)
if err != nil {
return nil, err
}
members.LatestSequence = currentSequence members.LatestSequence = currentSequence
return members, err return members, err
} }

View File

@ -284,7 +284,7 @@ func Test_OrgMemberPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*OrgMembership)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -77,7 +77,7 @@ var (
} }
) )
func (q *Queries) GetOrgMetadataByKey(ctx context.Context, shouldTriggerBulk bool, orgID string, key string, withOwnerRemoved bool, queries ...SearchQuery) (_ *OrgMetadata, err error) { func (q *Queries) GetOrgMetadataByKey(ctx context.Context, shouldTriggerBulk bool, orgID string, key string, withOwnerRemoved bool, queries ...SearchQuery) (metadata *OrgMetadata, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -102,11 +102,14 @@ func (q *Queries) GetOrgMetadataByKey(ctx context.Context, shouldTriggerBulk boo
return nil, errors.ThrowInternal(err, "QUERY-aDaG2", "Errors.Query.SQLStatment") return nil, errors.ThrowInternal(err, "QUERY-aDaG2", "Errors.Query.SQLStatment")
} }
row := q.client.QueryRowContext(ctx, stmt, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) metadata, err = scan(row)
return err
}, stmt, args...)
return metadata, err
} }
func (q *Queries) SearchOrgMetadata(ctx context.Context, shouldTriggerBulk bool, orgID string, queries *OrgMetadataSearchQueries, withOwnerRemoved bool) (_ *OrgMetadataList, err error) { func (q *Queries) SearchOrgMetadata(ctx context.Context, shouldTriggerBulk bool, orgID string, queries *OrgMetadataSearchQueries, withOwnerRemoved bool) (metadata *OrgMetadataList, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -126,14 +129,14 @@ func (q *Queries) SearchOrgMetadata(ctx context.Context, shouldTriggerBulk bool,
return nil, errors.ThrowInternal(err, "QUERY-Egbld", "Errors.Query.SQLStatment") return nil, errors.ThrowInternal(err, "QUERY-Egbld", "Errors.Query.SQLStatment")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
metadata, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Ho2wf", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-Ho2wf", "Errors.Internal")
} }
metadata, err := scan(rows)
if err != nil {
return nil, err
}
metadata.LatestSequence, err = q.latestSequence(ctx, orgMetadataTable) metadata.LatestSequence, err = q.latestSequence(ctx, orgMetadataTable)
return metadata, err return metadata, err
} }

View File

@ -63,7 +63,7 @@ func Test_OrgMetadataPrepares(t *testing.T) {
name: "prepareOrgMetadataQuery no result", name: "prepareOrgMetadataQuery no result",
prepare: prepareOrgMetadataQuery, prepare: prepareOrgMetadataQuery,
want: want{ want: want{
sqlExpectations: mockQuery( sqlExpectations: mockQueryScanErr(
regexp.QuoteMeta(orgMetadataQuery), regexp.QuoteMeta(orgMetadataQuery),
nil, nil,
nil, nil,
@ -118,7 +118,7 @@ func Test_OrgMetadataPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*OrgMetadata)(nil),
}, },
{ {
name: "prepareOrgMetadataListQuery no result", name: "prepareOrgMetadataListQuery no result",
@ -239,7 +239,7 @@ func Test_OrgMetadataPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*OrgMetadataList)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -209,13 +209,13 @@ func Test_OrgPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Orgs)(nil),
}, },
{ {
name: "prepareOrgQuery no result", name: "prepareOrgQuery no result",
prepare: prepareOrgQuery, prepare: prepareOrgQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareOrgQueryStmt), regexp.QuoteMeta(prepareOrgQueryStmt),
nil, nil,
nil, nil,
@ -274,13 +274,13 @@ func Test_OrgPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*Org)(nil),
}, },
{ {
name: "prepareOrgUniqueQuery no result", name: "prepareOrgUniqueQuery no result",
prepare: prepareOrgUniqueQuery, prepare: prepareOrgUniqueQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareOrgUniqueStmt), regexp.QuoteMeta(prepareOrgUniqueStmt),
nil, nil,
nil, nil,
@ -323,7 +323,7 @@ func Test_OrgPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: false,
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -81,7 +81,7 @@ var (
} }
) )
func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *PasswordAgePolicy, err error) { func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *PasswordAgePolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -107,11 +107,14 @@ func (q *Queries) PasswordAgePolicyByOrg(ctx context.Context, shouldTriggerBulk
return nil, errors.ThrowInternal(err, "QUERY-SKR6X", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-SKR6X", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBulk bool) (_ *PasswordAgePolicy, err error) { func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBulk bool) (policy *PasswordAgePolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -129,8 +132,11 @@ func (q *Queries) DefaultPasswordAgePolicy(ctx context.Context, shouldTriggerBul
return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-mN0Ci", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func preparePasswordAgePolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PasswordAgePolicy, error)) { func preparePasswordAgePolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PasswordAgePolicy, error)) {

View File

@ -52,7 +52,7 @@ func Test_PasswordAgePolicyPrepares(t *testing.T) {
name: "preparePasswordAgePolicyQuery no result", name: "preparePasswordAgePolicyQuery no result",
prepare: preparePasswordAgePolicyQuery, prepare: preparePasswordAgePolicyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(preparePasswordAgePolicyStmt), regexp.QuoteMeta(preparePasswordAgePolicyStmt),
nil, nil,
nil, nil,
@ -113,7 +113,7 @@ func Test_PasswordAgePolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*PasswordAgePolicy)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -33,7 +33,7 @@ type PasswordComplexityPolicy struct {
IsDefault bool IsDefault bool
} }
func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *PasswordComplexityPolicy, err error) { func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *PasswordComplexityPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -59,11 +59,14 @@ func (q *Queries) PasswordComplexityPolicyByOrg(ctx context.Context, shouldTrigg
return nil, errors.ThrowInternal(err, "QUERY-lDnrk", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-lDnrk", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTriggerBulk bool) (_ *PasswordComplexityPolicy, err error) { func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTriggerBulk bool) (policy *PasswordComplexityPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -82,8 +85,11 @@ func (q *Queries) DefaultPasswordComplexityPolicy(ctx context.Context, shouldTri
return nil, errors.ThrowInternal(err, "QUERY-h4Uyr", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-h4Uyr", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
var ( var (

View File

@ -58,7 +58,7 @@ func Test_PasswordComplexityPolicyPrepares(t *testing.T) {
name: "preparePasswordComplexityPolicyQuery no result", name: "preparePasswordComplexityPolicyQuery no result",
prepare: preparePasswordComplexityPolicyQuery, prepare: preparePasswordComplexityPolicyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(preparePasswordComplexityPolicyStmt), regexp.QuoteMeta(preparePasswordComplexityPolicyStmt),
nil, nil,
nil, nil,
@ -125,7 +125,7 @@ func Test_PasswordComplexityPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*PasswordComplexityPolicy)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -15,6 +15,7 @@ import (
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
) )
var ( var (
@ -53,15 +54,16 @@ func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExp
} }
return isErr(err) return isErr(err)
} }
object, ok := execScan(client, builder, scan, errCheck) object, ok, didScan := execScan(&database.DB{DB: client}, builder, scan, errCheck)
if !ok { if !ok {
t.Error(object) t.Error(object)
return false return false
} }
if didScan {
if !assert.Equal(t, expectedObject, object) { if !assert.Equal(t, expectedObject, object) {
return false return false
} }
}
if err := mock.ExpectationsWereMet(); err != nil { if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sql expectations not met: %v", err) t.Errorf("sql expectations not met: %v", err)
@ -77,7 +79,23 @@ type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...) q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
}
q.WillReturnRows(result)
return m
}
}
func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols) result := sqlmock.NewRows(cols)
if len(row) > 0 { if len(row) > 0 {
result.AddRow(row...) result.AddRow(row...)
@ -89,7 +107,28 @@ func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Va
func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...) q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
return m
}
}
func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols) result := sqlmock.NewRows(cols)
count := uint64(len(rows)) count := uint64(len(rows))
for _, row := range rows { for _, row := range rows {
@ -106,8 +145,10 @@ func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driv
func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock { func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock { return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...) q := m.ExpectQuery(stmt).WithArgs(args...)
q.WillReturnError(err) q.WillReturnError(err)
m.ExpectRollback()
return m return m
} }
} }
@ -127,52 +168,65 @@ var (
selectBuilderType = reflect.TypeOf(sq.SelectBuilder{}) selectBuilderType = reflect.TypeOf(sq.SelectBuilder{})
) )
func execScan(client *sql.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (interface{}, bool) { func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) {
scanType := reflect.TypeOf(scan) scanType := reflect.TypeOf(scan)
err := validateScan(scanType) err := validateScan(scanType)
if err != nil { if err != nil {
return err, false return err, false, false
} }
stmt, args, err := builder.ToSql() stmt, args, err := builder.ToSql()
if err != nil { if err != nil {
return fmt.Errorf("unexpeted error from sql builder: %w", err), false return fmt.Errorf("unexpeted error from sql builder: %w", err), false, false
} }
//resultSet represents *sql.Row or *sql.Rows, //resultSet represents *sql.Row or *sql.Rows,
// depending on whats assignable to the scan function // depending on whats assignable to the scan function
var resultSet interface{} var res []reflect.Value
//execute sql stmt //execute sql stmt
// if scan(*sql.Rows)... // if scan(*sql.Rows)...
if scanType.In(0).AssignableTo(rowsType) { if scanType.In(0).AssignableTo(rowsType) {
resultSet, err = client.Query(stmt, args...) err = client.Query(func(rows *sql.Rows) error {
if err != nil { didScan = true
return errCheck(err) res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(rows)})
if err, ok := res[1].Interface().(error); ok {
return err
} }
return nil
}, stmt, args...)
// if scan(*sql.Row)... // if scan(*sql.Row)...
} else if scanType.In(0).AssignableTo(rowType) { } else if scanType.In(0).AssignableTo(rowType) {
row := client.QueryRow(stmt, args...) err = client.QueryRow(func(r *sql.Row) error {
if row.Err() != nil { didScan = true
return errCheck(row.Err()) res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)})
if err, ok := res[1].Interface().(error); ok {
return err
} }
resultSet = row return nil
}, stmt, args...)
} else { } else {
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false, false
} }
// res contains object and error if err != nil {
res := reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(resultSet)}) err, ok := errCheck(err)
if didScan {
return res[0].Interface(), ok, didScan
}
return err, ok, didScan
}
//check for error //check for error
if res[1].Interface() != nil { if res[1].Interface() != nil {
if err, ok := errCheck(res[1].Interface().(error)); !ok { if err, ok := errCheck(res[1].Interface().(error)); !ok {
return fmt.Errorf("scan failed: %w", err), false return fmt.Errorf("scan failed: %w", err), false, didScan
} }
} }
return res[0].Interface(), true return res[0].Interface(), true, didScan
} }
func validateScan(scanType reflect.Type) error { func validateScan(scanType reflect.Type) error {

View File

@ -91,7 +91,7 @@ var (
} }
) )
func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (_ *PrivacyPolicy, err error) { func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (policy *PrivacyPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -116,11 +116,14 @@ func (q *Queries) PrivacyPolicyByOrg(ctx context.Context, shouldTriggerBulk bool
return nil, errors.ThrowInternal(err, "QUERY-UXuPI", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-UXuPI", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bool) (_ *PrivacyPolicy, err error) { func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bool) (policy *PrivacyPolicy, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -139,8 +142,11 @@ func (q *Queries) DefaultPrivacyPolicy(ctx context.Context, shouldTriggerBulk bo
return nil, errors.ThrowInternal(err, "QUERY-LkFZ7", "Errors.Query.SQLStatement") return nil, errors.ThrowInternal(err, "QUERY-LkFZ7", "Errors.Query.SQLStatement")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) policy, err = scan(row)
return err
}, query, args...)
return policy, err
} }
func preparePrivacyPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PrivacyPolicy, error)) { func preparePrivacyPolicyQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*PrivacyPolicy, error)) {

View File

@ -56,7 +56,7 @@ func Test_PrivacyPolicyPrepares(t *testing.T) {
name: "preparePrivacyPolicyQuery no result", name: "preparePrivacyPolicyQuery no result",
prepare: preparePrivacyPolicyQuery, prepare: preparePrivacyPolicyQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(preparePrivacyPolicyStmt), regexp.QuoteMeta(preparePrivacyPolicyStmt),
nil, nil,
nil, nil,
@ -121,7 +121,7 @@ func Test_PrivacyPolicyPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*PrivacyPolicy)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -100,7 +100,7 @@ type ProjectSearchQueries struct {
Queries []SearchQuery Queries []SearchQuery
} }
func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (_ *Project, err error) { func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (project *Project, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -121,8 +121,11 @@ func (q *Queries) ProjectByID(ctx context.Context, shouldTriggerBulk bool, id st
return nil, errors.ThrowInternal(err, "QUERY-2m00Q", "Errors.Query.SQLStatment") return nil, errors.ThrowInternal(err, "QUERY-2m00Q", "Errors.Query.SQLStatment")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) project, err = scan(row)
return err
}, query, args...)
return project, err
} }
func (q *Queries) SearchProjects(ctx context.Context, queries *ProjectSearchQueries, withOwnerRemoved bool) (projects *Projects, err error) { func (q *Queries) SearchProjects(ctx context.Context, queries *ProjectSearchQueries, withOwnerRemoved bool) (projects *Projects, err error) {
@ -139,14 +142,13 @@ func (q *Queries) SearchProjects(ctx context.Context, queries *ProjectSearchQuer
return nil, errors.ThrowInvalidArgument(err, "QUERY-fn9ew", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-fn9ew", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
projects, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-2j00f", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-2j00f", "Errors.Internal")
} }
projects, err = scan(rows)
if err != nil {
return nil, err
}
projects.LatestSequence, err = q.latestSequence(ctx, projectsTable) projects.LatestSequence, err = q.latestSequence(ctx, projectsTable)
return projects, err return projects, err
} }

View File

@ -111,7 +111,7 @@ type ProjectGrantSearchQueries struct {
Queries []SearchQuery Queries []SearchQuery
} }
func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (_ *ProjectGrant, err error) { func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool, id string, withOwnerRemoved bool) (grant *ProjectGrant, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -133,11 +133,14 @@ func (q *Queries) ProjectGrantByID(ctx context.Context, shouldTriggerBulk bool,
return nil, errors.ThrowInternal(err, "QUERY-Nf93d", "Errors.Query.SQLStatment") return nil, errors.ThrowInternal(err, "QUERY-Nf93d", "Errors.Query.SQLStatment")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) grant, err = scan(row)
return err
}, query, args...)
return grant, err
} }
func (q *Queries) ProjectGrantByIDAndGrantedOrg(ctx context.Context, id, grantedOrg string, withOwnerRemoved bool) (_ *ProjectGrant, err error) { func (q *Queries) ProjectGrantByIDAndGrantedOrg(ctx context.Context, id, grantedOrg string, withOwnerRemoved bool) (grant *ProjectGrant, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -156,11 +159,14 @@ func (q *Queries) ProjectGrantByIDAndGrantedOrg(ctx context.Context, id, granted
return nil, errors.ThrowInternal(err, "QUERY-MO9fs", "Errors.Query.SQLStatment") return nil, errors.ThrowInternal(err, "QUERY-MO9fs", "Errors.Query.SQLStatment")
} }
row := q.client.QueryRowContext(ctx, query, args...) err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
return scan(row) grant, err = scan(row)
return err
}, query, args...)
return grant, err
} }
func (q *Queries) SearchProjectGrants(ctx context.Context, queries *ProjectGrantSearchQueries, withOwnerRemoved bool) (projects *ProjectGrants, err error) { func (q *Queries) SearchProjectGrants(ctx context.Context, queries *ProjectGrantSearchQueries, withOwnerRemoved bool) (grants *ProjectGrants, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -177,16 +183,16 @@ func (q *Queries) SearchProjectGrants(ctx context.Context, queries *ProjectGrant
return nil, errors.ThrowInvalidArgument(err, "QUERY-N9fsg", "Errors.Query.InvalidRequest") return nil, errors.ThrowInvalidArgument(err, "QUERY-N9fsg", "Errors.Query.InvalidRequest")
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
grants, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-PP02n", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-PP02n", "Errors.Internal")
} }
projects, err = scan(rows)
if err != nil { grants.LatestSequence, err = q.latestSequence(ctx, projectGrantsTable)
return nil, err return grants, err
}
projects.LatestSequence, err = q.latestSequence(ctx, projectGrantsTable)
return projects, err
} }
func (q *Queries) SearchProjectGrantsByProjectIDAndRoleKey(ctx context.Context, projectID, roleKey string, withOwnerRemoved bool) (projects *ProjectGrants, err error) { func (q *Queries) SearchProjectGrantsByProjectIDAndRoleKey(ctx context.Context, projectID, roleKey string, withOwnerRemoved bool) (projects *ProjectGrants, err error) {

View File

@ -95,7 +95,7 @@ func addProjectGrantMemberWithoutOwnerRemoved(eq map[string]interface{}) {
eq[ProjectGrantMemberGrantedOrgRemoved.identifier()] = false eq[ProjectGrantMemberGrantedOrgRemoved.identifier()] = false
} }
func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrantMembersQuery, withOwnerRemoved bool) (*Members, error) { func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrantMembersQuery, withOwnerRemoved bool) (members *Members, err error) {
query, scan := prepareProjectGrantMembersQuery(ctx, q.client) query, scan := prepareProjectGrantMembersQuery(ctx, q.client)
eq := sq.Eq{ProjectGrantMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()} eq := sq.Eq{ProjectGrantMemberInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved { if !withOwnerRemoved {
@ -112,14 +112,14 @@ func (q *Queries) ProjectGrantMembers(ctx context.Context, queries *ProjectGrant
return nil, err return nil, err
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
members, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-Pdg1I", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-Pdg1I", "Errors.Internal")
} }
members, err := scan(rows)
if err != nil {
return nil, err
}
members.LatestSequence = currentSequence members.LatestSequence = currentSequence
return members, err return members, err
} }

View File

@ -287,7 +287,7 @@ func Test_ProjectGrantMemberPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*ProjectGrantMembership)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -381,13 +381,13 @@ func Test_ProjectGrantPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*ProjectGrants)(nil),
}, },
{ {
name: "prepareProjectGrantQuery no result", name: "prepareProjectGrantQuery no result",
prepare: prepareProjectGrantQuery, prepare: prepareProjectGrantQuery,
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(projectGrantQuery), regexp.QuoteMeta(projectGrantQuery),
nil, nil,
nil, nil,
@ -568,7 +568,7 @@ func Test_ProjectGrantPrepares(t *testing.T) {
return nil, true return nil, true
}, },
}, },
object: nil, object: (*ProjectGrant)(nil),
}, },
} }
for _, tt := range tests { for _, tt := range tests {

View File

@ -78,7 +78,7 @@ func addProjectMemberWithoutOwnerRemoved(eq map[string]interface{}) {
eq[ProjectMemberOwnerRemovedUser.identifier()] = false eq[ProjectMemberOwnerRemovedUser.identifier()] = false
} }
func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQuery, withOwnerRemoved bool) (_ *Members, err error) { func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQuery, withOwnerRemoved bool) (members *Members, err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
@ -98,14 +98,14 @@ func (q *Queries) ProjectMembers(ctx context.Context, queries *ProjectMembersQue
return nil, err return nil, err
} }
rows, err := q.client.QueryContext(ctx, stmt, args...) err = q.client.QueryContext(ctx, func(rows *sql.Rows) error {
members, err = scan(rows)
return err
}, stmt, args...)
if err != nil { if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-uh6pj", "Errors.Internal") return nil, errors.ThrowInternal(err, "QUERY-uh6pj", "Errors.Internal")
} }
members, err := scan(rows)
if err != nil {
return nil, err
}
members.LatestSequence = currentSequence members.LatestSequence = currentSequence
return members, err return members, err
} }

Some files were not shown because too many files have changed in this diff Show More