zitadel/internal/v2/database/sql_helper_test.go

513 lines
9.7 KiB
Go

package database
import (
"errors"
"reflect"
"testing"
)
func TestCloseTx(t *testing.T) {
type args struct {
tx *testTx
err error
}
tests := []struct {
name string
args args
assertErr func(t *testing.T, err error) bool
}{
{
name: "exec err",
args: args{
tx: &testTx{
rollback: execution{
shouldExecute: true,
},
},
err: errExec,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("execution error expected, got: %v", err)
}
return is
},
},
{
name: "exec err and rollback err",
args: args{
tx: &testTx{
rollback: execution{
err: true,
shouldExecute: true,
},
},
err: errExec,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("execution error expected, got: %v", err)
}
return is
},
},
{
name: "commit Err",
args: args{
tx: &testTx{
commit: execution{
err: true,
shouldExecute: true,
},
},
err: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errCommit)
if !is {
t.Errorf("commit error expected, got: %v", err)
}
return is
},
},
{
name: "no err",
args: args{
tx: &testTx{
commit: execution{
shouldExecute: true,
},
},
err: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := err == nil
if !is {
t.Errorf("no error expected, got: %v", err)
}
return is
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := CloseTx(tt.args.tx, tt.args.err)
tt.assertErr(t, err)
tt.args.tx.assert(t)
})
}
}
func TestMapRows(t *testing.T) {
type args struct {
rows *testRows
mapper DestMapper[string]
}
var emptyString string
tests := []struct {
name string
args args
wantResult []*string
assertErr func(t *testing.T, err error) bool
}{
{
name: "no rows, close err",
args: args{
rows: &testRows{
closeErr: true,
},
mapper: nil,
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errClose)
if !is {
t.Errorf("close error expected, got: %v", err)
}
return is
},
},
{
name: "no rows, close err",
args: args{
rows: &testRows{
hasErr: true,
},
mapper: nil,
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "scan err",
args: args{
rows: &testRows{
scanErr: true,
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
var s string
if err := scan(&s); err != nil {
return nil, err
}
return &s, nil
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errScan)
if !is {
t.Errorf("scan error expected, got: %v", err)
}
return is
},
},
{
name: "exec err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
return nil, errExec
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "exec err, close err",
args: args{
rows: &testRows{
closeErr: true,
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
return nil, errExec
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "rows err",
args: args{
rows: &testRows{
nextCount: 1,
hasErr: true,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
var s string
if err := scan(&s); err != nil {
return nil, err
}
return &s, nil
},
},
wantResult: nil,
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "no err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(index int, scan func(dest ...any) error) (*string, error) {
var s string
if err := scan(&s); err != nil {
return nil, err
}
return &s, nil
},
},
wantResult: []*string{&emptyString},
assertErr: func(t *testing.T, err error) bool {
is := err == nil
if !is {
t.Errorf("no error expected, got: %v", err)
}
return is
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotResult, err := MapRows(tt.args.rows, tt.args.mapper)
tt.assertErr(t, err)
if !reflect.DeepEqual(gotResult, tt.wantResult) {
t.Errorf("MapRows() = %v, want %v", gotResult, tt.wantResult)
}
})
}
}
func TestMapRowsToObject(t *testing.T) {
type args struct {
rows *testRows
mapper func(scan func(dest ...any) error) error
}
tests := []struct {
name string
args args
assertErr func(t *testing.T, err error) bool
}{
{
name: "no rows, close err",
args: args{
rows: &testRows{
closeErr: true,
},
mapper: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errClose)
if !is {
t.Errorf("close error expected, got: %v", err)
}
return is
},
},
{
name: "no rows, close err",
args: args{
rows: &testRows{
hasErr: true,
},
mapper: nil,
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "scan err",
args: args{
rows: &testRows{
scanErr: true,
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
var s string
if err := scan(&s); err != nil {
return err
}
return nil
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errScan)
if !is {
t.Errorf("scan error expected, got: %v", err)
}
return is
},
},
{
name: "exec err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
return errExec
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "exec err, close err",
args: args{
rows: &testRows{
closeErr: true,
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
return errExec
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errExec)
if !is {
t.Errorf("exec error expected, got: %v", err)
}
return is
},
},
{
name: "rows err",
args: args{
rows: &testRows{
nextCount: 1,
hasErr: true,
},
mapper: func(scan func(dest ...any) error) error {
var s string
return scan(&s)
},
},
assertErr: func(t *testing.T, err error) bool {
is := errors.Is(err, errRows)
if !is {
t.Errorf("rows error expected, got: %v", err)
}
return is
},
},
{
name: "no err",
args: args{
rows: &testRows{
nextCount: 1,
},
mapper: func(scan func(dest ...any) error) error {
var s string
return scan(&s)
},
},
assertErr: func(t *testing.T, err error) bool {
is := err == nil
if !is {
t.Errorf("no error expected, got: %v", err)
}
return is
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := MapRowsToObject(tt.args.rows, tt.args.mapper)
tt.assertErr(t, err)
})
}
}
var _ Tx = (*testTx)(nil)
type testTx struct {
commit, rollback execution
}
type execution struct {
err bool
didExecute bool
shouldExecute bool
}
var (
errCommit = errors.New("commit err")
errRollback = errors.New("rollback err")
errExec = errors.New("exec err")
)
// Commit implements Tx.
func (t *testTx) Commit() error {
t.commit.didExecute = true
if t.commit.err {
return errCommit
}
return nil
}
// Rollback implements Tx.
func (t *testTx) Rollback() error {
t.rollback.didExecute = true
if t.rollback.err {
return errRollback
}
return nil
}
func (tx *testTx) assert(t *testing.T) {
if tx.commit.didExecute != tx.commit.shouldExecute {
t.Errorf("unexpected execution of commit: should %v, did: %v", tx.commit.shouldExecute, tx.commit.didExecute)
}
if tx.rollback.didExecute != tx.rollback.shouldExecute {
t.Errorf("unexpected execution of rollback: should %v, did: %v", tx.rollback.shouldExecute, tx.rollback.didExecute)
}
}
var _ Rows = (*testRows)(nil)
var (
errClose = errors.New("err close")
errRows = errors.New("err rows")
errScan = errors.New("err scan")
)
type testRows struct {
closeErr bool
scanErr bool
hasErr bool
nextCount int
}
// Close implements Rows.
func (t *testRows) Close() error {
if t.closeErr {
return errClose
}
return nil
}
// Err implements Rows.
func (t *testRows) Err() error {
if t.hasErr {
return errRows
}
if t.closeErr {
return errClose
}
return nil
}
// Next implements Rows.
func (t *testRows) Next() bool {
t.nextCount--
return t.nextCount >= 0
}
// Scan implements Rows.
func (t *testRows) Scan(dest ...any) error {
if t.scanErr {
return errScan
}
return nil
}