mirror of
https://github.com/zitadel/zitadel.git
synced 2025-01-11 15:43:39 +00:00
7494a7b6d9
This PR extends the user schema service (V3 API) with the possibility to ListUserSchemas and GetUserSchemaByID. The previously started guide is extended to demonstrate how to retrieve the schema(s) and notes the generated revision property.
413 lines
10 KiB
Go
413 lines
10 KiB
Go
package query
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"database/sql/driver"
|
|
"errors"
|
|
"fmt"
|
|
"log"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/DATA-DOG/go-sqlmock"
|
|
sq "github.com/Masterminds/squirrel"
|
|
"github.com/jackc/pgtype"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/zitadel/zitadel/internal/database"
|
|
)
|
|
|
|
var (
|
|
testNow = time.Now()
|
|
dayNow = testNow.Truncate(24 * time.Hour)
|
|
)
|
|
|
|
// assertPrepare checks if the prepare func executes the correct sql query and returns the correct object
|
|
// prepareFunc must be of type
|
|
// func() (sq.SelectBuilder, func(*sql.Rows) (*struct, error))
|
|
// or
|
|
// func() (sq.SelectBuilder, func(*sql.Row) (*struct, error))
|
|
// expectedObject represents the return value of scan
|
|
// sqlExpectation represents the query executed on the database
|
|
func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExpectation sqlExpectation, isErr checkErr, prepareArgs ...reflect.Value) bool {
|
|
t.Helper()
|
|
|
|
client, mock, err := sqlmock.New()
|
|
if err != nil {
|
|
t.Fatalf("failed to build mock client: %v", err)
|
|
}
|
|
|
|
mock = sqlExpectation(mock)
|
|
|
|
builder, scan, err := execPrepare(prepareFunc, prepareArgs)
|
|
if err != nil {
|
|
t.Error(err)
|
|
return false
|
|
}
|
|
errCheck := func(err error) (error, bool) {
|
|
if isErr == nil {
|
|
if err == nil {
|
|
return nil, true
|
|
} else {
|
|
return fmt.Errorf("no error expected got: %w", err), false
|
|
}
|
|
}
|
|
return isErr(err)
|
|
}
|
|
object, ok, didScan := execScan(t, &database.DB{DB: client}, builder, scan, errCheck)
|
|
if !ok {
|
|
t.Error(object)
|
|
return false
|
|
}
|
|
if didScan {
|
|
if !assert.Equal(t, expectedObject, object) {
|
|
return false
|
|
}
|
|
}
|
|
|
|
if err := mock.ExpectationsWereMet(); err != nil {
|
|
t.Errorf("sql expectations not met: %v", err)
|
|
return false
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
type checkErr func(error) (err error, ok bool)
|
|
|
|
type sqlExpectation func(sqlmock.Sqlmock) sqlmock.Sqlmock
|
|
|
|
func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
m.ExpectBegin()
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
m.ExpectCommit()
|
|
result := sqlmock.NewRows(cols)
|
|
if len(row) > 0 {
|
|
result.AddRow(row...)
|
|
}
|
|
q.WillReturnRows(result)
|
|
return m
|
|
}
|
|
}
|
|
|
|
func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
m.ExpectBegin()
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
m.ExpectRollback()
|
|
result := sqlmock.NewRows(cols)
|
|
if len(row) > 0 {
|
|
result.AddRow(row...)
|
|
}
|
|
q.WillReturnRows(result)
|
|
return m
|
|
}
|
|
}
|
|
|
|
func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
m.ExpectBegin()
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
m.ExpectCommit()
|
|
result := sqlmock.NewRows(cols)
|
|
count := uint64(len(rows))
|
|
for _, row := range rows {
|
|
if cols[len(cols)-1] == "count" {
|
|
row = append(row, count)
|
|
}
|
|
result.AddRow(row...)
|
|
}
|
|
q.WillReturnRows(result)
|
|
q.RowsWillBeClosed()
|
|
return m
|
|
}
|
|
}
|
|
|
|
func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
m.ExpectBegin()
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
m.ExpectRollback()
|
|
result := sqlmock.NewRows(cols)
|
|
count := uint64(len(rows))
|
|
for _, row := range rows {
|
|
if cols[len(cols)-1] == "count" {
|
|
row = append(row, count)
|
|
}
|
|
result.AddRow(row...)
|
|
}
|
|
q.WillReturnRows(result)
|
|
q.RowsWillBeClosed()
|
|
return m
|
|
}
|
|
}
|
|
|
|
func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
return func(m sqlmock.Sqlmock) sqlmock.Sqlmock {
|
|
m.ExpectBegin()
|
|
q := m.ExpectQuery(stmt).WithArgs(args...)
|
|
q.WillReturnError(err)
|
|
m.ExpectRollback()
|
|
return m
|
|
}
|
|
}
|
|
|
|
func execMock(t testing.TB, exp sqlExpectation, run func(db *sql.DB)) {
|
|
db, mock, err := sqlmock.New()
|
|
require.NoError(t, err)
|
|
defer db.Close()
|
|
mock = exp(mock)
|
|
run(db)
|
|
assert.NoError(t, mock.ExpectationsWereMet())
|
|
}
|
|
|
|
var (
|
|
rowType = reflect.TypeOf(&sql.Row{})
|
|
rowsType = reflect.TypeOf(&sql.Rows{})
|
|
selectBuilderType = reflect.TypeOf(sq.SelectBuilder{})
|
|
)
|
|
|
|
func execScan(t testing.TB, client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) {
|
|
scanType := reflect.TypeOf(scan)
|
|
err := validateScan(scanType)
|
|
if err != nil {
|
|
return err, false, false
|
|
}
|
|
|
|
stmt, args, err := builder.ToSql()
|
|
if err != nil {
|
|
return fmt.Errorf("unexpected error from sql builder: %w", err), false, false
|
|
}
|
|
|
|
//resultSet represents *sql.Row or *sql.Rows,
|
|
// depending on whats assignable to the scan function
|
|
var res []reflect.Value
|
|
|
|
//execute sql stmt
|
|
// if scan(*sql.Rows)...
|
|
if scanType.In(0).AssignableTo(rowsType) {
|
|
err = client.Query(func(rows *sql.Rows) error {
|
|
didScan = true
|
|
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(rows)})
|
|
if err, ok := res[1].Interface().(error); ok {
|
|
return err
|
|
}
|
|
return nil
|
|
}, stmt, args...)
|
|
|
|
// if scan(*sql.Row)...
|
|
} else if scanType.In(0).AssignableTo(rowType) {
|
|
err = client.QueryRow(func(r *sql.Row) error {
|
|
if r.Err() != nil {
|
|
return r.Err()
|
|
}
|
|
didScan = true
|
|
res = reflect.ValueOf(scan).Call([]reflect.Value{reflect.ValueOf(r)})
|
|
if err, ok := res[1].Interface().(error); ok {
|
|
return err
|
|
}
|
|
return nil
|
|
}, stmt, args...)
|
|
|
|
} else {
|
|
return errors.New("scan: parameter must be *sql.Row or *sql.Rows"), false, false
|
|
}
|
|
|
|
if err != nil {
|
|
err, ok := errCheck(err)
|
|
if !ok {
|
|
t.Fatal(err)
|
|
}
|
|
if didScan {
|
|
return res[0].Interface(), ok, didScan
|
|
}
|
|
return err, ok, didScan
|
|
}
|
|
|
|
//check for error
|
|
if res[1].Interface() != nil {
|
|
if err, ok := errCheck(res[1].Interface().(error)); !ok {
|
|
return fmt.Errorf("scan failed: %w", err), false, didScan
|
|
}
|
|
}
|
|
|
|
return res[0].Interface(), true, didScan
|
|
}
|
|
|
|
func validateScan(scanType reflect.Type) error {
|
|
if scanType.Kind() != reflect.Func {
|
|
return errors.New("scan is not a function")
|
|
}
|
|
if scanType.NumIn() != 1 {
|
|
return fmt.Errorf("scan: invalid number of inputs: want: 1 got %d", scanType.NumIn())
|
|
}
|
|
if scanType.NumOut() != 2 {
|
|
return fmt.Errorf("scan: invalid number of outputs: want: 2 got %d", scanType.NumOut())
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func execPrepare(prepare interface{}, args []reflect.Value) (builder sq.SelectBuilder, scan interface{}, err error) {
|
|
prepareVal := reflect.ValueOf(prepare)
|
|
if err := validatePrepare(prepareVal.Type()); err != nil {
|
|
return sq.SelectBuilder{}, nil, err
|
|
}
|
|
res := prepareVal.Call(args)
|
|
|
|
return res[0].Interface().(sq.SelectBuilder), res[1].Interface(), nil
|
|
}
|
|
|
|
func validatePrepare(prepareType reflect.Type) error {
|
|
if prepareType.Kind() != reflect.Func {
|
|
return errors.New("prepare is not a function")
|
|
}
|
|
if prepareType.NumIn() != 0 && prepareType.NumIn() != 2 {
|
|
return fmt.Errorf("prepare: invalid number of inputs: want: 0 or 2 got %d", prepareType.NumIn())
|
|
}
|
|
if prepareType.NumOut() != 2 {
|
|
return fmt.Errorf("prepare: invalid number of outputs: want: 2 got %d", prepareType.NumOut())
|
|
}
|
|
if prepareType.Out(0) != selectBuilderType {
|
|
return fmt.Errorf("prepare: first return value must be: %s got %s", selectBuilderType, prepareType.Out(0))
|
|
}
|
|
if prepareType.Out(1).Kind() != reflect.Func {
|
|
return fmt.Errorf("prepare: second return value must be: %s got %s", reflect.Func, prepareType.Out(1))
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func TestValidateScan(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
t reflect.Type
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "not a func",
|
|
t: reflect.TypeOf(&struct{}{}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "wong input count",
|
|
t: reflect.TypeOf(func() (*struct{}, error) {
|
|
log.Fatal("should not be executed")
|
|
return nil, nil
|
|
}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "wrong output count",
|
|
t: reflect.TypeOf(func(interface{}) error {
|
|
log.Fatal("should not be executed")
|
|
return nil
|
|
}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "correct",
|
|
t: reflect.TypeOf(func(interface{}) (*struct{}, error) {
|
|
log.Fatal("should not be executed")
|
|
return nil, nil
|
|
}),
|
|
expectErr: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validateScan(tt.t)
|
|
if (err != nil) != tt.expectErr {
|
|
t.Errorf("unexpected err: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func TestValidatePrepare(t *testing.T) {
|
|
tests := []struct {
|
|
name string
|
|
t reflect.Type
|
|
expectErr bool
|
|
}{
|
|
{
|
|
name: "not a func",
|
|
t: reflect.TypeOf(&struct{}{}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "wong input count",
|
|
t: reflect.TypeOf(func(int) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) {
|
|
log.Fatal("should not be executed")
|
|
return sq.SelectBuilder{}, nil
|
|
}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "wrong output count",
|
|
t: reflect.TypeOf(func() sq.SelectBuilder {
|
|
log.Fatal("should not be executed")
|
|
return sq.SelectBuilder{}
|
|
}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "first output type wrong",
|
|
t: reflect.TypeOf(func() (*struct{}, func(*sql.Rows) (interface{}, error)) {
|
|
log.Fatal("should not be executed")
|
|
return nil, nil
|
|
}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "second output type wrong",
|
|
t: reflect.TypeOf(func() (sq.SelectBuilder, *struct{}) {
|
|
log.Fatal("should not be executed")
|
|
return sq.SelectBuilder{}, nil
|
|
}),
|
|
expectErr: true,
|
|
},
|
|
{
|
|
name: "correct",
|
|
t: reflect.TypeOf(func(context.Context, prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (interface{}, error)) {
|
|
log.Fatal("should not be executed")
|
|
return sq.SelectBuilder{}, nil
|
|
}),
|
|
expectErr: false,
|
|
},
|
|
}
|
|
for _, tt := range tests {
|
|
t.Run(tt.name, func(t *testing.T) {
|
|
err := validatePrepare(tt.t)
|
|
if (err != nil) != tt.expectErr {
|
|
t.Errorf("unexpected err: %v", err)
|
|
}
|
|
})
|
|
}
|
|
}
|
|
|
|
func intervalDriverValue(t *testing.T, src time.Duration) pgtype.Interval {
|
|
interval := pgtype.Interval{}
|
|
err := interval.Set(src)
|
|
if err != nil {
|
|
t.Fatal(err)
|
|
}
|
|
return interval
|
|
}
|
|
|
|
type prepareDB struct{}
|
|
|
|
const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' "
|
|
|
|
func (*prepareDB) Timetravel(time.Duration) string { return asOfSystemTime }
|
|
|
|
var defaultPrepareArgs = []reflect.Value{reflect.ValueOf(context.Background()), reflect.ValueOf(new(prepareDB))}
|
|
|
|
func (*prepareDB) DatabaseName() string { return "db" }
|
|
|
|
func (*prepareDB) Username() string { return "user" }
|
|
|
|
func (*prepareDB) Type() string { return "type" }
|