mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 01:37:31 +00:00
chore: use pgx v5 (#7577)
* chore: use pgx v5 * chore: update go version * remove direct pq dependency * remove unnecessary type * scan test * map scanner * converter * uint8 number array * duration * most unit tests work * unit tests work * chore: coverage * go 1.21 * linting * int64 gopfertammi * retry go 1.22 * retry go 1.22 * revert to go v1.21.5 * update go toolchain to 1.21.8 * go 1.21.8 * remove test flag * go 1.21.5 * linting * update toolchain * use correct array * use correct array * add byte array * correct value * correct error message * go 1.21 compatible
This commit is contained in:
@@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
|
@@ -19,6 +19,7 @@ type expectation func(m sqlmock.Sqlmock)
|
||||
func NewSQLMock(t *testing.T, expectations ...expectation) *SQLMock {
|
||||
db, mock, err := sqlmock.New(
|
||||
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
|
||||
sqlmock.ValueConverterOption(new(TypeConverter)),
|
||||
)
|
||||
if err != nil {
|
||||
t.Fatal("create mock failed", err)
|
||||
@@ -97,23 +98,23 @@ func ExcpectExec(stmt string, opts ...ExecOpt) expectation {
|
||||
}
|
||||
}
|
||||
|
||||
type QueryOpt func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
|
||||
type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
|
||||
|
||||
func WithQueryArgs(args ...driver.Value) QueryOpt {
|
||||
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WithArgs(args...)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryErr(err error) QueryOpt {
|
||||
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
return e.WillReturnError(err)
|
||||
}
|
||||
}
|
||||
|
||||
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
|
||||
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
mockedRows := sqlmock.NewRows(columns)
|
||||
return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
|
||||
mockedRows := m.NewRows(columns)
|
||||
for _, row := range rows {
|
||||
mockedRows = mockedRows.AddRow(row...)
|
||||
}
|
||||
@@ -125,7 +126,7 @@ func ExpectQuery(stmt string, opts ...QueryOpt) expectation {
|
||||
return func(m sqlmock.Sqlmock) {
|
||||
e := m.ExpectQuery(stmt)
|
||||
for _, opt := range opts {
|
||||
e = opt(e)
|
||||
e = opt(m, e)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
87
internal/database/mock/type_converter.go
Normal file
87
internal/database/mock/type_converter.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package mock
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var _ driver.ValueConverter = (*TypeConverter)(nil)
|
||||
|
||||
type TypeConverter struct{}
|
||||
|
||||
// ConvertValue converts a value to a driver Value.
|
||||
func (s TypeConverter) ConvertValue(v any) (driver.Value, error) {
|
||||
if driver.IsValue(v) {
|
||||
return v, nil
|
||||
}
|
||||
value := reflect.ValueOf(v)
|
||||
|
||||
if rawMessage, ok := v.(json.RawMessage); ok {
|
||||
return convertBytes(rawMessage), nil
|
||||
}
|
||||
|
||||
if value.Kind() == reflect.Slice {
|
||||
//nolint: exhaustive
|
||||
// only defined types
|
||||
switch value.Type().Elem().Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
|
||||
return convertSigned(value), nil
|
||||
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
|
||||
return convertUnsigned(value), nil
|
||||
case reflect.String:
|
||||
return convertText(value), nil
|
||||
}
|
||||
}
|
||||
return v, nil
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertSigned(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = strconv.FormatInt(array.Index(i).Int(), 10)
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertUnsigned(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10)
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
// converts a text array to valid pgx v5 representation
|
||||
func convertText(array reflect.Value) string {
|
||||
slice := make([]string, array.Len())
|
||||
for i := 0; i < array.Len(); i++ {
|
||||
slice[i] = array.Index(i).String()
|
||||
}
|
||||
|
||||
return "{" + strings.Join(slice, ",") + "}"
|
||||
}
|
||||
|
||||
func convertBytes(array []byte) string {
|
||||
var builder strings.Builder
|
||||
builder.Grow(hex.EncodedLen(len(array)) + 4)
|
||||
builder.WriteString(`\x`)
|
||||
builder.Write(AppendEncode(nil, array))
|
||||
return builder.String()
|
||||
}
|
||||
|
||||
// TODO: remove function after we compile using go 1.22 and use function of hex package `hex.AppendEncode`
|
||||
func AppendEncode(dst, src []byte) []byte {
|
||||
n := hex.EncodedLen(len(src))
|
||||
dst = slices.Grow(dst, n)
|
||||
hex.Encode(dst[len(dst):][:n], src)
|
||||
return dst[:len(dst)+n]
|
||||
}
|
@@ -6,7 +6,7 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
_ "github.com/jackc/pgx/v4/stdlib"
|
||||
_ "github.com/jackc/pgx/v5/stdlib"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
|
@@ -1,87 +1,166 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgtype"
|
||||
"github.com/jackc/pgx/v5/pgtype"
|
||||
)
|
||||
|
||||
type TextArray[t ~string] []t
|
||||
type TextArray[T ~string] pgtype.FlatArray[T]
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (s *TextArray[t]) Scan(src any) error {
|
||||
array := new(pgtype.TextArray)
|
||||
if err := array.Scan(src); err != nil {
|
||||
func (s *TextArray[T]) Scan(src any) error {
|
||||
var typedArray []string
|
||||
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return array.AssignTo(s)
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (s TextArray[t]) Value() (driver.Value, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, nil
|
||||
(*s) = make(TextArray[T], len(typedArray))
|
||||
for i, value := range typedArray {
|
||||
(*s)[i] = T(value)
|
||||
}
|
||||
|
||||
array := pgtype.TextArray{}
|
||||
if err := array.Set(s); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return array.Value()
|
||||
}
|
||||
|
||||
type arrayField interface {
|
||||
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32
|
||||
}
|
||||
|
||||
type Array[F arrayField] []F
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (a *Array[F]) Scan(src any) error {
|
||||
array := new(pgtype.Int8Array)
|
||||
if err := array.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
elements := make([]int64, len(array.Elements))
|
||||
if err := array.AssignTo(&elements); err != nil {
|
||||
return err
|
||||
}
|
||||
*a = make([]F, len(elements))
|
||||
for i, element := range elements {
|
||||
(*a)[i] = F(element)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (a Array[F]) Value() (driver.Value, error) {
|
||||
if len(a) == 0 {
|
||||
func (s TextArray[T]) Value() (driver.Value, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
array := pgtype.Int8Array{}
|
||||
if err := array.Set(a); err != nil {
|
||||
return nil, err
|
||||
typed := make([]string, len(s))
|
||||
|
||||
for i, value := range s {
|
||||
typed[i] = string(value)
|
||||
}
|
||||
|
||||
return array.Value()
|
||||
return []byte("{" + strings.Join(typed, ",") + "}"), nil
|
||||
}
|
||||
|
||||
type ByteArray[T ~byte] pgtype.FlatArray[T]
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (s *ByteArray[T]) Scan(src any) error {
|
||||
var typedArray []byte
|
||||
typedArray, ok := src.([]byte)
|
||||
if !ok {
|
||||
// tests use a different src type
|
||||
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
(*s) = make(ByteArray[T], len(typedArray))
|
||||
for i, value := range typedArray {
|
||||
(*s)[i] = T(value)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
func (s ByteArray[T]) Value() (driver.Value, error) {
|
||||
if len(s) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
typed := make([]byte, len(s))
|
||||
|
||||
for i, value := range s {
|
||||
typed[i] = byte(value)
|
||||
}
|
||||
|
||||
return typed, nil
|
||||
}
|
||||
|
||||
type numberField interface {
|
||||
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~int | ~uint
|
||||
}
|
||||
|
||||
type numberTypeField interface {
|
||||
int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | int | uint
|
||||
}
|
||||
|
||||
var _ sql.Scanner = (*NumberArray[int8])(nil)
|
||||
|
||||
type NumberArray[F numberField] pgtype.FlatArray[F]
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (a *NumberArray[F]) Scan(src any) (err error) {
|
||||
var (
|
||||
mapper func()
|
||||
scanner sql.Scanner
|
||||
)
|
||||
|
||||
//nolint: exhaustive
|
||||
// only defined types
|
||||
switch reflect.TypeOf(*a).Elem().Kind() {
|
||||
case reflect.Int8:
|
||||
mapper, scanner = castedScan[int8](a)
|
||||
case reflect.Uint8:
|
||||
// we provide int16 is a workaround because pgx thinks we want to scan a byte array if we provide uint8
|
||||
mapper, scanner = castedScan[int16](a)
|
||||
case reflect.Int16:
|
||||
mapper, scanner = castedScan[int16](a)
|
||||
case reflect.Uint16:
|
||||
mapper, scanner = castedScan[uint16](a)
|
||||
case reflect.Int32:
|
||||
mapper, scanner = castedScan[int32](a)
|
||||
case reflect.Uint32:
|
||||
mapper, scanner = castedScan[uint32](a)
|
||||
case reflect.Int64:
|
||||
mapper, scanner = castedScan[int64](a)
|
||||
case reflect.Uint64:
|
||||
mapper, scanner = castedScan[uint64](a)
|
||||
case reflect.Int:
|
||||
mapper, scanner = castedScan[int](a)
|
||||
case reflect.Uint:
|
||||
mapper, scanner = castedScan[uint](a)
|
||||
}
|
||||
|
||||
if err = scanner.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
mapper()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func castedScan[T numberTypeField, F numberField](a *NumberArray[F]) (mapper func(), scanner sql.Scanner) {
|
||||
var typedArray []T
|
||||
|
||||
mapper = func() {
|
||||
(*a) = make(NumberArray[F], len(typedArray))
|
||||
for i, value := range typedArray {
|
||||
(*a)[i] = F(value)
|
||||
}
|
||||
}
|
||||
scanner = pgtype.NewMap().SQLScanner(&typedArray)
|
||||
|
||||
return mapper, scanner
|
||||
}
|
||||
|
||||
type Map[V any] map[string]V
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (m *Map[V]) Scan(src any) error {
|
||||
bytea := new(pgtype.Bytea)
|
||||
if err := bytea.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
if len(bytea.Bytes) == 0 {
|
||||
if src == nil {
|
||||
return nil
|
||||
}
|
||||
return json.Unmarshal(bytea.Bytes, &m)
|
||||
|
||||
bytes := src.([]byte)
|
||||
if len(bytes) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
return json.Unmarshal(bytes, &m)
|
||||
}
|
||||
|
||||
// Value implements the [database/sql/driver.Valuer] interface.
|
||||
@@ -96,14 +175,35 @@ type Duration time.Duration
|
||||
|
||||
// Scan implements the [database/sql.Scanner] interface.
|
||||
func (d *Duration) Scan(src any) error {
|
||||
switch duration := src.(type) {
|
||||
case *time.Duration:
|
||||
*d = Duration(*duration)
|
||||
return nil
|
||||
case time.Duration:
|
||||
*d = Duration(duration)
|
||||
return nil
|
||||
case *pgtype.Interval:
|
||||
*d = intervalToDuration(duration)
|
||||
return nil
|
||||
case pgtype.Interval:
|
||||
*d = intervalToDuration(&duration)
|
||||
return nil
|
||||
case int64:
|
||||
*d = Duration(duration)
|
||||
return nil
|
||||
}
|
||||
interval := new(pgtype.Interval)
|
||||
if err := interval.Scan(src); err != nil {
|
||||
return err
|
||||
}
|
||||
*d = Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
|
||||
*d = intervalToDuration(interval)
|
||||
return nil
|
||||
}
|
||||
|
||||
func intervalToDuration(interval *pgtype.Interval) Duration {
|
||||
return Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
|
||||
}
|
||||
|
||||
// NullDuration can be used for NULL intervals.
|
||||
// If Valid is false, the scanned value was NULL
|
||||
// This behavior is similar to [database/sql.NullString]
|
||||
|
@@ -1,9 +1,9 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
|
||||
func TestMap_Scan(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
src []byte
|
||||
}
|
||||
type res[V any] struct {
|
||||
want Map[V]
|
||||
@@ -24,10 +24,19 @@ func TestMap_Scan(t *testing.T) {
|
||||
res[V]
|
||||
}
|
||||
tests := []testCase[string]{
|
||||
{
|
||||
"nil",
|
||||
Map[string]{},
|
||||
args{src: nil},
|
||||
res[string]{
|
||||
want: Map[string]{},
|
||||
err: false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"null",
|
||||
Map[string]{},
|
||||
args{src: "invalid"},
|
||||
args{src: []byte("invalid")},
|
||||
res[string]{
|
||||
want: Map[string]{},
|
||||
err: true,
|
||||
@@ -119,83 +128,109 @@ func TestMap_Value(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestNullDuration_Scan(t *testing.T) {
|
||||
type typedInt int
|
||||
|
||||
func TestNumberArray_Scan(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want NullDuration
|
||||
want any
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"invalid",
|
||||
args{src: "invalid"},
|
||||
res{
|
||||
want: NullDuration{
|
||||
Valid: false,
|
||||
},
|
||||
err: true,
|
||||
name: "typedInt",
|
||||
m: new(NumberArray[typedInt]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[typedInt]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
"null",
|
||||
args{src: nil},
|
||||
res{
|
||||
want: NullDuration{
|
||||
Valid: false,
|
||||
},
|
||||
err: false,
|
||||
name: "int8",
|
||||
m: new(NumberArray[int8]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int8]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
"valid",
|
||||
args{src: "1:0:0"},
|
||||
res{
|
||||
want: NullDuration{
|
||||
Valid: true,
|
||||
Duration: time.Hour,
|
||||
},
|
||||
name: "uint8",
|
||||
m: new(NumberArray[uint8]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint8]{1, 2},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
d := new(NullDuration)
|
||||
if err := d.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
assert.Equal(t, tt.res.want, *d)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestArray_ScanInt32(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res[V arrayField] struct {
|
||||
want Array[V]
|
||||
err bool
|
||||
}
|
||||
type testCase[V arrayField] struct {
|
||||
name string
|
||||
m Array[V]
|
||||
args args
|
||||
res[V]
|
||||
}
|
||||
tests := []testCase[int32]{
|
||||
{
|
||||
"number",
|
||||
Array[int32]{},
|
||||
args{src: "{1,2}"},
|
||||
res[int32]{
|
||||
want: []int32{1, 2},
|
||||
name: "int16",
|
||||
m: new(NumberArray[int16]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int16]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint16",
|
||||
m: new(NumberArray[uint16]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint16]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int32",
|
||||
m: new(NumberArray[int32]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int32]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint32",
|
||||
m: new(NumberArray[uint32]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint32]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int64",
|
||||
m: new(NumberArray[int64]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int64]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint64",
|
||||
m: new(NumberArray[uint64]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint64]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "int",
|
||||
m: new(NumberArray[int]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[int]{1, 2},
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "uint",
|
||||
m: new(NumberArray[uint]),
|
||||
args: args{src: "{1,2}"},
|
||||
res: res{
|
||||
want: &NumberArray[uint]{1, 2},
|
||||
},
|
||||
},
|
||||
}
|
||||
@@ -210,42 +245,80 @@ func TestArray_ScanInt32(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestArray_Value(t *testing.T) {
|
||||
type typedText string
|
||||
|
||||
func TestTextArray_Scan(t *testing.T) {
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want sql.Scanner
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"string",
|
||||
new(TextArray[string]),
|
||||
args{src: "{asdf,fdas}"},
|
||||
res{
|
||||
want: &TextArray[string]{"asdf", "fdas"},
|
||||
},
|
||||
},
|
||||
{
|
||||
"typedText",
|
||||
new(TextArray[typedText]),
|
||||
args{src: "{asdf,fdas}"},
|
||||
res{
|
||||
want: &TextArray[typedText]{"asdf", "fdas"},
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestTextArray_Value(t *testing.T) {
|
||||
type res struct {
|
||||
want driver.Value
|
||||
err bool
|
||||
}
|
||||
type testCase[V arrayField] struct {
|
||||
type testCase struct {
|
||||
name string
|
||||
a Array[V]
|
||||
m driver.Valuer
|
||||
res res
|
||||
}
|
||||
tests := []testCase[int32]{
|
||||
{
|
||||
"nil",
|
||||
nil,
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
tests := []testCase{
|
||||
{
|
||||
"empty",
|
||||
Array[int32]{},
|
||||
TextArray[string]{},
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"set",
|
||||
Array[int32]([]int32{1, 2}),
|
||||
TextArray[string]{"a", "s", "d", "f"},
|
||||
res{
|
||||
want: driver.Value(string([]byte(`{1,2}`))),
|
||||
want: driver.Value([]byte("{a,s,d,f}")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.a.Value()
|
||||
got, err := tt.m.Value()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
@@ -256,3 +329,126 @@ func TestArray_Value(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type typedByte byte
|
||||
|
||||
func TestByteArray_Scan(t *testing.T) {
|
||||
wantedBytes := []byte("asdf")
|
||||
wantedTypedBytes := []typedByte("asdf")
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want sql.Scanner
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"bytes",
|
||||
new(ByteArray[byte]),
|
||||
args{src: []byte("asdf")},
|
||||
res{
|
||||
want: (*ByteArray[byte])(&wantedBytes),
|
||||
},
|
||||
},
|
||||
{
|
||||
"typed",
|
||||
new(ByteArray[typedByte]),
|
||||
args{src: []byte("asdf")},
|
||||
res{
|
||||
want: (*ByteArray[typedByte])(&wantedTypedBytes),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestByteArray_Value(t *testing.T) {
|
||||
type res struct {
|
||||
want driver.Value
|
||||
err bool
|
||||
}
|
||||
type testCase struct {
|
||||
name string
|
||||
m driver.Valuer
|
||||
res res
|
||||
}
|
||||
tests := []testCase{
|
||||
{
|
||||
"empty",
|
||||
ByteArray[byte]{},
|
||||
res{
|
||||
want: nil,
|
||||
},
|
||||
},
|
||||
{
|
||||
"set",
|
||||
ByteArray[byte]([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
|
||||
res{
|
||||
want: driver.Value([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := tt.m.Value()
|
||||
if tt.res.err {
|
||||
assert.Error(t, err)
|
||||
}
|
||||
if !tt.res.err {
|
||||
require.NoError(t, err)
|
||||
assert.Equalf(t, tt.res.want, got, "Value()")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuration_Scan(t *testing.T) {
|
||||
duration := Duration(10)
|
||||
type args struct {
|
||||
src any
|
||||
}
|
||||
type res struct {
|
||||
want sql.Scanner
|
||||
err bool
|
||||
}
|
||||
type testCase[V ~string] struct {
|
||||
name string
|
||||
m sql.Scanner
|
||||
args args
|
||||
res
|
||||
}
|
||||
tests := []testCase[string]{
|
||||
{
|
||||
name: "int64",
|
||||
m: new(Duration),
|
||||
args: args{src: int64(duration)},
|
||||
res: res{
|
||||
want: &duration,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
|
||||
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
|
||||
}
|
||||
|
||||
assert.Equal(t, tt.res.want, tt.m)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user