zitadel/internal/query/prepare_test.go
Livio Spring 7494a7b6d9
feat(api): add possibility to retrieve user schemas (#7614)
This PR extends the user schema service (V3 API) with the possibility to ListUserSchemas and GetUserSchemaByID.
The previously started guide is extended to demonstrate how to retrieve the schema(s) and notes the generated revision property.
2024-03-22 13:26:13 +00:00

413 lines
10 KiB
Go

package query
import (
"context"
"database/sql"
"database/sql/driver"
"errors"
"fmt"
"log"
"reflect"
"testing"
"time"
"github.com/DATA-DOG/go-sqlmock"
sq "github.com/Masterminds/squirrel"
"github.com/jackc/pgtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
)
var (
testNow = time.Now()
dayNow = testNow.Truncate(24 * time.Hour)
)
// assertPrepare checks if the prepare func executes the correct sql query and returns the correct object
// prepareFunc must be of type
// func() (sq.SelectBuilder, func(*sql.Rows) (*struct, error))
// or
// func() (sq.SelectBuilder, func(*sql.Row) (*struct, error))
// expectedObject represents the return value of scan
// sqlExpectation represents the query executed on the database
func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExpectation sqlExpectation, isErr checkErr, prepareArgs ...reflect.Value) bool {
t.Helper()
client, mock, err := sqlmock.New()
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
}
mock = sqlExpectation(mock)
builder, scan, err := execPrepare(prepareFunc, prepareArgs)
if err != nil {
t.Error(err)
return false
}
errCheck := func(err error) (error, bool) {
if isErr == nil {
if err == nil {
return nil, true
} else {
return fmt.Errorf("no error expected got: %w", err), false
}
}
return isErr(err)
}
object, ok, didScan := execScan(t, &database.DB{DB: client}, builder, scan, errCheck)
if !ok {
t.Error(object)
return false
}
if didScan {
if !assert.Equal(t, expectedObject, object) {
return false
}
}
if err := mock.ExpectationsWereMet(); err != nil {
t.Errorf("sql expectations not met: %v", err)
return false
}
return true
}
type checkErr func(error) (err error, ok bool)
type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
}
q.WillReturnRows(result)
return m
}
}
func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
}
q.WillReturnRows(result)
return m
}
}
func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
return m
}
}
func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
row = append(row, count)
}
result.AddRow(row...)
}
q.WillReturnRows(result)
q.RowsWillBeClosed()
return m
}
}
func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
q.WillReturnError(err)
m.ExpectRollback()
return m
}
}
func execMock(t testing.TB, exp sqlExpectation, run func(db *sql.DB)) {
db, mock, err := sqlmock.New()
require.NoError(t, err)
defer db.Close()
mock = exp(mock)
run(db)
assert.NoError(t, mock.ExpectationsWereMet())
}
var (
rowType = reflect.TypeOf(&sql.Row{})
rowsType = reflect.TypeOf(&sql.Rows{})
selectBuilderType = reflect.TypeOf(sq.SelectBuilder{})
)
func execScan(t testing.TB, client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) {
scanType := reflect.TypeOf(scan)
err := validateScan(scanType)
if err != nil {
return err, false, false
}
stmt, args, err := builder.ToSql()
if err != nil {
return fmt.Errorf("unexpected error from sql builder: %w", err), false, false
}
//resultSet represents *sql.Row or *sql.Rows,
// depending on whats assignable to the scan function
var res []reflect.Value
//execute sql stmt
// if scan(*sql.Rows)...
if scanType.In(0).AssignableTo(rowsType) {
err = client.Query(func(rows *sql.Rows) error {
didScan = true
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(rows)})
if err, ok := res[1].Interface().(error); ok {
return err
}
return nil
}, stmt, args...)
// if scan(*sql.Row)...
} else if scanType.In(0).AssignableTo(rowType) {
err = client.QueryRow(func(r *sql.Row) error {
if r.Err() != nil {
return r.Err()
}
didScan = true
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)})
if err, ok := res[1].Interface().(error); ok {
return err
}
return nil
}, stmt, args...)
} else {
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false, false
}
if err != nil {
err, ok := errCheck(err)
if !ok {
t.Fatal(err)
}
if didScan {
return res[0].Interface(), ok, didScan
}
return err, ok, didScan
}
//check for error
if res[1].Interface() != nil {
if err, ok := errCheck(res[1].Interface().(error)); !ok {
return fmt.Errorf("scan failed: %w", err), false, didScan
}
}
return res[0].Interface(), true, didScan
}
func validateScan(scanType reflect.Type) error {
if scanType.Kind() != reflect.Func {
return errors.New("scan is not a function")
}
if scanType.NumIn() != 1 {
return fmt.Errorf("scan: invalid number of inputs: want: 1 got %d", scanType.NumIn())
}
if scanType.NumOut() != 2 {
return fmt.Errorf("scan: invalid number of outputs: want: 2 got %d", scanType.NumOut())
}
return nil
}
func execPrepare(prepare interface{}, args []reflect.Value) (builder sq.SelectBuilder, scan interface{}, err error) {
prepareVal := reflect.ValueOf(prepare)
if err := validatePrepare(prepareVal.Type()); err != nil {
return sq.SelectBuilder{}, nil, err
}
res := prepareVal.Call(args)
return res[0].Interface().(sq.SelectBuilder), res[1].Interface(), nil
}
func validatePrepare(prepareType reflect.Type) error {
if prepareType.Kind() != reflect.Func {
return errors.New("prepare is not a function")
}
if prepareType.NumIn() != 0 && prepareType.NumIn() != 2 {
return fmt.Errorf("prepare: invalid number of inputs: want: 0 or 2 got %d", prepareType.NumIn())
}
if prepareType.NumOut() != 2 {
return fmt.Errorf("prepare: invalid number of outputs: want: 2 got %d", prepareType.NumOut())
}
if prepareType.Out(0) != selectBuilderType {
return fmt.Errorf("prepare: first return value must be: %s got %s", selectBuilderType, prepareType.Out(0))
}
if prepareType.Out(1).Kind() != reflect.Func {
return fmt.Errorf("prepare: second return value must be: %s got %s", reflect.Func, prepareType.Out(1))
}
return nil
}
func TestValidateScan(t *testing.T) {
tests := []struct {
name string
t reflect.Type
expectErr bool
}{
{
name: "not a func",
t: reflect.TypeOf(&struct{}{}),
expectErr: true,
},
{
name: "wong input count",
t: reflect.TypeOf(func() (*struct{}, error) {
log.Fatal("should not be executed")
return nil, nil
}),
expectErr: true,
},
{
name: "wrong output count",
t: reflect.TypeOf(func(interface{}) error {
log.Fatal("should not be executed")
return nil
}),
expectErr: true,
},
{
name: "correct",
t: reflect.TypeOf(func(interface{}) (*struct{}, error) {
log.Fatal("should not be executed")
return nil, nil
}),
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validateScan(tt.t)
if (err != nil) != tt.expectErr {
t.Errorf("unexpected err: %v", err)
}
})
}
}
func TestValidatePrepare(t *testing.T) {
tests := []struct {
name string
t reflect.Type
expectErr bool
}{
{
name: "not a func",
t: reflect.TypeOf(&struct{}{}),
expectErr: true,
},
{
name: "wong input count",
t: reflect.TypeOf(func(int) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) {
log.Fatal("should not be executed")
return sq.SelectBuilder{}, nil
}),
expectErr: true,
},
{
name: "wrong output count",
t: reflect.TypeOf(func() sq.SelectBuilder {
log.Fatal("should not be executed")
return sq.SelectBuilder{}
}),
expectErr: true,
},
{
name: "first output type wrong",
t: reflect.TypeOf(func() (*struct{}, func(*sql.Rows) (interface{}, error)) {
log.Fatal("should not be executed")
return nil, nil
}),
expectErr: true,
},
{
name: "second output type wrong",
t: reflect.TypeOf(func() (sq.SelectBuilder, *struct{}) {
log.Fatal("should not be executed")
return sq.SelectBuilder{}, nil
}),
expectErr: true,
},
{
name: "correct",
t: reflect.TypeOf(func(context.Context, prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) {
log.Fatal("should not be executed")
return sq.SelectBuilder{}, nil
}),
expectErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := validatePrepare(tt.t)
if (err != nil) != tt.expectErr {
t.Errorf("unexpected err: %v", err)
}
})
}
}
func intervalDriverValue(t *testing.T, src time.Duration) pgtype.Interval {
interval := pgtype.Interval{}
err := interval.Set(src)
if err != nil {
t.Fatal(err)
}
return interval
}
type prepareDB struct{}
const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' "
func (*prepareDB) Timetravel(time.Duration) string { return asOfSystemTime }
var defaultPrepareArgs = []reflect.Value{reflect.ValueOf(context.Background()), reflect.ValueOf(new(prepareDB))}
func (*prepareDB) DatabaseName() string { return "db" }
func (*prepareDB) Username() string { return "user" }
func (*prepareDB) Type() string { return "type" }