mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:17:32 +00:00
feat(storage): read only transactions for queries (#6415)
* fix: tests * bastle wie en grosse * fix(database): scan as callback * fix tests * fix merge failures * remove as of system time * refactor: remove unused test * refacotr: remove unused lines
This commit is contained in:
@@ -15,6 +15,7 @@ import (
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -53,14 +54,15 @@ func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExp
|
||||
}
|
||||
return isErr(err)
|
||||
}
|
||||
object, ok := execScan(client, builder, scan, errCheck)
|
||||
object, ok, didScan := execScan(&database.DB{DB: client}, builder, scan, errCheck)
|
||||
if !ok {
|
||||
t.Error(object)
|
||||
return false
|
||||
}
|
||||
|
||||
if !assert.Equal(t, expectedObject, object) {
|
||||
return false
|
||||
if didScan {
|
||||
if !assert.Equal(t, expectedObject, object) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
if err := mock.ExpectationsWereMet(); err != nil {
|
||||
@@ -77,7 +79,23 @@ type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
|
||||
|
||||
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
m.ExpectBegin()
|
||||
q := m.ExpectQuery(stmt).WithArgs(args...)
|
||||
m.ExpectCommit()
|
||||
result := sqlmock.NewRows(cols)
|
||||
if len(row) > 0 {
|
||||
result.AddRow(row...)
|
||||
}
|
||||
q.WillReturnRows(result)
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
m.ExpectBegin()
|
||||
q := m.ExpectQuery(stmt).WithArgs(args...)
|
||||
m.ExpectRollback()
|
||||
result := sqlmock.NewRows(cols)
|
||||
if len(row) > 0 {
|
||||
result.AddRow(row...)
|
||||
@@ -89,7 +107,28 @@ func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Va
|
||||
|
||||
func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
m.ExpectBegin()
|
||||
q := m.ExpectQuery(stmt).WithArgs(args...)
|
||||
m.ExpectCommit()
|
||||
result := sqlmock.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
if cols[len(cols)-1] == "count" {
|
||||
row = append(row, count)
|
||||
}
|
||||
result.AddRow(row...)
|
||||
}
|
||||
q.WillReturnRows(result)
|
||||
q.RowsWillBeClosed()
|
||||
return m
|
||||
}
|
||||
}
|
||||
|
||||
func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
m.ExpectBegin()
|
||||
q := m.ExpectQuery(stmt).WithArgs(args...)
|
||||
m.ExpectRollback()
|
||||
result := sqlmock.NewRows(cols)
|
||||
count := uint64(len(rows))
|
||||
for _, row := range rows {
|
||||
@@ -106,8 +145,10 @@ func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driv
|
||||
|
||||
func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
||||
m.ExpectBegin()
|
||||
q := m.ExpectQuery(stmt).WithArgs(args...)
|
||||
q.WillReturnError(err)
|
||||
m.ExpectRollback()
|
||||
return m
|
||||
}
|
||||
}
|
||||
@@ -127,52 +168,65 @@ var (
|
||||
selectBuilderType = reflect.TypeOf(sq.SelectBuilder{})
|
||||
)
|
||||
|
||||
func execScan(client *sql.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (interface{}, bool) {
|
||||
func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) {
|
||||
scanType := reflect.TypeOf(scan)
|
||||
err := validateScan(scanType)
|
||||
if err != nil {
|
||||
return err, false
|
||||
return err, false, false
|
||||
}
|
||||
|
||||
stmt, args, err := builder.ToSql()
|
||||
if err != nil {
|
||||
return fmt.Errorf("unexpeted error from sql builder: %w", err), false
|
||||
return fmt.Errorf("unexpeted error from sql builder: %w", err), false, false
|
||||
}
|
||||
|
||||
//resultSet represents *sql.Row or *sql.Rows,
|
||||
// depending on whats assignable to the scan function
|
||||
var resultSet interface{}
|
||||
var res []reflect.Value
|
||||
|
||||
//execute sql stmt
|
||||
// if scan(*sql.Rows)...
|
||||
if scanType.In(0).AssignableTo(rowsType) {
|
||||
resultSet, err = client.Query(stmt, args...)
|
||||
if err != nil {
|
||||
return errCheck(err)
|
||||
}
|
||||
err = client.Query(func(rows *sql.Rows) error {
|
||||
didScan = true
|
||||
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(rows)})
|
||||
if err, ok := res[1].Interface().(error); ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, stmt, args...)
|
||||
|
||||
// if scan(*sql.Row)...
|
||||
} else if scanType.In(0).AssignableTo(rowType) {
|
||||
row := client.QueryRow(stmt, args...)
|
||||
if row.Err() != nil {
|
||||
return errCheck(row.Err())
|
||||
}
|
||||
resultSet = row
|
||||
err = client.QueryRow(func(r *sql.Row) error {
|
||||
didScan = true
|
||||
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)})
|
||||
if err, ok := res[1].Interface().(error); ok {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}, stmt, args...)
|
||||
|
||||
} else {
|
||||
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false
|
||||
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false, false
|
||||
}
|
||||
|
||||
// res contains object and error
|
||||
res := reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(resultSet)})
|
||||
if err != nil {
|
||||
err, ok := errCheck(err)
|
||||
if didScan {
|
||||
return res[0].Interface(), ok, didScan
|
||||
}
|
||||
return err, ok, didScan
|
||||
}
|
||||
|
||||
//check for error
|
||||
if res[1].Interface() != nil {
|
||||
if err, ok := errCheck(res[1].Interface().(error)); !ok {
|
||||
return fmt.Errorf("scan failed: %w", err), false
|
||||
return fmt.Errorf("scan failed: %w", err), false, didScan
|
||||
}
|
||||
}
|
||||
|
||||
return res[0].Interface(), true
|
||||
return res[0].Interface(), true, didScan
|
||||
}
|
||||
|
||||
func validateScan(scanType reflect.Type) error {
|
||||
|
Reference in New Issue
Block a user