feat(storage): read only transactions for queries (#6415)

* fix: tests

* bastle wie en grosse

* fix(database): scan as callback

* fix tests

* fix merge failures

* remove as of system time

* refactor: remove unused test

* refacotr: remove unused lines
This commit is contained in:
Silvan
2023-08-22 12:49:22 +02:00
committed by GitHub
parent a9fb2a6e5c
commit 99e1c654a3
128 changed files with 1355 additions and 897 deletions

View File

@@ -15,6 +15,7 @@ import (
sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
)
var (
@@ -53,14 +54,15 @@ func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExp
}
return isErr(err)
}
object, ok := execScan(client, builder, scan, errCheck)
object, ok, didScan := execScan(&database.DB{DB: client}, builder, scan, errCheck)
if !ok {
t.Error(object)
return false
}
if !assert.Equal(t, expectedObject, object) {
return false
if didScan {
if !assert.Equal(t, expectedObject, object) {
return false
}
}
if err := mock.ExpectationsWereMet(); err != nil {
@@ -77,7 +79,23 @@ type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
}
q.WillReturnRows(result)
return m
}
}
func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
@@ -89,7 +107,28 @@ func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Va
func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
return m
}
}
func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
@@ -106,8 +145,10 @@ func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driv
func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
q.WillReturnError(err)
m.ExpectRollback()
return m
}
}
@@ -127,52 +168,65 @@ var (
selectBuilderType = reflect.TypeOf(sq.SelectBuilder{})
)
func execScan(client *sql.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (interface{}, bool) {
func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) {
scanType := reflect.TypeOf(scan)
err := validateScan(scanType)
if err != nil {
return err, false
return err, false, false
}
stmt, args, err := builder.ToSql()
if err != nil {
return fmt.Errorf("unexpeted error from sql builder: %w", err), false
return fmt.Errorf("unexpeted error from sql builder: %w", err), false, false
}
//resultSet represents *sql.Row or *sql.Rows,
// depending on whats assignable to the scan function
var resultSet interface{}
var res []reflect.Value
//execute sql stmt
// if scan(*sql.Rows)...
if scanType.In(0).AssignableTo(rowsType) {
resultSet, err = client.Query(stmt, args...)
if err != nil {
return errCheck(err)
}
err = client.Query(func(rows *sql.Rows) error {
didScan = true
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(rows)})
if err, ok := res[1].Interface().(error); ok {
return err
}
return nil
}, stmt, args...)
// if scan(*sql.Row)...
} else if scanType.In(0).AssignableTo(rowType) {
row := client.QueryRow(stmt, args...)
if row.Err() != nil {
return errCheck(row.Err())
}
resultSet = row
err = client.QueryRow(func(r *sql.Row) error {
didScan = true
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)})
if err, ok := res[1].Interface().(error); ok {
return err
}
return nil
}, stmt, args...)
} else {
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false, false
}
// res contains object and error
res := reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(resultSet)})
if err != nil {
err, ok := errCheck(err)
if didScan {
return res[0].Interface(), ok, didScan
}
return err, ok, didScan
}
//check for error
if res[1].Interface() != nil {
if err, ok := errCheck(res[1].Interface().(error)); !ok {
return fmt.Errorf("scan failed: %w", err), false
return fmt.Errorf("scan failed: %w", err), false, didScan
}
}
return res[0].Interface(), true
return res[0].Interface(), true, didScan
}
func validateScan(scanType reflect.Type) error {