2021-11-15 10:04:46 +01:00
|
|
|
package query
|
|
|
|
|
|
|
|
import (
|
2023-02-27 22:36:43 +01:00
|
|
|
"context"
|
2021-11-15 10:04:46 +01:00
|
|
|
"database/sql"
|
|
|
|
"database/sql/driver"
|
|
|
|
"errors"
|
|
|
|
"fmt"
|
|
|
|
"log"
|
|
|
|
"reflect"
|
|
|
|
"testing"
|
2021-11-15 13:52:49 +01:00
|
|
|
"time"
|
2021-11-15 10:04:46 +01:00
|
|
|
|
|
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
|
|
sq "github.com/Masterminds/squirrel"
|
2022-01-12 13:22:04 +01:00
|
|
|
"github.com/stretchr/testify/assert"
|
2023-07-10 15:27:00 +02:00
|
|
|
"github.com/stretchr/testify/require"
|
2023-08-22 14:49:02 +02:00
|
|
|
"github.com/zitadel/zitadel/internal/database"
|
2021-11-15 10:04:46 +01:00
|
|
|
)
|
|
|
|
|
2021-11-15 13:52:49 +01:00
|
|
|
var (
|
|
|
|
testNow = time.Now()
|
|
|
|
)
|
|
|
|
|
2023-02-27 22:36:43 +01:00
|
|
|
// assertPrepare checks if the prepare func executes the correct sql query and returns the correct object
|
|
|
|
// prepareFunc must be of type
|
2021-11-15 10:04:46 +01:00
|
|
|
// func() (sq.SelectBuilder, func(*sql.Rows) (*struct, error))
|
|
|
|
// or
|
|
|
|
// func() (sq.SelectBuilder, func(*sql.Row) (*struct, error))
|
2023-02-27 22:36:43 +01:00
|
|
|
// expectedObject represents the return value of scan
|
|
|
|
// sqlExpectation represents the query executed on the database
|
|
|
|
func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExpectation sqlExpectation, isErr checkErr, prepareArgs ...reflect.Value) bool {
|
2021-11-15 10:04:46 +01:00
|
|
|
t.Helper()
|
|
|
|
|
|
|
|
client, mock, err := sqlmock.New()
|
|
|
|
if err != nil {
|
|
|
|
t.Fatalf("failed to build mock client: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
mock = sqlExpectation(mock)
|
|
|
|
|
2023-02-27 22:36:43 +01:00
|
|
|
builder, scan, err := execPrepare(prepareFunc, prepareArgs)
|
2021-11-15 10:04:46 +01:00
|
|
|
if err != nil {
|
|
|
|
t.Error(err)
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
errCheck := func(err error) (error, bool) {
|
|
|
|
if isErr == nil {
|
|
|
|
if err == nil {
|
|
|
|
return nil, true
|
|
|
|
} else {
|
|
|
|
return fmt.Errorf("no error expected got: %w", err), false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
return isErr(err)
|
|
|
|
}
|
2023-08-22 14:49:02 +02:00
|
|
|
object, ok, didScan := execScan(&database.DB{DB: client}, builder, scan, errCheck)
|
2021-11-15 10:04:46 +01:00
|
|
|
if !ok {
|
|
|
|
t.Error(object)
|
|
|
|
return false
|
|
|
|
}
|
2023-08-22 14:49:02 +02:00
|
|
|
if didScan {
|
|
|
|
if !assert.Equal(t, expectedObject, object) {
|
|
|
|
return false
|
|
|
|
}
|
2021-11-15 10:04:46 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
if err := mock.ExpectationsWereMet(); err != nil {
|
|
|
|
t.Errorf("sql expectations not met: %v", err)
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
|
|
|
|
return true
|
|
|
|
}
|
|
|
|
|
|
|
|
type checkErr func(error) (err error, ok bool)
|
|
|
|
|
|
|
|
type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
|
|
|
|
|
2023-07-10 15:27:00 +02:00
|
|
|
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
2021-11-15 10:04:46 +01:00
|
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
2023-08-22 14:49:02 +02:00
|
|
|
m.ExpectBegin()
|
2023-07-10 15:27:00 +02:00
|
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
2023-08-22 14:49:02 +02:00
|
|
|
m.ExpectCommit()
|
|
|
|
result := sqlmock.NewRows(cols)
|
|
|
|
if len(row) > 0 {
|
|
|
|
result.AddRow(row...)
|
|
|
|
}
|
|
|
|
q.WillReturnRows(result)
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
|
|
m.ExpectBegin()
|
|
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
|
|
m.ExpectRollback()
|
2021-11-15 10:04:46 +01:00
|
|
|
result := sqlmock.NewRows(cols)
|
2021-11-15 16:04:08 +01:00
|
|
|
if len(row) > 0 {
|
|
|
|
result.AddRow(row...)
|
|
|
|
}
|
2021-11-15 10:04:46 +01:00
|
|
|
q.WillReturnRows(result)
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
2023-08-22 14:49:02 +02:00
|
|
|
m.ExpectBegin()
|
|
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
|
|
m.ExpectCommit()
|
|
|
|
result := sqlmock.NewRows(cols)
|
|
|
|
count := uint64(len(rows))
|
|
|
|
for _, row := range rows {
|
|
|
|
if cols[len(cols)-1] == "count" {
|
|
|
|
row = append(row, count)
|
|
|
|
}
|
|
|
|
result.AddRow(row...)
|
|
|
|
}
|
|
|
|
q.WillReturnRows(result)
|
|
|
|
q.RowsWillBeClosed()
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
|
|
m.ExpectBegin()
|
2021-11-15 10:04:46 +01:00
|
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
2023-08-22 14:49:02 +02:00
|
|
|
m.ExpectRollback()
|
2021-11-15 10:04:46 +01:00
|
|
|
result := sqlmock.NewRows(cols)
|
|
|
|
count := uint64(len(rows))
|
|
|
|
for _, row := range rows {
|
2021-11-16 14:04:22 +01:00
|
|
|
if cols[len(cols)-1] == "count" {
|
|
|
|
row = append(row, count)
|
|
|
|
}
|
2021-11-15 10:04:46 +01:00
|
|
|
result.AddRow(row...)
|
|
|
|
}
|
|
|
|
q.WillReturnRows(result)
|
|
|
|
q.RowsWillBeClosed()
|
|
|
|
return m
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
2023-08-22 14:49:02 +02:00
|
|
|
m.ExpectBegin()
|
2021-11-15 10:04:46 +01:00
|
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
|
|
q.WillReturnError(err)
|
2023-08-22 14:49:02 +02:00
|
|
|
m.ExpectRollback()
|
2021-11-15 10:04:46 +01:00
|
|
|
return m
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-07-10 15:27:00 +02:00
|
|
|
func execMock(t testing.TB, exp sqlExpectation, run func(db *sql.DB)) {
|
|
|
|
db, mock, err := sqlmock.New()
|
|
|
|
require.NoError(t, err)
|
|
|
|
defer db.Close()
|
|
|
|
mock = exp(mock)
|
|
|
|
run(db)
|
|
|
|
assert.NoError(t, mock.ExpectationsWereMet())
|
|
|
|
}
|
|
|
|
|
2021-11-15 10:04:46 +01:00
|
|
|
var (
|
|
|
|
rowType = reflect.TypeOf(&sql.Row{})
|
|
|
|
rowsType = reflect.TypeOf(&sql.Rows{})
|
|
|
|
selectBuilderType = reflect.TypeOf(sq.SelectBuilder{})
|
|
|
|
)
|
|
|
|
|
2023-08-22 14:49:02 +02:00
|
|
|
func execScan(client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) {
|
2021-11-15 10:04:46 +01:00
|
|
|
scanType := reflect.TypeOf(scan)
|
|
|
|
err := validateScan(scanType)
|
|
|
|
if err != nil {
|
2023-08-22 14:49:02 +02:00
|
|
|
return err, false, false
|
2021-11-15 10:04:46 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
stmt, args, err := builder.ToSql()
|
|
|
|
if err != nil {
|
2023-08-22 14:49:02 +02:00
|
|
|
return fmt.Errorf("unexpeted error from sql builder: %w", err), false, false
|
2021-11-15 10:04:46 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
//resultSet represents *sql.Row or *sql.Rows,
|
|
|
|
// depending on whats assignable to the scan function
|
2023-08-22 14:49:02 +02:00
|
|
|
var res []reflect.Value
|
2021-11-15 10:04:46 +01:00
|
|
|
|
|
|
|
//execute sql stmt
|
|
|
|
// if scan(*sql.Rows)...
|
|
|
|
if scanType.In(0).AssignableTo(rowsType) {
|
2023-08-22 14:49:02 +02:00
|
|
|
err = client.Query(func(rows *sql.Rows) error {
|
|
|
|
didScan = true
|
|
|
|
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(rows)})
|
|
|
|
if err, ok := res[1].Interface().(error); ok {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}, stmt, args...)
|
2021-11-15 10:04:46 +01:00
|
|
|
|
|
|
|
// if scan(*sql.Row)...
|
|
|
|
} else if scanType.In(0).AssignableTo(rowType) {
|
2023-08-22 14:49:02 +02:00
|
|
|
err = client.QueryRow(func(r *sql.Row) error {
|
|
|
|
didScan = true
|
|
|
|
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)})
|
|
|
|
if err, ok := res[1].Interface().(error); ok {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}, stmt, args...)
|
|
|
|
|
2021-11-15 10:04:46 +01:00
|
|
|
} else {
|
2023-08-22 14:49:02 +02:00
|
|
|
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false, false
|
2021-11-15 10:04:46 +01:00
|
|
|
}
|
|
|
|
|
2023-08-22 14:49:02 +02:00
|
|
|
if err != nil {
|
|
|
|
err, ok := errCheck(err)
|
|
|
|
if didScan {
|
|
|
|
return res[0].Interface(), ok, didScan
|
|
|
|
}
|
|
|
|
return err, ok, didScan
|
|
|
|
}
|
2021-11-15 10:04:46 +01:00
|
|
|
|
|
|
|
//check for error
|
|
|
|
if res[1].Interface() != nil {
|
|
|
|
if err, ok := errCheck(res[1].Interface().(error)); !ok {
|
2023-08-22 14:49:02 +02:00
|
|
|
return fmt.Errorf("scan failed: %w", err), false, didScan
|
2021-11-15 10:04:46 +01:00
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2023-08-22 14:49:02 +02:00
|
|
|
return res[0].Interface(), true, didScan
|
2021-11-15 10:04:46 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
func validateScan(scanType reflect.Type) error {
|
|
|
|
if scanType.Kind() != reflect.Func {
|
|
|
|
return errors.New("scan is not a function")
|
|
|
|
}
|
|
|
|
if scanType.NumIn() != 1 {
|
|
|
|
return fmt.Errorf("scan: invalid number of inputs: want: 1 got %d", scanType.NumIn())
|
|
|
|
}
|
|
|
|
if scanType.NumOut() != 2 {
|
|
|
|
return fmt.Errorf("scan: invalid number of outputs: want: 2 got %d", scanType.NumOut())
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2023-02-27 22:36:43 +01:00
|
|
|
func execPrepare(prepare interface{}, args []reflect.Value) (builder sq.SelectBuilder, scan interface{}, err error) {
|
2021-11-15 10:04:46 +01:00
|
|
|
prepareVal := reflect.ValueOf(prepare)
|
|
|
|
if err := validatePrepare(prepareVal.Type()); err != nil {
|
|
|
|
return sq.SelectBuilder{}, nil, err
|
|
|
|
}
|
2023-02-27 22:36:43 +01:00
|
|
|
res := prepareVal.Call(args)
|
2021-11-15 10:04:46 +01:00
|
|
|
|
|
|
|
return res[0].Interface().(sq.SelectBuilder), res[1].Interface(), nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func validatePrepare(prepareType reflect.Type) error {
|
|
|
|
if prepareType.Kind() != reflect.Func {
|
|
|
|
return errors.New("prepare is not a function")
|
|
|
|
}
|
2023-02-27 22:36:43 +01:00
|
|
|
if prepareType.NumIn() < 2 {
|
2021-11-15 10:04:46 +01:00
|
|
|
return fmt.Errorf("prepare: invalid number of inputs: want: 0 got %d", prepareType.NumIn())
|
|
|
|
}
|
|
|
|
if prepareType.NumOut() != 2 {
|
|
|
|
return fmt.Errorf("prepare: invalid number of outputs: want: 2 got %d", prepareType.NumOut())
|
|
|
|
}
|
|
|
|
if prepareType.Out(0) != selectBuilderType {
|
|
|
|
return fmt.Errorf("prepare: first return value must be: %s got %s", selectBuilderType, prepareType.Out(0))
|
|
|
|
}
|
|
|
|
if prepareType.Out(1).Kind() != reflect.Func {
|
|
|
|
return fmt.Errorf("prepare: second return value must be: %s got %s", reflect.Func, prepareType.Out(1))
|
|
|
|
}
|
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestValidateScan(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
t reflect.Type
|
|
|
|
expectErr bool
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "not a func",
|
|
|
|
t: reflect.TypeOf(&struct{}{}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "wong input count",
|
|
|
|
t: reflect.TypeOf(func() (*struct{}, error) {
|
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return nil, nil
|
|
|
|
}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "wrong output count",
|
|
|
|
t: reflect.TypeOf(func(interface{}) error {
|
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return nil
|
|
|
|
}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "correct",
|
|
|
|
t: reflect.TypeOf(func(interface{}) (*struct{}, error) {
|
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return nil, nil
|
|
|
|
}),
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
err := validateScan(tt.t)
|
|
|
|
if (err != nil) != tt.expectErr {
|
|
|
|
t.Errorf("unexpected err: %v", err)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func TestValidatePrepare(t *testing.T) {
|
|
|
|
tests := []struct {
|
|
|
|
name string
|
|
|
|
t reflect.Type
|
|
|
|
expectErr bool
|
|
|
|
}{
|
|
|
|
{
|
|
|
|
name: "not a func",
|
|
|
|
t: reflect.TypeOf(&struct{}{}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "wong input count",
|
|
|
|
t: reflect.TypeOf(func(int) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) {
|
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return sq.SelectBuilder{}, nil
|
|
|
|
}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "wrong output count",
|
|
|
|
t: reflect.TypeOf(func() sq.SelectBuilder {
|
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return sq.SelectBuilder{}
|
|
|
|
}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "first output type wrong",
|
|
|
|
t: reflect.TypeOf(func() (*struct{}, func(*sql.Rows) (interface{}, error)) {
|
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return nil, nil
|
|
|
|
}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "second output type wrong",
|
|
|
|
t: reflect.TypeOf(func() (sq.SelectBuilder, *struct{}) {
|
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return sq.SelectBuilder{}, nil
|
|
|
|
}),
|
|
|
|
expectErr: true,
|
|
|
|
},
|
|
|
|
{
|
|
|
|
name: "correct",
|
2023-02-27 22:36:43 +01:00
|
|
|
t: reflect.TypeOf(func(context.Context, prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) {
|
2021-11-15 10:04:46 +01:00
|
|
|
log.Fatal("should not be executed")
|
|
|
|
return sq.SelectBuilder{}, nil
|
|
|
|
}),
|
|
|
|
expectErr: false,
|
|
|
|
},
|
|
|
|
}
|
|
|
|
for _, tt := range tests {
|
|
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
|
|
err := validatePrepare(tt.t)
|
|
|
|
if (err != nil) != tt.expectErr {
|
|
|
|
t.Errorf("unexpected err: %v", err)
|
|
|
|
}
|
|
|
|
})
|
|
|
|
}
|
|
|
|
}
|
2023-02-27 22:36:43 +01:00
|
|
|
|
|
|
|
type prepareDB struct{}
|
|
|
|
|
2023-07-10 15:27:00 +02:00
|
|
|
const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' "
|
|
|
|
|
|
|
|
func (*prepareDB) Timetravel(time.Duration) string { return asOfSystemTime }
|
2023-02-27 22:36:43 +01:00
|
|
|
|
|
|
|
var defaultPrepareArgs = []reflect.Value{reflect.ValueOf(context.Background()), reflect.ValueOf(new(prepareDB))}
|
|
|
|
|
|
|
|
func (*prepareDB) DatabaseName() string { return "db" }
|
|
|
|
|
|
|
|
func (*prepareDB) Username() string { return "user" }
|
|
|
|
|
|
|
|
func (*prepareDB) Type() string { return "type" }
|