chore: use pgx v5 (#7577)

* chore: use pgx v5

* chore: update go version

* remove direct pq dependency

* remove unnecessary type

* scan test

* map scanner

* converter

* uint8 number array

* duration

* most unit tests work

* unit tests work

* chore: coverage

* go 1.21

* linting

* int64 gopfertammi

* retry go 1.22

* retry go 1.22

* revert to go v1.21.5

* update go toolchain to 1.21.8

* go 1.21.8

* remove test flag

* go 1.21.5

* linting

* update toolchain

* use correct array

* use correct array

* add byte array

* correct value

* correct error message

* go 1.21 compatible
This commit is contained in:
Silvan
2024-03-27 14:48:22 +01:00
committed by GitHub
parent 2ea0b520fd
commit 56df515e5f
49 changed files with 801 additions and 493 deletions

View File

@@ -2,6 +2,7 @@ package admin
import (
"context"
"time"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -394,11 +395,11 @@ func (s *Server) getLoginPolicy(ctx context.Context, orgID string, orgIDPs []str
return nil, err
}
if !queriedLogin.IsDefault {
pwCheck := durationpb.New(queriedLogin.PasswordCheckLifetime)
externalLogin := durationpb.New(queriedLogin.ExternalLoginCheckLifetime)
mfaInitSkip := durationpb.New(queriedLogin.MFAInitSkipLifetime)
secondFactor := durationpb.New(queriedLogin.SecondFactorCheckLifetime)
multiFactor := durationpb.New(queriedLogin.MultiFactorCheckLifetime)
pwCheck := durationpb.New(time.Duration(queriedLogin.PasswordCheckLifetime))
externalLogin := durationpb.New(time.Duration(queriedLogin.ExternalLoginCheckLifetime))
mfaInitSkip := durationpb.New(time.Duration(queriedLogin.MFAInitSkipLifetime))
secondFactor := durationpb.New(time.Duration(queriedLogin.SecondFactorCheckLifetime))
multiFactor := durationpb.New(time.Duration(queriedLogin.MultiFactorCheckLifetime))
secondFactors := []policy_pb.SecondFactorType{}
for _, factor := range queriedLogin.SecondFactors {

View File

@@ -1153,11 +1153,11 @@ func (s *Server) dataOrgsV1ToDataOrgs(ctx context.Context, dataOrgs *v1_pb.Impor
if err != nil {
return nil, err
}
org.LoginPolicy.ExternalLoginCheckLifetime = durationpb.New(defaultLoginPolicy.ExternalLoginCheckLifetime)
org.LoginPolicy.MultiFactorCheckLifetime = durationpb.New(defaultLoginPolicy.MultiFactorCheckLifetime)
org.LoginPolicy.SecondFactorCheckLifetime = durationpb.New(defaultLoginPolicy.SecondFactorCheckLifetime)
org.LoginPolicy.PasswordCheckLifetime = durationpb.New(defaultLoginPolicy.PasswordCheckLifetime)
org.LoginPolicy.MfaInitSkipLifetime = durationpb.New(defaultLoginPolicy.MFAInitSkipLifetime)
org.LoginPolicy.ExternalLoginCheckLifetime = durationpb.New(time.Duration(defaultLoginPolicy.ExternalLoginCheckLifetime))
org.LoginPolicy.MultiFactorCheckLifetime = durationpb.New(time.Duration(defaultLoginPolicy.MultiFactorCheckLifetime))
org.LoginPolicy.SecondFactorCheckLifetime = durationpb.New(time.Duration(defaultLoginPolicy.SecondFactorCheckLifetime))
org.LoginPolicy.PasswordCheckLifetime = durationpb.New(time.Duration(defaultLoginPolicy.PasswordCheckLifetime))
org.LoginPolicy.MfaInitSkipLifetime = durationpb.New(time.Duration(defaultLoginPolicy.MFAInitSkipLifetime))
if orgV1.SecondFactors != nil {
org.LoginPolicy.SecondFactors = make([]policy.SecondFactorType, len(orgV1.SecondFactors))

View File

@@ -1,6 +1,8 @@
package policy
import (
"time"
"google.golang.org/protobuf/types/known/durationpb"
"google.golang.org/protobuf/types/known/timestamppb"
@@ -26,11 +28,11 @@ func ModelLoginPolicyToPb(policy *query.LoginPolicy) *policy_pb.LoginPolicy {
DisableLoginWithEmail: policy.DisableLoginWithEmail,
DisableLoginWithPhone: policy.DisableLoginWithPhone,
DefaultRedirectUri: policy.DefaultRedirectURI,
PasswordCheckLifetime: durationpb.New(policy.PasswordCheckLifetime),
ExternalLoginCheckLifetime: durationpb.New(policy.ExternalLoginCheckLifetime),
MfaInitSkipLifetime: durationpb.New(policy.MFAInitSkipLifetime),
SecondFactorCheckLifetime: durationpb.New(policy.SecondFactorCheckLifetime),
MultiFactorCheckLifetime: durationpb.New(policy.MultiFactorCheckLifetime),
PasswordCheckLifetime: durationpb.New(time.Duration(policy.PasswordCheckLifetime)),
ExternalLoginCheckLifetime: durationpb.New(time.Duration(policy.ExternalLoginCheckLifetime)),
MfaInitSkipLifetime: durationpb.New(time.Duration(policy.MFAInitSkipLifetime)),
SecondFactorCheckLifetime: durationpb.New(time.Duration(policy.SecondFactorCheckLifetime)),
MultiFactorCheckLifetime: durationpb.New(time.Duration(policy.MultiFactorCheckLifetime)),
SecondFactors: ModelSecondFactorTypesToPb(policy.SecondFactors),
MultiFactors: ModelMultiFactorTypesToPb(policy.MultiFactors),
Idps: idp_grpc.IDPLoginPolicyLinksToPb(policy.IDPLinks),

View File

@@ -1,6 +1,8 @@
package settings
import (
"time"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/zitadel/zitadel/internal/command"
@@ -32,11 +34,11 @@ func loginSettingsToPb(current *query.LoginPolicy) *settings.LoginSettings {
DisableLoginWithEmail: current.DisableLoginWithEmail,
DisableLoginWithPhone: current.DisableLoginWithPhone,
DefaultRedirectUri: current.DefaultRedirectURI,
PasswordCheckLifetime: durationpb.New(current.PasswordCheckLifetime),
ExternalLoginCheckLifetime: durationpb.New(current.ExternalLoginCheckLifetime),
MfaInitSkipLifetime: durationpb.New(current.MFAInitSkipLifetime),
SecondFactorCheckLifetime: durationpb.New(current.SecondFactorCheckLifetime),
MultiFactorCheckLifetime: durationpb.New(current.MultiFactorCheckLifetime),
PasswordCheckLifetime: durationpb.New(time.Duration(current.PasswordCheckLifetime)),
ExternalLoginCheckLifetime: durationpb.New(time.Duration(current.ExternalLoginCheckLifetime)),
MfaInitSkipLifetime: durationpb.New(time.Duration(current.MFAInitSkipLifetime)),
SecondFactorCheckLifetime: durationpb.New(time.Duration(current.SecondFactorCheckLifetime)),
MultiFactorCheckLifetime: durationpb.New(time.Duration(current.MultiFactorCheckLifetime)),
SecondFactors: second,
MultiFactors: multi,
ResourceOwnerType: isDefaultToResourceOwnerTypePb(current.IsDefault),

View File

@@ -13,6 +13,7 @@ import (
"github.com/zitadel/zitadel/internal/api/grpc"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query"
settings "github.com/zitadel/zitadel/pkg/grpc/settings/v2beta"
@@ -34,11 +35,11 @@ func Test_loginSettingsToPb(t *testing.T) {
DisableLoginWithEmail: true,
DisableLoginWithPhone: true,
DefaultRedirectURI: "example.com",
PasswordCheckLifetime: time.Hour,
ExternalLoginCheckLifetime: time.Minute,
MFAInitSkipLifetime: time.Millisecond,
SecondFactorCheckLifetime: time.Microsecond,
MultiFactorCheckLifetime: time.Nanosecond,
PasswordCheckLifetime: database.Duration(time.Hour),
ExternalLoginCheckLifetime: database.Duration(time.Minute),
MFAInitSkipLifetime: database.Duration(time.Millisecond),
SecondFactorCheckLifetime: database.Duration(time.Microsecond),
MultiFactorCheckLifetime: database.Duration(time.Nanosecond),
SecondFactors: []domain.SecondFactorType{
domain.SecondFactorTypeTOTP,
domain.SecondFactorTypeU2F,

View File

@@ -912,11 +912,11 @@ func queryLoginPolicyToDomain(policy *query.LoginPolicy) *domain.LoginPolicy {
IgnoreUnknownUsernames: policy.IgnoreUnknownUsernames,
AllowDomainDiscovery: policy.AllowDomainDiscovery,
DefaultRedirectURI: policy.DefaultRedirectURI,
PasswordCheckLifetime: policy.PasswordCheckLifetime,
ExternalLoginCheckLifetime: policy.ExternalLoginCheckLifetime,
MFAInitSkipLifetime: policy.MFAInitSkipLifetime,
SecondFactorCheckLifetime: policy.SecondFactorCheckLifetime,
MultiFactorCheckLifetime: policy.MultiFactorCheckLifetime,
PasswordCheckLifetime: time.Duration(policy.PasswordCheckLifetime),
ExternalLoginCheckLifetime: time.Duration(policy.ExternalLoginCheckLifetime),
MFAInitSkipLifetime: time.Duration(policy.MFAInitSkipLifetime),
SecondFactorCheckLifetime: time.Duration(policy.SecondFactorCheckLifetime),
MultiFactorCheckLifetime: time.Duration(policy.MultiFactorCheckLifetime),
DisableLoginWithEmail: policy.DisableLoginWithEmail,
DisableLoginWithPhone: policy.DisableLoginWithPhone,
}

View File

@@ -13,6 +13,7 @@ import (
cache "github.com/zitadel/zitadel/internal/auth_request/repository"
"github.com/zitadel/zitadel/internal/auth_request/repository/mock"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
@@ -518,8 +519,8 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
PasswordCheckLifetime: 10 * 24 * time.Hour,
SecondFactorCheckLifetime: 18 * time.Hour,
PasswordCheckLifetime: database.Duration(10 * 24 * time.Hour),
SecondFactorCheckLifetime: database.Duration(18 * time.Hour),
},
},
privacyPolicyProvider: &mockPrivacyPolicy{
@@ -820,7 +821,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
MultiFactorCheckLifetime: 10 * time.Hour,
MultiFactorCheckLifetime: database.Duration(10 * time.Hour),
},
},
idpUserLinksProvider: &mockIDPUserLinks{},
@@ -845,7 +846,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
MultiFactorCheckLifetime: 10 * time.Hour,
MultiFactorCheckLifetime: database.Duration(10 * time.Hour),
},
},
idpUserLinksProvider: &mockIDPUserLinks{},
@@ -871,7 +872,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
MultiFactorCheckLifetime: 10 * time.Hour,
MultiFactorCheckLifetime: database.Duration(10 * time.Hour),
},
},
idpUserLinksProvider: &mockIDPUserLinks{},
@@ -953,7 +954,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
orgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
SecondFactorCheckLifetime: 18 * time.Hour,
SecondFactorCheckLifetime: database.Duration(18 * time.Hour),
},
},
idpUserLinksProvider: &mockIDPUserLinks{},
@@ -986,7 +987,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
orgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
SecondFactorCheckLifetime: 18 * time.Hour,
SecondFactorCheckLifetime: database.Duration(18 * time.Hour),
},
},
idpUserLinksProvider: &mockIDPUserLinks{
@@ -1054,7 +1055,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
PasswordCheckLifetime: 10 * 24 * time.Hour,
PasswordCheckLifetime: database.Duration(10 * 24 * time.Hour),
},
},
idpUserLinksProvider: &mockIDPUserLinks{},
@@ -1591,7 +1592,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
SecondFactorCheckLifetime: 18 * time.Hour,
SecondFactorCheckLifetime: database.Duration(18 * time.Hour),
},
},
userEventProvider: &mockEventUser{},

View File

@@ -14,6 +14,7 @@ import (
"github.com/zitadel/zitadel/internal/crypto"
z_db "github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -452,7 +453,7 @@ type db struct {
func dbMock(t *testing.T, expectations ...func(m sqlmock.Sqlmock)) db {
t.Helper()
client, mock, err := sqlmock.New()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("unable to create sql mock: %v", err)
}
@@ -478,7 +479,7 @@ func expectQueryScanErr(stmt string, cols []string, rows [][]driver.Value, args
m.ExpectBegin()
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
@@ -496,7 +497,7 @@ func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driv
m.ExpectBegin()
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {

View File

@@ -6,7 +6,7 @@ import (
"strings"
"time"
_ "github.com/jackc/pgx/v4/stdlib"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/logging"

View File

@@ -19,6 +19,7 @@ type expectation func(m sqlmock.Sqlmock)
func NewSQLMock(t *testing.T, expectations ...expectation) *SQLMock {
db, mock, err := sqlmock.New(
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
sqlmock.ValueConverterOption(new(TypeConverter)),
)
if err != nil {
t.Fatal("create mock failed", err)
@@ -97,23 +98,23 @@ func ExcpectExec(stmt string, opts ...ExecOpt) expectation {
}
}
type QueryOpt func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
type QueryOpt func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery
func WithQueryArgs(args ...driver.Value) QueryOpt {
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return e.WithArgs(args...)
}
}
func WithQueryErr(err error) QueryOpt {
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return func(_ sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
return e.WillReturnError(err)
}
}
func WithQueryResult(columns []string, rows [][]driver.Value) QueryOpt {
return func(e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
mockedRows := sqlmock.NewRows(columns)
return func(m sqlmock.Sqlmock, e *sqlmock.ExpectedQuery) *sqlmock.ExpectedQuery {
mockedRows := m.NewRows(columns)
for _, row := range rows {
mockedRows = mockedRows.AddRow(row...)
}
@@ -125,7 +126,7 @@ func ExpectQuery(stmt string, opts ...QueryOpt) expectation {
return func(m sqlmock.Sqlmock) {
e := m.ExpectQuery(stmt)
for _, opt := range opts {
e = opt(e)
e = opt(m, e)
}
}
}

View File

@@ -0,0 +1,87 @@
package mock
import (
"database/sql/driver"
"encoding/hex"
"encoding/json"
"reflect"
"slices"
"strconv"
"strings"
)
var _ driver.ValueConverter = (*TypeConverter)(nil)
type TypeConverter struct{}
// ConvertValue converts a value to a driver Value.
func (s TypeConverter) ConvertValue(v any) (driver.Value, error) {
if driver.IsValue(v) {
return v, nil
}
value := reflect.ValueOf(v)
if rawMessage, ok := v.(json.RawMessage); ok {
return convertBytes(rawMessage), nil
}
if value.Kind() == reflect.Slice {
//nolint: exhaustive
// only defined types
switch value.Type().Elem().Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Int:
return convertSigned(value), nil
case reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uint:
return convertUnsigned(value), nil
case reflect.String:
return convertText(value), nil
}
}
return v, nil
}
// converts a text array to valid pgx v5 representation
func convertSigned(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = strconv.FormatInt(array.Index(i).Int(), 10)
}
return "{" + strings.Join(slice, ",") + "}"
}
// converts a text array to valid pgx v5 representation
func convertUnsigned(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = strconv.FormatUint(array.Index(i).Uint(), 10)
}
return "{" + strings.Join(slice, ",") + "}"
}
// converts a text array to valid pgx v5 representation
func convertText(array reflect.Value) string {
slice := make([]string, array.Len())
for i := 0; i < array.Len(); i++ {
slice[i] = array.Index(i).String()
}
return "{" + strings.Join(slice, ",") + "}"
}
func convertBytes(array []byte) string {
var builder strings.Builder
builder.Grow(hex.EncodedLen(len(array)) + 4)
builder.WriteString(`\x`)
builder.Write(AppendEncode(nil, array))
return builder.String()
}
// TODO: remove function after we compile using go 1.22 and use function of hex package `hex.AppendEncode`
func AppendEncode(dst, src []byte) []byte {
n := hex.EncodedLen(len(src))
dst = slices.Grow(dst, n)
hex.Encode(dst[len(dst):][:n], src)
return dst[:len(dst)+n]
}

View File

@@ -6,7 +6,7 @@ import (
"strings"
"time"
_ "github.com/jackc/pgx/v4/stdlib"
_ "github.com/jackc/pgx/v5/stdlib"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/logging"

View File

@@ -1,87 +1,166 @@
package database
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"reflect"
"strings"
"time"
"github.com/jackc/pgtype"
"github.com/jackc/pgx/v5/pgtype"
)
type TextArray[t ~string] []t
type TextArray[T ~string] pgtype.FlatArray[T]
// Scan implements the [database/sql.Scanner] interface.
func (s *TextArray[t]) Scan(src any) error {
array := new(pgtype.TextArray)
if err := array.Scan(src); err != nil {
func (s *TextArray[T]) Scan(src any) error {
var typedArray []string
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
if err != nil {
return err
}
return array.AssignTo(s)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (s TextArray[t]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
(*s) = make(TextArray[T], len(typedArray))
for i, value := range typedArray {
(*s)[i] = T(value)
}
array := pgtype.TextArray{}
if err := array.Set(s); err != nil {
return nil, err
}
return array.Value()
}
type arrayField interface {
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32
}
type Array[F arrayField] []F
// Scan implements the [database/sql.Scanner] interface.
func (a *Array[F]) Scan(src any) error {
array := new(pgtype.Int8Array)
if err := array.Scan(src); err != nil {
return err
}
elements := make([]int64, len(array.Elements))
if err := array.AssignTo(&elements); err != nil {
return err
}
*a = make([]F, len(elements))
for i, element := range elements {
(*a)[i] = F(element)
}
return nil
}
// Value implements the [database/sql/driver.Valuer] interface.
func (a Array[F]) Value() (driver.Value, error) {
if len(a) == 0 {
func (s TextArray[T]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
}
array := pgtype.Int8Array{}
if err := array.Set(a); err != nil {
return nil, err
typed := make([]string, len(s))
for i, value := range s {
typed[i] = string(value)
}
return array.Value()
return []byte("{" + strings.Join(typed, ",") + "}"), nil
}
type ByteArray[T ~byte] pgtype.FlatArray[T]
// Scan implements the [database/sql.Scanner] interface.
func (s *ByteArray[T]) Scan(src any) error {
var typedArray []byte
typedArray, ok := src.([]byte)
if !ok {
// tests use a different src type
err := pgtype.NewMap().SQLScanner(&typedArray).Scan(src)
if err != nil {
return err
}
}
(*s) = make(ByteArray[T], len(typedArray))
for i, value := range typedArray {
(*s)[i] = T(value)
}
return nil
}
// Value implements the [database/sql/driver.Valuer] interface.
func (s ByteArray[T]) Value() (driver.Value, error) {
if len(s) == 0 {
return nil, nil
}
typed := make([]byte, len(s))
for i, value := range s {
typed[i] = byte(value)
}
return typed, nil
}
type numberField interface {
~int8 | ~uint8 | ~int16 | ~uint16 | ~int32 | ~uint32 | ~int64 | ~uint64 | ~int | ~uint
}
type numberTypeField interface {
int8 | uint8 | int16 | uint16 | int32 | uint32 | int64 | uint64 | int | uint
}
var _ sql.Scanner = (*NumberArray[int8])(nil)
type NumberArray[F numberField] pgtype.FlatArray[F]
// Scan implements the [database/sql.Scanner] interface.
func (a *NumberArray[F]) Scan(src any) (err error) {
var (
mapper func()
scanner sql.Scanner
)
//nolint: exhaustive
// only defined types
switch reflect.TypeOf(*a).Elem().Kind() {
case reflect.Int8:
mapper, scanner = castedScan[int8](a)
case reflect.Uint8:
// we provide int16 is a workaround because pgx thinks we want to scan a byte array if we provide uint8
mapper, scanner = castedScan[int16](a)
case reflect.Int16:
mapper, scanner = castedScan[int16](a)
case reflect.Uint16:
mapper, scanner = castedScan[uint16](a)
case reflect.Int32:
mapper, scanner = castedScan[int32](a)
case reflect.Uint32:
mapper, scanner = castedScan[uint32](a)
case reflect.Int64:
mapper, scanner = castedScan[int64](a)
case reflect.Uint64:
mapper, scanner = castedScan[uint64](a)
case reflect.Int:
mapper, scanner = castedScan[int](a)
case reflect.Uint:
mapper, scanner = castedScan[uint](a)
}
if err = scanner.Scan(src); err != nil {
return err
}
mapper()
return nil
}
func castedScan[T numberTypeField, F numberField](a *NumberArray[F]) (mapper func(), scanner sql.Scanner) {
var typedArray []T
mapper = func() {
(*a) = make(NumberArray[F], len(typedArray))
for i, value := range typedArray {
(*a)[i] = F(value)
}
}
scanner = pgtype.NewMap().SQLScanner(&typedArray)
return mapper, scanner
}
type Map[V any] map[string]V
// Scan implements the [database/sql.Scanner] interface.
func (m *Map[V]) Scan(src any) error {
bytea := new(pgtype.Bytea)
if err := bytea.Scan(src); err != nil {
return err
}
if len(bytea.Bytes) == 0 {
if src == nil {
return nil
}
return json.Unmarshal(bytea.Bytes, &m)
bytes := src.([]byte)
if len(bytes) == 0 {
return nil
}
return json.Unmarshal(bytes, &m)
}
// Value implements the [database/sql/driver.Valuer] interface.
@@ -96,14 +175,35 @@ type Duration time.Duration
// Scan implements the [database/sql.Scanner] interface.
func (d *Duration) Scan(src any) error {
switch duration := src.(type) {
case *time.Duration:
*d = Duration(*duration)
return nil
case time.Duration:
*d = Duration(duration)
return nil
case *pgtype.Interval:
*d = intervalToDuration(duration)
return nil
case pgtype.Interval:
*d = intervalToDuration(&duration)
return nil
case int64:
*d = Duration(duration)
return nil
}
interval := new(pgtype.Interval)
if err := interval.Scan(src); err != nil {
return err
}
*d = Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
*d = intervalToDuration(interval)
return nil
}
func intervalToDuration(interval *pgtype.Interval) Duration {
return Duration(time.Duration(interval.Microseconds*1000) + time.Duration(interval.Days)*24*time.Hour + time.Duration(interval.Months)*30*24*time.Hour)
}
// NullDuration can be used for NULL intervals.
// If Valid is false, the scanned value was NULL
// This behavior is similar to [database/sql.NullString]

View File

@@ -1,9 +1,9 @@
package database
import (
"database/sql"
"database/sql/driver"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -11,7 +11,7 @@ import (
func TestMap_Scan(t *testing.T) {
type args struct {
src any
src []byte
}
type res[V any] struct {
want Map[V]
@@ -24,10 +24,19 @@ func TestMap_Scan(t *testing.T) {
res[V]
}
tests := []testCase[string]{
{
"nil",
Map[string]{},
args{src: nil},
res[string]{
want: Map[string]{},
err: false,
},
},
{
"null",
Map[string]{},
args{src: "invalid"},
args{src: []byte("invalid")},
res[string]{
want: Map[string]{},
err: true,
@@ -119,83 +128,109 @@ func TestMap_Value(t *testing.T) {
}
}
func TestNullDuration_Scan(t *testing.T) {
type typedInt int
func TestNumberArray_Scan(t *testing.T) {
type args struct {
src any
}
type res struct {
want NullDuration
want any
err bool
}
type testCase struct {
name string
m sql.Scanner
args args
res res
}
tests := []testCase{
{
"invalid",
args{src: "invalid"},
res{
want: NullDuration{
Valid: false,
},
err: true,
name: "typedInt",
m: new(NumberArray[typedInt]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[typedInt]{1, 2},
},
},
{
"null",
args{src: nil},
res{
want: NullDuration{
Valid: false,
},
err: false,
name: "int8",
m: new(NumberArray[int8]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int8]{1, 2},
},
},
{
"valid",
args{src: "1:0:0"},
res{
want: NullDuration{
Valid: true,
Duration: time.Hour,
},
name: "uint8",
m: new(NumberArray[uint8]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint8]{1, 2},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
d := new(NullDuration)
if err := d.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, *d)
})
}
}
func TestArray_ScanInt32(t *testing.T) {
type args struct {
src any
}
type res[V arrayField] struct {
want Array[V]
err bool
}
type testCase[V arrayField] struct {
name string
m Array[V]
args args
res[V]
}
tests := []testCase[int32]{
{
"number",
Array[int32]{},
args{src: "{1,2}"},
res[int32]{
want: []int32{1, 2},
name: "int16",
m: new(NumberArray[int16]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int16]{1, 2},
},
},
{
name: "uint16",
m: new(NumberArray[uint16]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint16]{1, 2},
},
},
{
name: "int32",
m: new(NumberArray[int32]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int32]{1, 2},
},
},
{
name: "uint32",
m: new(NumberArray[uint32]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint32]{1, 2},
},
},
{
name: "int64",
m: new(NumberArray[int64]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int64]{1, 2},
},
},
{
name: "uint64",
m: new(NumberArray[uint64]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint64]{1, 2},
},
},
{
name: "int",
m: new(NumberArray[int]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[int]{1, 2},
},
},
{
name: "uint",
m: new(NumberArray[uint]),
args: args{src: "{1,2}"},
res: res{
want: &NumberArray[uint]{1, 2},
},
},
}
@@ -210,42 +245,80 @@ func TestArray_ScanInt32(t *testing.T) {
}
}
func TestArray_Value(t *testing.T) {
type typedText string
func TestTextArray_Scan(t *testing.T) {
type args struct {
src any
}
type res struct {
want sql.Scanner
err bool
}
type testCase struct {
name string
m sql.Scanner
args args
res
}
tests := []testCase{
{
"string",
new(TextArray[string]),
args{src: "{asdf,fdas}"},
res{
want: &TextArray[string]{"asdf", "fdas"},
},
},
{
"typedText",
new(TextArray[typedText]),
args{src: "{asdf,fdas}"},
res{
want: &TextArray[typedText]{"asdf", "fdas"},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
func TestTextArray_Value(t *testing.T) {
type res struct {
want driver.Value
err bool
}
type testCase[V arrayField] struct {
type testCase struct {
name string
a Array[V]
m driver.Valuer
res res
}
tests := []testCase[int32]{
{
"nil",
nil,
res{
want: nil,
},
},
tests := []testCase{
{
"empty",
Array[int32]{},
TextArray[string]{},
res{
want: nil,
},
},
{
"set",
Array[int32]([]int32{1, 2}),
TextArray[string]{"a", "s", "d", "f"},
res{
want: driver.Value(string([]byte(`{1,2}`))),
want: driver.Value([]byte("{a,s,d,f}")),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.a.Value()
got, err := tt.m.Value()
if tt.res.err {
assert.Error(t, err)
}
@@ -256,3 +329,126 @@ func TestArray_Value(t *testing.T) {
})
}
}
type typedByte byte
func TestByteArray_Scan(t *testing.T) {
wantedBytes := []byte("asdf")
wantedTypedBytes := []typedByte("asdf")
type args struct {
src any
}
type res struct {
want sql.Scanner
err bool
}
type testCase struct {
name string
m sql.Scanner
args args
res
}
tests := []testCase{
{
"bytes",
new(ByteArray[byte]),
args{src: []byte("asdf")},
res{
want: (*ByteArray[byte])(&wantedBytes),
},
},
{
"typed",
new(ByteArray[typedByte]),
args{src: []byte("asdf")},
res{
want: (*ByteArray[typedByte])(&wantedTypedBytes),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}
func TestByteArray_Value(t *testing.T) {
type res struct {
want driver.Value
err bool
}
type testCase struct {
name string
m driver.Valuer
res res
}
tests := []testCase{
{
"empty",
ByteArray[byte]{},
res{
want: nil,
},
},
{
"set",
ByteArray[byte]([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
res{
want: driver.Value([]byte("{\"type\": \"object\", \"$schema\": \"urn:zitadel:schema:v1\"}")),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := tt.m.Value()
if tt.res.err {
assert.Error(t, err)
}
if !tt.res.err {
require.NoError(t, err)
assert.Equalf(t, tt.res.want, got, "Value()")
}
})
}
}
func TestDuration_Scan(t *testing.T) {
duration := Duration(10)
type args struct {
src any
}
type res struct {
want sql.Scanner
err bool
}
type testCase[V ~string] struct {
name string
m sql.Scanner
args args
res
}
tests := []testCase[string]{
{
name: "int64",
m: new(Duration),
args: args{src: int64(duration)},
res: res{
want: &duration,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.m.Scan(tt.args.src); (err != nil) != tt.res.err {
t.Errorf("Scan() error = %v, wantErr %v", err, tt.res.err)
}
assert.Equal(t, tt.res.want, tt.m)
})
}
}

View File

@@ -7,7 +7,7 @@ import (
"sync"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"

View File

@@ -8,7 +8,7 @@ import (
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/service"

View File

@@ -11,6 +11,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -99,7 +100,7 @@ func TestStatementHandler_handleLock(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, mock, err := sqlmock.New()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatal(err)
}
@@ -209,7 +210,7 @@ func TestStatementHandler_renewLock(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, mock, err := sqlmock.New()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatal(err)
}
@@ -283,7 +284,7 @@ func TestStatementHandler_Unlock(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
client, mock, err := sqlmock.New()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatal(err)
}

View File

@@ -10,7 +10,7 @@ import (
"sync"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
@@ -331,7 +331,7 @@ func (ai *existingInstances) AppendEvents(events ...eventstore.Event) {
case instance.InstanceAddedEventType:
*ai = append(*ai, event.Aggregate().InstanceID)
case instance.InstanceRemovedEventType:
slices.DeleteFunc(*ai, func(s string) bool {
*ai = slices.DeleteFunc(*ai, func(s string) bool {
return s == event.Aggregate().InstanceID
})
}

View File

@@ -6,7 +6,7 @@ import (
"fmt"
"strings"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore/handler"

View File

@@ -10,10 +10,11 @@ import (
"testing"
"time"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -213,11 +214,11 @@ func TestHandler_updateLastUpdated(t *testing.T) {
"projection",
"instance",
"aggregate id",
"aggregate type",
eventstore.AggregateType("aggregate type"),
uint64(42),
mock.AnyType[time.Time]{},
float64(42),
uint16(0),
uint32(0),
),
mock.WithExecRowsAffected(1),
),

View File

@@ -227,7 +227,7 @@ func instanceIDsFilter(builder *eventstore.SearchQueryBuilder, query *SearchQuer
if builder.GetInstanceIDs() == nil {
return nil
}
query.InstanceIDs = NewFilter(FieldInstanceID, builder.GetInstanceIDs(), OperationIn)
query.InstanceIDs = NewFilter(FieldInstanceID, database.TextArray[string](builder.GetInstanceIDs()), OperationIn)
return query.InstanceIDs
}
@@ -256,11 +256,7 @@ func eventTypeFilter(query *eventstore.SearchQuery) *Filter {
if len(query.GetEventTypes()) == 1 {
return NewFilter(FieldEventType, query.GetEventTypes()[0], OperationEquals)
}
eventTypes := make(database.TextArray[eventstore.EventType], len(query.GetEventTypes()))
for i, eventType := range query.GetEventTypes() {
eventTypes[i] = eventType
}
return NewFilter(FieldEventType, eventTypes, OperationIn)
return NewFilter(FieldEventType, database.TextArray[eventstore.EventType](query.GetEventTypes()), OperationIn)
}
func aggregateTypeFilter(query *eventstore.SearchQuery) *Filter {
@@ -270,11 +266,7 @@ func aggregateTypeFilter(query *eventstore.SearchQuery) *Filter {
if len(query.GetAggregateTypes()) == 1 {
return NewFilter(FieldAggregateType, query.GetAggregateTypes()[0], OperationEquals)
}
aggregateTypes := make(database.TextArray[eventstore.AggregateType], len(query.GetAggregateTypes()))
for i, aggregateType := range query.GetAggregateTypes() {
aggregateTypes[i] = aggregateType
}
return NewFilter(FieldAggregateType, aggregateTypes, OperationIn)
return NewFilter(FieldAggregateType, database.TextArray[eventstore.AggregateType](query.GetAggregateTypes()), OperationIn)
}
func eventDataFilter(query *eventstore.SearchQuery) *Filter {

View File

@@ -10,8 +10,7 @@ import (
"strings"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/jackc/pgconn"
"github.com/lib/pq"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
@@ -438,11 +437,6 @@ func (db *CRDB) placeholder(query string) string {
}
func (db *CRDB) isUniqueViolationError(err error) bool {
if pqErr, ok := err.(*pq.Error); ok {
if pqErr.Code == "23505" {
return true
}
}
if pgxErr, ok := err.(*pgconn.PgError); ok {
if pgxErr.Code == "23505" {
return true

View File

@@ -14,6 +14,7 @@ import (
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/cockroach"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
"github.com/zitadel/zitadel/internal/zerrors"
@@ -872,7 +873,7 @@ func (m *dbMock) expectQuery(t *testing.T, expectedQuery string, args []driver.V
m.mock.ExpectBegin()
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
m.mock.ExpectCommit()
rows := sqlmock.NewRows([]string{"sequence"})
rows := m.mock.NewRows([]string{"sequence"})
for _, event := range events {
rows = rows.AddRow(event.Seq)
}
@@ -884,7 +885,7 @@ func (m *dbMock) expectQueryScanErr(t *testing.T, expectedQuery string, args []d
m.mock.ExpectBegin()
query := m.mock.ExpectQuery(expectedQuery).WithArgs(args...)
m.mock.ExpectRollback()
rows := sqlmock.NewRows([]string{"sequence"})
rows := m.mock.NewRows([]string{"sequence"})
for _, event := range events {
rows = rows.AddRow(event.Seq)
}
@@ -900,7 +901,7 @@ func (m *dbMock) expectQueryErr(t *testing.T, expectedQuery string, args []drive
func newMockClient(t *testing.T) *dbMock {
t.Helper()
db, mock, err := sqlmock.New()
db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Errorf("unable to create mock client: %v", err)
t.FailNow()

View File

@@ -11,7 +11,7 @@ import (
"sync"
"github.com/cockroachdb/cockroach-go/v2/crdb"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"

View File

@@ -8,7 +8,7 @@ import (
"fmt"
"strings"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"

View File

@@ -42,8 +42,8 @@ type App struct {
type OIDCApp struct {
RedirectURIs database.TextArray[string]
ResponseTypes database.Array[domain.OIDCResponseType]
GrantTypes database.Array[domain.OIDCGrantType]
ResponseTypes database.NumberArray[domain.OIDCResponseType]
GrantTypes database.NumberArray[domain.OIDCGrantType]
AppType domain.OIDCApplicationType
ClientID string
AuthMethodType domain.OIDCAuthMethodType
@@ -835,8 +835,8 @@ type sqlOIDCConfig struct {
iDTokenUserinfoAssertion sql.NullBool
clockSkew sql.NullInt64
additionalOrigins database.TextArray[string]
responseTypes database.Array[domain.OIDCResponseType]
grantTypes database.Array[domain.OIDCGrantType]
responseTypes database.NumberArray[domain.OIDCResponseType]
grantTypes database.NumberArray[domain.OIDCGrantType]
skipNativeAppSuccessPage sql.NullBool
}

View File

@@ -421,8 +421,8 @@ func Test_AppsPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -461,8 +461,8 @@ func Test_AppsPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -507,8 +507,8 @@ func Test_AppsPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -547,8 +547,8 @@ func Test_AppsPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -593,8 +593,8 @@ func Test_AppsPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -633,8 +633,8 @@ func Test_AppsPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -679,8 +679,8 @@ func Test_AppsPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -719,8 +719,8 @@ func Test_AppsPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -765,8 +765,8 @@ func Test_AppsPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -805,8 +805,8 @@ func Test_AppsPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -851,8 +851,8 @@ func Test_AppsPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeNative,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -891,8 +891,8 @@ func Test_AppsPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeNative,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -937,8 +937,8 @@ func Test_AppsPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -1051,8 +1051,8 @@ func Test_AppsPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -1120,6 +1120,9 @@ func Test_AppsPrepare(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if tt.name == "prepareAppsQuery oidc app" {
_ = tt.name
}
assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err, defaultPrepareArgs...)
})
}
@@ -1300,8 +1303,8 @@ func Test_AppPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -1335,8 +1338,8 @@ func Test_AppPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -1442,8 +1445,8 @@ func Test_AppPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -1477,8 +1480,8 @@ func Test_AppPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -1521,8 +1524,8 @@ func Test_AppPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -1556,8 +1559,8 @@ func Test_AppPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -1600,8 +1603,8 @@ func Test_AppPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -1635,8 +1638,8 @@ func Test_AppPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},
@@ -1679,8 +1682,8 @@ func Test_AppPrepare(t *testing.T) {
domain.OIDCVersionV1,
"oidc-client-id",
database.TextArray[string]{"https://redirect.to/me"},
database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
domain.OIDCApplicationTypeUserAgent,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"post.logout.ch"},
@@ -1714,8 +1717,8 @@ func Test_AppPrepare(t *testing.T) {
Version: domain.OIDCVersionV1,
ClientID: "oidc-client-id",
RedirectURIs: database.TextArray[string]{"https://redirect.to/me"},
ResponseTypes: database.Array[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.Array[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
ResponseTypes: database.NumberArray[domain.OIDCResponseType]{domain.OIDCResponseTypeIDTokenToken},
GrantTypes: database.NumberArray[domain.OIDCGrantType]{domain.OIDCGrantTypeImplicit},
AppType: domain.OIDCApplicationTypeUserAgent,
AuthMethodType: domain.OIDCAuthMethodTypeNone,
PostLogoutRedirectURIs: database.TextArray[string]{"post.logout.ch"},

View File

@@ -61,7 +61,7 @@ func (q *Queries) AuthRequestByID(ctx context.Context, shouldTriggerBulk bool, i
var (
scope database.TextArray[string]
prompt database.Array[domain.Prompt]
prompt database.NumberArray[domain.Prompt]
locales database.TextArray[string]
)

View File

@@ -65,7 +65,7 @@ func TestQueries_AuthRequestByID(t *testing.T) {
"clientID",
database.TextArray[string]{"a", "b", "c"},
"example.com",
database.Array[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.NumberArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.TextArray[string]{"en", "fi"},
"me@example.com",
int64(time.Minute),
@@ -99,11 +99,11 @@ func TestQueries_AuthRequestByID(t *testing.T) {
"clientID",
database.TextArray[string]{"a", "b", "c"},
"example.com",
database.Array[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.NumberArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.TextArray[string]{"en", "fi"},
sql.NullString{},
sql.NullInt64{},
sql.NullString{},
nil,
nil,
nil,
}, "123", "instanceID"),
want: &AuthRequest{
ID: "id",
@@ -151,11 +151,11 @@ func TestQueries_AuthRequestByID(t *testing.T) {
"clientID",
database.TextArray[string]{"a", "b", "c"},
"example.com",
database.Array[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.NumberArray[domain.Prompt]{domain.PromptLogin, domain.PromptConsent},
database.TextArray[string]{"en", "fi"},
sql.NullString{},
sql.NullInt64{},
sql.NullString{},
nil,
nil,
nil,
}, "123", "instanceID"),
wantErr: zerrors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.AuthRequest.WrongLoginClient"),
},

View File

@@ -17,6 +17,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/repository/deviceauth"
@@ -188,7 +189,7 @@ var (
)
func TestQueries_DeviceAuthRequestByUserCode(t *testing.T) {
client, mock, err := sqlmock.New()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
}
@@ -196,7 +197,7 @@ func TestQueries_DeviceAuthRequestByUserCode(t *testing.T) {
mock.ExpectBegin()
mock.ExpectQuery(expectedDeviceAuthWhereUserCodeQuery).WillReturnRows(
sqlmock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
mock.NewRows(deviceAuthSelectColumns).AddRow(expectedDeviceAuthValues...),
)
mock.ExpectCommit()
q := Queries{

View File

@@ -30,8 +30,8 @@ type LoginPolicy struct {
AllowExternalIDPs bool
ForceMFA bool
ForceMFALocalOnly bool
SecondFactors database.Array[domain.SecondFactorType]
MultiFactors database.Array[domain.MultiFactorType]
SecondFactors database.NumberArray[domain.SecondFactorType]
MultiFactors database.NumberArray[domain.MultiFactorType]
PasswordlessType domain.PasswordlessType
IsDefault bool
HidePasswordReset bool
@@ -40,22 +40,22 @@ type LoginPolicy struct {
DisableLoginWithEmail bool
DisableLoginWithPhone bool
DefaultRedirectURI string
PasswordCheckLifetime time.Duration
ExternalLoginCheckLifetime time.Duration
MFAInitSkipLifetime time.Duration
SecondFactorCheckLifetime time.Duration
MultiFactorCheckLifetime time.Duration
PasswordCheckLifetime database.Duration
ExternalLoginCheckLifetime database.Duration
MFAInitSkipLifetime database.Duration
SecondFactorCheckLifetime database.Duration
MultiFactorCheckLifetime database.Duration
IDPLinks []*IDPLoginPolicyLink
}
type SecondFactors struct {
SearchResponse
Factors database.Array[domain.SecondFactorType]
Factors database.NumberArray[domain.SecondFactorType]
}
type MultiFactors struct {
SearchResponse
Factors database.Array[domain.MultiFactorType]
Factors database.NumberArray[domain.MultiFactorType]
}
var (

View File

@@ -84,6 +84,7 @@ var (
)
func Test_LoginPolicyPrepares(t *testing.T) {
duration := 2 * time.Hour
type want struct {
sqlExpectations sqlExpectation
err checkErr
@@ -129,8 +130,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
true,
true,
true,
database.Array[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
database.Array[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
database.NumberArray[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
database.NumberArray[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
domain.PasswordlessTypeAllowed,
true,
true,
@@ -139,11 +140,11 @@ func Test_LoginPolicyPrepares(t *testing.T) {
true,
true,
"https://example.com/redirect",
time.Hour * 2,
time.Hour * 2,
time.Hour * 2,
time.Hour * 2,
time.Hour * 2,
&duration,
&duration,
&duration,
&duration,
&duration,
},
),
},
@@ -157,8 +158,8 @@ func Test_LoginPolicyPrepares(t *testing.T) {
AllowExternalIDPs: true,
ForceMFA: true,
ForceMFALocalOnly: true,
SecondFactors: database.Array[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
MultiFactors: database.Array[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
SecondFactors: database.NumberArray[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
MultiFactors: database.NumberArray[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
PasswordlessType: domain.PasswordlessTypeAllowed,
IsDefault: true,
HidePasswordReset: true,
@@ -167,11 +168,11 @@ func Test_LoginPolicyPrepares(t *testing.T) {
DisableLoginWithEmail: true,
DisableLoginWithPhone: true,
DefaultRedirectURI: "https://example.com/redirect",
PasswordCheckLifetime: time.Hour * 2,
ExternalLoginCheckLifetime: time.Hour * 2,
MFAInitSkipLifetime: time.Hour * 2,
SecondFactorCheckLifetime: time.Hour * 2,
MultiFactorCheckLifetime: time.Hour * 2,
PasswordCheckLifetime: database.Duration(duration),
ExternalLoginCheckLifetime: database.Duration(duration),
MFAInitSkipLifetime: database.Duration(duration),
SecondFactorCheckLifetime: database.Duration(duration),
MultiFactorCheckLifetime: database.Duration(duration),
},
},
{
@@ -217,7 +218,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
regexp.QuoteMeta(prepareLoginPolicy2FAsStmt),
prepareLoginPolicy2FAsCols,
[]driver.Value{
database.Array[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
database.NumberArray[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
},
),
},
@@ -225,7 +226,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
SearchResponse: SearchResponse{
Count: 1,
},
Factors: database.Array[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
Factors: database.NumberArray[domain.SecondFactorType]{domain.SecondFactorTypeTOTP},
},
},
{
@@ -236,11 +237,11 @@ func Test_LoginPolicyPrepares(t *testing.T) {
regexp.QuoteMeta(prepareLoginPolicy2FAsStmt),
prepareLoginPolicy2FAsCols,
[]driver.Value{
database.Array[domain.SecondFactorType]{},
database.NumberArray[domain.SecondFactorType]{},
},
),
},
object: &SecondFactors{Factors: database.Array[domain.SecondFactorType]{}},
object: &SecondFactors{Factors: database.NumberArray[domain.SecondFactorType]{}},
},
{
name: "prepareLoginPolicy2FAsQuery sql err",
@@ -285,7 +286,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
regexp.QuoteMeta(prepareLoginPolicyMFAsStmt),
prepareLoginPolicyMFAsCols,
[]driver.Value{
database.Array[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
database.NumberArray[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
},
),
},
@@ -293,7 +294,7 @@ func Test_LoginPolicyPrepares(t *testing.T) {
SearchResponse: SearchResponse{
Count: 1,
},
Factors: database.Array[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
Factors: database.NumberArray[domain.MultiFactorType]{domain.MultiFactorTypeU2FWithPIN},
},
},
{
@@ -304,11 +305,11 @@ func Test_LoginPolicyPrepares(t *testing.T) {
regexp.QuoteMeta(prepareLoginPolicyMFAsStmt),
prepareLoginPolicyMFAsCols,
[]driver.Value{
database.Array[domain.MultiFactorType]{},
database.NumberArray[domain.MultiFactorType]{},
},
),
},
object: &MultiFactors{Factors: database.Array[domain.MultiFactorType]{}},
object: &MultiFactors{Factors: database.NumberArray[domain.MultiFactorType]{}},
},
{
name: "prepareLoginPolicyMFAsQuery sql err",

View File

@@ -12,6 +12,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -405,7 +406,10 @@ func TestQueries_IsOrgUnique(t *testing.T) {
},
}
for _, tt := range tests {
client, mock, err := sqlmock.New(sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual))
client, mock, err := sqlmock.New(
sqlmock.QueryMatcherOption(sqlmock.QueryMatcherEqual),
sqlmock.ValueConverterOption(new(db_mock.TypeConverter)),
)
if err != nil {
t.Fatalf("unable to mock db: %v", err)
}

View File

@@ -13,11 +13,11 @@ import (
"github.com/DATA-DOG/go-sqlmock"
sq "github.com/Masterminds/squirrel"
"github.com/jackc/pgtype"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
)
var (
@@ -35,7 +35,7 @@ var (
func assertPrepare(t *testing.T, prepareFunc, expectedObject interface{}, sqlExpectation sqlExpectation, isErr checkErr, prepareArgs ...reflect.Value) bool {
t.Helper()
client, mock, err := sqlmock.New()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("failed to build mock client: %v", err)
}
@@ -85,7 +85,7 @@ func mockQuery(stmt string, cols []string, row []driver.Value, args ...driver.Va
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
result := m.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
}
@@ -99,7 +99,7 @@ func mockQueryScanErr(stmt string, cols []string, row []driver.Value, args ...dr
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
result := m.NewRows(cols)
if len(row) > 0 {
result.AddRow(row...)
}
@@ -113,7 +113,7 @@ func mockQueries(stmt string, cols []string, rows [][]driver.Value, args ...driv
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectCommit()
result := sqlmock.NewRows(cols)
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
@@ -132,7 +132,7 @@ func mockQueriesScanErr(stmt string, cols []string, rows [][]driver.Value, args
m.ExpectBegin()
q := m.ExpectQuery(stmt).WithArgs(args...)
m.ExpectRollback()
result := sqlmock.NewRows(cols)
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {
@@ -157,7 +157,7 @@ func mockQueryErr(stmt string, err error, args ...driver.Value) func(m sqlmock.S
}
func execMock(t testing.TB, exp sqlExpectation, run func(db *sql.DB)) {
db, mock, err := sqlmock.New()
db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
require.NoError(t, err)
defer db.Close()
mock = exp(mock)
@@ -172,6 +172,8 @@ var (
)
func execScan(t testing.TB, client *database.DB, builder sq.SelectBuilder, scan interface{}, errCheck checkErr) (object interface{}, ok bool, didScan bool) {
t.Helper()
scanType := reflect.TypeOf(scan)
err := validateScan(scanType)
if err != nil {
@@ -388,15 +390,6 @@ func TestValidatePrepare(t *testing.T) {
}
}
func intervalDriverValue(t *testing.T, src time.Duration) pgtype.Interval {
interval := pgtype.Interval{}
err := interval.Set(src)
if err != nil {
t.Fatal(err)
}
return interval
}
type prepareDB struct{}
const asOfSystemTime = " AS OF SYSTEM TIME '-1 ms' "

View File

@@ -451,8 +451,8 @@ func (p *appProjection) reduceOIDCConfigAdded(event eventstore.Event) (*handler.
handler.NewCol(AppOIDCConfigColumnClientID, e.ClientID),
handler.NewCol(AppOIDCConfigColumnClientSecret, e.ClientSecret),
handler.NewCol(AppOIDCConfigColumnRedirectUris, database.TextArray[string](e.RedirectUris)),
handler.NewCol(AppOIDCConfigColumnResponseTypes, database.Array[domain.OIDCResponseType](e.ResponseTypes)),
handler.NewCol(AppOIDCConfigColumnGrantTypes, database.Array[domain.OIDCGrantType](e.GrantTypes)),
handler.NewCol(AppOIDCConfigColumnResponseTypes, database.NumberArray[domain.OIDCResponseType](e.ResponseTypes)),
handler.NewCol(AppOIDCConfigColumnGrantTypes, database.NumberArray[domain.OIDCGrantType](e.GrantTypes)),
handler.NewCol(AppOIDCConfigColumnApplicationType, e.ApplicationType),
handler.NewCol(AppOIDCConfigColumnAuthMethodType, e.AuthMethodType),
handler.NewCol(AppOIDCConfigColumnPostLogoutRedirectUris, database.TextArray[string](e.PostLogoutRedirectUris)),
@@ -494,10 +494,10 @@ func (p *appProjection) reduceOIDCConfigChanged(event eventstore.Event) (*handle
cols = append(cols, handler.NewCol(AppOIDCConfigColumnRedirectUris, database.TextArray[string](*e.RedirectUris)))
}
if e.ResponseTypes != nil {
cols = append(cols, handler.NewCol(AppOIDCConfigColumnResponseTypes, database.Array[domain.OIDCResponseType](*e.ResponseTypes)))
cols = append(cols, handler.NewCol(AppOIDCConfigColumnResponseTypes, database.NumberArray[domain.OIDCResponseType](*e.ResponseTypes)))
}
if e.GrantTypes != nil {
cols = append(cols, handler.NewCol(AppOIDCConfigColumnGrantTypes, database.Array[domain.OIDCGrantType](*e.GrantTypes)))
cols = append(cols, handler.NewCol(AppOIDCConfigColumnGrantTypes, database.NumberArray[domain.OIDCGrantType](*e.GrantTypes)))
}
if e.ApplicationType != nil {
cols = append(cols, handler.NewCol(AppOIDCConfigColumnApplicationType, *e.ApplicationType))

View File

@@ -455,8 +455,8 @@ func TestAppProjection_reduces(t *testing.T) {
"client-id",
anyArg{},
database.TextArray[string]{"redirect.one.ch", "redirect.two.ch"},
database.Array[domain.OIDCResponseType]{1, 2},
database.Array[domain.OIDCGrantType]{1, 2},
database.NumberArray[domain.OIDCResponseType]{1, 2},
database.NumberArray[domain.OIDCGrantType]{1, 2},
domain.OIDCApplicationTypeNative,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"logout.one.ch", "logout.two.ch"},
@@ -522,8 +522,8 @@ func TestAppProjection_reduces(t *testing.T) {
expectedArgs: []interface{}{
domain.OIDCVersionV1,
database.TextArray[string]{"redirect.one.ch", "redirect.two.ch"},
database.Array[domain.OIDCResponseType]{1, 2},
database.Array[domain.OIDCGrantType]{1, 2},
database.NumberArray[domain.OIDCResponseType]{1, 2},
database.NumberArray[domain.OIDCGrantType]{1, 2},
domain.OIDCApplicationTypeNative,
domain.OIDCAuthMethodTypeNone,
database.TextArray[string]{"logout.one.ch", "logout.two.ch"},

View File

@@ -10,6 +10,7 @@ import (
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/database"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/repository/instance"
@@ -384,15 +385,15 @@ func Test_quotaProjection_IncrementUsage(t *testing.T) {
name: "",
fields: fields{
client: func() *database.DB {
db, mock, _ := sqlmock.New()
db, mock, _ := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
mock.ExpectQuery(regexp.QuoteMeta(incrementQuotaStatement)).
WithArgs(
"instance_id",
1,
quota.Unit(1),
testNow,
2,
uint64(2),
).
WillReturnRows(sqlmock.NewRows([]string{"key"}).
WillReturnRows(mock.NewRows([]string{"key"}).
AddRow(3))
return &database.DB{DB: db}
}(),

View File

@@ -92,7 +92,7 @@ func prepareQuotaQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilde
From(quotasTable.identifier()).
PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*Quota, error) {
q := new(Quota)
var interval database.Duration
var interval database.NullDuration
var now time.Time
err := row.Scan(&q.ID, &q.From, &interval, &q.Amount, &q.Limit, &now)
if err != nil {
@@ -101,7 +101,7 @@ func prepareQuotaQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilde
}
return nil, zerrors.ThrowInternal(err, "QUERY-LqySK", "Errors.Internal")
}
q.ResetInterval = time.Duration(interval)
q.ResetInterval = interval.Duration
q.CurrentPeriodStart = pushPeriodStart(q.From, q.ResetInterval, now)
return q, nil
}

View File

@@ -9,6 +9,8 @@ import (
"testing"
"time"
"github.com/jackc/pgx/v5/pgtype"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -70,7 +72,9 @@ func Test_QuotaPrepare(t *testing.T) {
[]driver.Value{
"quota-id",
dayNow,
intervalDriverValue(t, time.Hour*24),
&pgtype.Interval{
Days: 1,
},
uint64(1000),
true,
testNow,

View File

@@ -32,7 +32,7 @@ type UserSchema struct {
Type string
Revision uint32
Schema json.RawMessage
PossibleAuthenticators database.Array[domain.AuthenticatorType]
PossibleAuthenticators database.NumberArray[domain.AuthenticatorType]
}
type UserSchemaSearchQueries struct {
@@ -144,6 +144,7 @@ func prepareUserSchemaQuery() (sq.SelectBuilder, func(*sql.Row) (*UserSchema, er
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*UserSchema, error) {
u := new(UserSchema)
var schema database.ByteArray[byte]
err := row.Scan(
&u.ID,
&u.EventDate,
@@ -152,16 +153,19 @@ func prepareUserSchemaQuery() (sq.SelectBuilder, func(*sql.Row) (*UserSchema, er
&u.State,
&u.Type,
&u.Revision,
&u.Schema,
&schema,
&u.PossibleAuthenticators,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-SAF3t", "Errors.Metadata.NotFound")
return nil, zerrors.ThrowNotFound(err, "QUERY-SAF3t", "Errors.UserSchema.NotExists")
}
return nil, zerrors.ThrowInternal(err, "QUERY-WRB2Q", "Errors.Internal")
}
u.Schema = json.RawMessage(schema)
return u, nil
}
}
@@ -181,8 +185,12 @@ func prepareUserSchemasQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserSchemas,
From(userSchemaTable.identifier()).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*UserSchemas, error) {
schema := make([]*UserSchema, 0)
var count uint64
schemas := make([]*UserSchema, 0)
var (
schema database.ByteArray[byte]
count uint64
)
for rows.Next() {
u := new(UserSchema)
err := rows.Scan(
@@ -193,7 +201,7 @@ func prepareUserSchemasQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserSchemas,
&u.State,
&u.Type,
&u.Revision,
&u.Schema,
&schema,
&u.PossibleAuthenticators,
&count,
)
@@ -201,7 +209,8 @@ func prepareUserSchemasQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserSchemas,
return nil, err
}
schema = append(schema, u)
u.Schema = json.RawMessage(schema)
schemas = append(schemas, u)
}
if err := rows.Close(); err != nil {
@@ -209,7 +218,7 @@ func prepareUserSchemasQuery() (sq.SelectBuilder, func(*sql.Rows) (*UserSchemas,
}
return &UserSchemas{
UserSchemas: schema,
UserSchemas: schemas,
SearchResponse: SearchResponse{
Count: count,
},

View File

@@ -102,7 +102,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
"type",
1,
json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
},
),
@@ -123,7 +123,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
Type: "type",
Revision: 1,
Schema: json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
PossibleAuthenticators: database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
PossibleAuthenticators: database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
},
},
@@ -145,7 +145,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
"type1",
1,
json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
{
"id-2",
@@ -156,7 +156,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
"type2",
2,
json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
},
),
@@ -177,7 +177,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
Type: "type1",
Revision: 1,
Schema: json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
PossibleAuthenticators: database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
PossibleAuthenticators: database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
{
ID: "id-2",
@@ -190,7 +190,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
Type: "type2",
Revision: 2,
Schema: json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
PossibleAuthenticators: database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
PossibleAuthenticators: database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
},
},
@@ -246,7 +246,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
"type",
1,
json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
),
},
@@ -261,7 +261,7 @@ func Test_UserSchemaPrepares(t *testing.T) {
Type: "type",
Revision: 1,
Schema: json.RawMessage(`{"$schema":"urn:zitadel:schema:v1","properties":{"name":{"type":"string","urn:zitadel:schema:permission":{"self":"rw"}}},"type":"object"}`),
PossibleAuthenticators: database.Array[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
PossibleAuthenticators: database.NumberArray[domain.AuthenticatorType]{domain.AuthenticatorTypeUsername, domain.AuthenticatorTypePassword},
},
},
{

View File

@@ -65,7 +65,7 @@ func TestQueries_GetOIDCUserInfo(t *testing.T) {
args: args{
userID: "231965491734773762",
},
mock: mockQueryErr(expQuery, sql.ErrConnDone, "231965491734773762", "instanceID", nil),
mock: mockQueryErr(expQuery, sql.ErrConnDone, "231965491734773762", "instanceID", database.TextArray[string](nil)),
wantErr: sql.ErrConnDone,
},
{
@@ -73,7 +73,7 @@ func TestQueries_GetOIDCUserInfo(t *testing.T) {
args: args{
userID: "231965491734773762",
},
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoNotFound}, "231965491734773762", "instanceID", nil),
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoNotFound}, "231965491734773762", "instanceID", database.TextArray[string](nil)),
wantErr: zerrors.ThrowNotFound(nil, "QUERY-ahs4S", "Errors.User.NotFound"),
},
{
@@ -81,7 +81,7 @@ func TestQueries_GetOIDCUserInfo(t *testing.T) {
args: args{
userID: "231965491734773762",
},
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoHumanNoMD}, "231965491734773762", "instanceID", nil),
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoHumanNoMD}, "231965491734773762", "instanceID", database.TextArray[string](nil)),
want: &OIDCUserInfo{
User: &User{
ID: "231965491734773762",
@@ -120,7 +120,7 @@ func TestQueries_GetOIDCUserInfo(t *testing.T) {
args: args{
userID: "231965491734773762",
},
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoHuman}, "231965491734773762", "instanceID", nil),
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoHuman}, "231965491734773762", "instanceID", database.TextArray[string](nil)),
want: &OIDCUserInfo{
User: &User{
ID: "231965491734773762",
@@ -277,7 +277,7 @@ func TestQueries_GetOIDCUserInfo(t *testing.T) {
args: args{
userID: "240707570677841922",
},
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoMachine}, "240707570677841922", "instanceID", nil),
mock: mockQuery(expQuery, []string{"json_build_object"}, []driver.Value{testdataUserInfoMachine}, "240707570677841922", "instanceID", database.TextArray[string](nil)),
want: &OIDCUserInfo{
User: &User{
ID: "240707570677841922",

View File

@@ -13,6 +13,7 @@ import (
"github.com/DATA-DOG/go-sqlmock"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/static"
)
@@ -278,7 +279,7 @@ type db struct {
func prepareDB(t *testing.T, expectations ...expectation) db {
t.Helper()
client, mock, err := sqlmock.New()
client, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("unable to create sql mock: %v", err)
}
@@ -295,7 +296,7 @@ type expectation func(m sqlmock.Sqlmock)
func expectExists(query string, value bool, args ...driver.Value) expectation {
return func(m sqlmock.Sqlmock) {
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnRows(sqlmock.NewRows([]string{"exists"}).AddRow(value))
m.ExpectQuery(regexp.QuoteMeta(query)).WithArgs(args...).WillReturnRows(m.NewRows([]string{"exists"}).AddRow(value))
}
}
@@ -307,7 +308,7 @@ func expectQueryErr(query string, err error, args ...driver.Value) expectation {
func expectQuery(stmt string, cols []string, rows [][]driver.Value, args ...driver.Value) func(m sqlmock.Sqlmock) {
return func(m sqlmock.Sqlmock) {
q := m.ExpectQuery(regexp.QuoteMeta(stmt)).WithArgs(args...)
result := sqlmock.NewRows(cols)
result := m.NewRows(cols)
count := uint64(len(rows))
for _, row := range rows {
if cols[len(cols)-1] == "count" {

View File

@@ -6,10 +6,10 @@ import (
"strconv"
"testing"
"github.com/zitadel/zitadel/internal/domain"
"github.com/DATA-DOG/go-sqlmock"
"github.com/jinzhu/gorm"
db_mock "github.com/zitadel/zitadel/internal/database/mock"
"github.com/zitadel/zitadel/internal/domain"
)
var (
@@ -130,15 +130,15 @@ func (db *dbMock) close() {
func mockDB(t *testing.T) *dbMock {
mockDB := dbMock{}
db, mock, err := sqlmock.New()
db, mock, err := sqlmock.New(sqlmock.ValueConverterOption(new(db_mock.TypeConverter)))
if err != nil {
t.Fatalf("error occured while creating stub db %v", err)
t.Fatalf("error occurred while creating stub db %v", err)
}
mockDB.mock = mock
mockDB.db, err = gorm.Open("postgres", db)
if err != nil {
t.Fatalf("error occured while connecting to stub db: %v", err)
t.Fatalf("error occurred while connecting to stub db: %v", err)
}
mockDB.mock.MatchExpectationsInOrder(true)
@@ -178,7 +178,7 @@ func (db *dbMock) expectGetByID(table, key, value string) *dbMock {
db.mock.ExpectBegin()
db.mock.ExpectQuery(query).
WithArgs(value).
WillReturnRows(sqlmock.NewRows([]string{key}).
WillReturnRows(db.mock.NewRows([]string{key}).
AddRow(key))
db.mock.ExpectCommit()
@@ -201,7 +201,7 @@ func (db *dbMock) expectGetByQuery(table, key, method, value string) *dbMock {
db.mock.ExpectBegin()
db.mock.ExpectQuery(query).
WithArgs(value).
WillReturnRows(sqlmock.NewRows([]string{key}).
WillReturnRows(db.mock.NewRows([]string{key}).
AddRow(key))
db.mock.ExpectCommit()
@@ -213,7 +213,7 @@ func (db *dbMock) expectGetByQueryCaseSensitive(table, key, method, value string
db.mock.ExpectBegin()
db.mock.ExpectQuery(query).
WithArgs(value).
WillReturnRows(sqlmock.NewRows([]string{key}).
WillReturnRows(db.mock.NewRows([]string{key}).
AddRow(key))
db.mock.ExpectCommit()
@@ -259,15 +259,15 @@ func (db *dbMock) expectRemove(table, key, value string) *dbMock {
}
func (db *dbMock) expectRemoveKeys(table string, keys ...Key) *dbMock {
keynames := make([]interface{}, len(keys))
keyvalues := make([]driver.Value, len(keys))
keyNames := make([]interface{}, len(keys))
keyValues := make([]driver.Value, len(keys))
for i, key := range keys {
keynames[i] = key.Key.ToColumnName()
keyvalues[i] = key.Value
keyNames[i] = key.Key.ToColumnName()
keyValues[i] = key.Value
}
query := fmt.Sprintf(expectedRemoveByKeys(len(keys), table), keynames...)
query := fmt.Sprintf(expectedRemoveByKeys(len(keys), table), keyNames...)
db.mock.ExpectExec(query).
WithArgs(keyvalues...).
WithArgs(keyValues...).
WillReturnResult(sqlmock.NewResult(1, 1))
return db
@@ -318,7 +318,7 @@ func (db *dbMock) expectGetSearchRequestNoParams(table string, resultAmount, tot
query := fmt.Sprintf(expectedSearch, table)
queryCount := fmt.Sprintf(expectedSearchCount, table)
rows := sqlmock.NewRows([]string{"id"})
rows := db.mock.NewRows([]string{"id"})
for i := 0; i < resultAmount; i++ {
rows.AddRow(fmt.Sprintf("hodor-%d", i))
}
@@ -326,7 +326,7 @@ func (db *dbMock) expectGetSearchRequestNoParams(table string, resultAmount, tot
db.mock.ExpectBegin()
db.mock.ExpectQuery(queryCount).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total))
WillReturnRows(db.mock.NewRows([]string{"count"}).AddRow(total))
db.mock.ExpectQuery(query).
WillReturnRows(rows)
@@ -338,14 +338,14 @@ func (db *dbMock) expectGetSearchRequestWithLimit(table string, limit, resultAmo
query := fmt.Sprintf(expectedSearchLimit, table, limit)
queryCount := fmt.Sprintf(expectedSearchLimitCount, table)
rows := sqlmock.NewRows([]string{"id"})
rows := db.mock.NewRows([]string{"id"})
for i := 0; i < resultAmount; i++ {
rows.AddRow(fmt.Sprintf("hodor-%d", i))
}
db.mock.ExpectBegin()
db.mock.ExpectQuery(queryCount).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total))
WillReturnRows(db.mock.NewRows([]string{"count"}).AddRow(total))
db.mock.ExpectQuery(query).
WillReturnRows(rows)
db.mock.ExpectCommit()
@@ -356,14 +356,14 @@ func (db *dbMock) expectGetSearchRequestWithOffset(table string, offset, resultA
query := fmt.Sprintf(expectedSearchOffset, table, offset)
queryCount := fmt.Sprintf(expectedSearchOffsetCount, table)
rows := sqlmock.NewRows([]string{"id"})
rows := db.mock.NewRows([]string{"id"})
for i := 0; i < resultAmount; i++ {
rows.AddRow(fmt.Sprintf("hodor-%d", i))
}
db.mock.ExpectBegin()
db.mock.ExpectQuery(queryCount).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total))
WillReturnRows(db.mock.NewRows([]string{"count"}).AddRow(total))
db.mock.ExpectQuery(query).
WillReturnRows(rows)
db.mock.ExpectCommit()
@@ -374,14 +374,14 @@ func (db *dbMock) expectGetSearchRequestWithSorting(table, sorting string, sorti
query := fmt.Sprintf(expectedSearchSorting, table, sortingColumn.ToColumnName(), sorting)
queryCount := fmt.Sprintf(expectedSearchSortingCount, table)
rows := sqlmock.NewRows([]string{"id"})
rows := db.mock.NewRows([]string{"id"})
for i := 0; i < resultAmount; i++ {
rows.AddRow(fmt.Sprintf("hodor-%d", i))
}
db.mock.ExpectBegin()
db.mock.ExpectQuery(queryCount).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total))
WillReturnRows(db.mock.NewRows([]string{"count"}).AddRow(total))
db.mock.ExpectQuery(query).
WillReturnRows(rows)
db.mock.ExpectCommit()
@@ -392,7 +392,7 @@ func (db *dbMock) expectGetSearchRequestWithSearchQuery(table, key, method, valu
query := fmt.Sprintf(expectedSearchQuery, table, key, method)
queryCount := fmt.Sprintf(expectedSearchQueryCount, table, key, method)
rows := sqlmock.NewRows([]string{"id"})
rows := db.mock.NewRows([]string{"id"})
for i := 0; i < resultAmount; i++ {
rows.AddRow(fmt.Sprintf("hodor-%d", i))
}
@@ -400,7 +400,7 @@ func (db *dbMock) expectGetSearchRequestWithSearchQuery(table, key, method, valu
db.mock.ExpectBegin()
db.mock.ExpectQuery(queryCount).
WithArgs(value).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total))
WillReturnRows(db.mock.NewRows([]string{"count"}).AddRow(total))
db.mock.ExpectQuery(query).
WithArgs(value).
WillReturnRows(rows)
@@ -412,7 +412,7 @@ func (db *dbMock) expectGetSearchRequestWithAllParams(table, key, method, value,
query := fmt.Sprintf(expectedSearchQueryAllParams, table, key, method, sortingColumn.ToColumnName(), sorting, limit, offset)
queryCount := fmt.Sprintf(expectedSearchQueryAllParamCount, table, key, method)
rows := sqlmock.NewRows([]string{"id"})
rows := db.mock.NewRows([]string{"id"})
for i := 0; i < resultAmount; i++ {
rows.AddRow(fmt.Sprintf("hodor-%d", i))
}
@@ -420,7 +420,7 @@ func (db *dbMock) expectGetSearchRequestWithAllParams(table, key, method, value,
db.mock.ExpectBegin()
db.mock.ExpectQuery(queryCount).
WithArgs(value).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total))
WillReturnRows(db.mock.NewRows([]string{"count"}).AddRow(total))
db.mock.ExpectQuery(query).
WithArgs(value).
WillReturnRows(rows)
@@ -432,14 +432,14 @@ func (db *dbMock) expectGetSearchRequestErr(table string, resultAmount, total in
query := fmt.Sprintf(expectedSearch, table)
queryCount := fmt.Sprintf(expectedSearchCount, table)
rows := sqlmock.NewRows([]string{"id"})
rows := db.mock.NewRows([]string{"id"})
for i := 0; i < resultAmount; i++ {
rows.AddRow(fmt.Sprintf("hodor-%d", i))
}
db.mock.ExpectBegin()
db.mock.ExpectQuery(queryCount).
WillReturnRows(sqlmock.NewRows([]string{"count"}).AddRow(total))
WillReturnRows(db.mock.NewRows([]string{"count"}).AddRow(total))
db.mock.ExpectQuery(query).
WillReturnError(err)
db.mock.ExpectCommit()