Merge branch 'main' into v3.x

# Conflicts:
#	internal/query/user.go
#	internal/query/user_auth_method.go
#	internal/query/user_auth_method_test.go
This commit is contained in:
Livio Spring
2025-04-01 08:45:47 +02:00
20 changed files with 146 additions and 238 deletions

View File

@@ -2,15 +2,19 @@ package setup
import (
"context"
"database/sql"
_ "embed"
"errors"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/eventstore"
)
var (
//go:embed 52.sql
//go:embed 52/alter.sql
renameTableIfNotExisting string
//go:embed 52/check.sql
checkIfTableIsExisting string
)
type IDPTemplate6LDAP2 struct {
@@ -18,7 +22,23 @@ type IDPTemplate6LDAP2 struct {
}
func (mig *IDPTemplate6LDAP2) Execute(ctx context.Context, _ eventstore.Event) error {
_, err := mig.dbClient.ExecContext(ctx, renameTableIfNotExisting)
var count int
err := mig.dbClient.QueryRowContext(ctx,
func(row *sql.Row) error {
if err := row.Scan(&count); err != nil {
return err
}
return row.Err()
},
checkIfTableIsExisting,
)
if err == nil {
return nil
}
if !errors.Is(err, sql.ErrNoRows) {
return err
}
_, err = mig.dbClient.ExecContext(ctx, renameTableIfNotExisting)
return err
}

View File

@@ -1,2 +1,2 @@
ALTER TABLE IF EXISTS projections.idp_templates6_ldap3 RENAME COLUMN rootCA TO root_ca;
ALTER TABLE IF EXISTS projections.idp_templates6_ldap3 RENAME TO idp_templates6_ldap2;
ALTER TABLE IF EXISTS projections.idp_templates6_ldap3 RENAME TO idp_templates6_ldap2;

4
cmd/setup/52/check.sql Normal file
View File

@@ -0,0 +1,4 @@
SELECT 1
FROM information_schema.tables
WHERE table_schema = 'projections'
AND table_name = 'idp_templates6_ldap2';

View File

@@ -20,7 +20,6 @@ func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) grpc.Una
if !svc.Enabled() {
return handler(ctx, req)
}
reqMd, _ := metadata.FromIncomingContext(ctx)
resp, handlerErr := handler(ctx, req)

View File

@@ -34,7 +34,7 @@ func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor {
func executeTargetsForRequest(ctx context.Context, targets []execution.Target, fullMethod string, req interface{}) (_ interface{}, err error) {
ctx, span := tracing.NewSpan(ctx)
defer span.EndWithError(err)
defer func() { span.EndWithError(err) }()
// if no targets are found, return without any calls
if len(targets) == 0 {
@@ -56,7 +56,7 @@ func executeTargetsForRequest(ctx context.Context, targets []execution.Target, f
func executeTargetsForResponse(ctx context.Context, targets []execution.Target, fullMethod string, req, resp interface{}) (_ interface{}, err error) {
ctx, span := tracing.NewSpan(ctx)
defer span.EndWithError(err)
defer func() { span.EndWithError(err) }()
// if no targets are found, return without any calls
if len(targets) == 0 {

View File

@@ -255,7 +255,7 @@ type userSearchByID struct {
}
func (u userSearchByID) search(ctx context.Context, q *query.Queries) (*query.User, error) {
return q.GetUserByID(ctx, true, u.id)
return q.GetUserByID(ctx, false, u.id)
}
type userSearchByLoginName struct {

View File

@@ -150,7 +150,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
if err != nil {
return nil, err
}
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, true)
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
if err != nil {
return nil, err
}
@@ -546,11 +546,7 @@ func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authoriz
code: code,
state: authReq.GetState(),
}
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
if err != nil {
return "", err
}
return callback, err
return op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
}
func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) (string, error) {

View File

@@ -21,7 +21,9 @@ import (
)
func TestServer_JWTProfile(t *testing.T) {
user, name, keyData, err := Instance.CreateOIDCJWTProfileClient(CTX)
user, name, keyData, err := Instance.CreateOIDCJWTProfileClient(CTX, time.Hour)
require.NoError(t, err)
_, _, keyDataExpired, err := Instance.CreateOIDCJWTProfileClient(CTX, 10*time.Second)
require.NoError(t, err)
type claims struct {
@@ -104,6 +106,12 @@ func TestServer_JWTProfile(t *testing.T) {
resourceOwnerPrimaryDomain: Instance.DefaultOrg.PrimaryDomain,
},
},
{
name: "key expired",
keyData: keyDataExpired,
scope: []string{oidc.ScopeOpenID},
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
@@ -123,6 +131,9 @@ func TestServer_JWTProfile(t *testing.T) {
},
time.Minute, time.Second,
)
if tt.wantErr {
return
}
provider, err := rp.NewRelyingPartyOIDC(CTX, Instance.OIDCIssuer(), "", "", redirectURI, tt.scope)
require.NoError(t, err)

View File

@@ -789,7 +789,7 @@ func (repo *AuthRequestRepo) checkLoginName(ctx context.Context, request *domain
}
// if there's an active (human) user, let's use it
if user != nil && !user.HumanView.IsZero() && domain.UserState(user.State).IsEnabled() {
request.SetUserInfo(user.ID, loginNameInput, user.PreferredLoginName, "", "", user.ResourceOwner)
request.SetUserInfo(user.ID, loginNameInput, preferredLoginName, "", "", user.ResourceOwner)
return nil
}
// the user was either not found or not active
@@ -1055,9 +1055,6 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
if err != nil {
return nil, err
}
if user.PreferredLoginName != "" {
request.LoginName = user.PreferredLoginName
}
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user)
if err != nil {
return nil, err

View File

@@ -23,7 +23,7 @@ type OrgRepository struct {
}
func (repo *OrgRepository) GetMyPasswordComplexityPolicy(ctx context.Context) (*iam_model.PasswordComplexityPolicyView, error) {
policy, err := repo.Query.PasswordComplexityPolicyByOrg(ctx, true, authz.GetCtxData(ctx).OrgID, false)
policy, err := repo.Query.PasswordComplexityPolicyByOrg(ctx, false, authz.GetCtxData(ctx).OrgID, false)
if err != nil {
return nil, err
}

View File

@@ -1,6 +1,7 @@
package eventstore
import (
"slices"
"sync"
"github.com/zitadel/logging"
@@ -8,7 +9,7 @@ import (
var (
subscriptions = map[AggregateType][]*Subscription{}
subsMutext sync.Mutex
subsMutex sync.RWMutex
)
type Subscription struct {
@@ -27,8 +28,8 @@ func SubscribeAggregates(eventQueue chan Event, aggregates ...AggregateType) *Su
types: types,
}
subsMutext.Lock()
defer subsMutext.Unlock()
subsMutex.Lock()
defer subsMutex.Unlock()
for _, aggregate := range aggregates {
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
@@ -45,8 +46,8 @@ func SubscribeEventTypes(eventQueue chan Event, types map[AggregateType][]EventT
types: types,
}
subsMutext.Lock()
defer subsMutext.Unlock()
subsMutex.Lock()
defer subsMutex.Unlock()
for aggregate := range types {
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
@@ -56,8 +57,8 @@ func SubscribeEventTypes(eventQueue chan Event, types map[AggregateType][]EventT
}
func (es *Eventstore) notify(events []Event) {
subsMutext.Lock()
defer subsMutext.Unlock()
subsMutex.RLock()
defer subsMutex.RUnlock()
for _, event := range events {
subs, ok := subscriptions[event.Aggregate().Type]
if !ok {
@@ -71,14 +72,11 @@ func (es *Eventstore) notify(events []Event) {
continue
}
//subscription for certain events
for _, eventType := range eventTypes {
if event.Type() == eventType {
select {
case sub.Events <- event:
default:
logging.Debug("unable to push event")
}
break
if slices.Contains(eventTypes, event.Type()) {
select {
case sub.Events <- event:
default:
logging.Debug("unable to push event")
}
}
}
@@ -86,8 +84,8 @@ func (es *Eventstore) notify(events []Event) {
}
func (s *Subscription) Unsubscribe() {
subsMutext.Lock()
defer subsMutext.Unlock()
subsMutex.Lock()
defer subsMutex.Unlock()
for aggregate := range s.types {
subs, ok := subscriptions[aggregate]
if !ok {

View File

@@ -47,7 +47,7 @@ func (es *Eventstore) FillFields(ctx context.Context, events ...eventstore.FillF
// Search implements the [eventstore.Search] method
func (es *Eventstore) Search(ctx context.Context, conditions ...map[eventstore.FieldType]any) (result []*eventstore.SearchResult, err error) {
ctx, span := tracing.NewSpan(ctx)
defer span.EndWithError(err)
defer func() { span.EndWithError(err) }()
var builder strings.Builder
args := buildSearchStatement(ctx, &builder, conditions...)

View File

@@ -42,7 +42,7 @@ func CallTargets(
info ContextInfo,
) (_ interface{}, err error) {
ctx, span := tracing.NewSpan(ctx)
defer span.EndWithError(err)
defer func() { span.EndWithError(err) }()
for _, target := range targets {
// call the type of target
@@ -72,7 +72,7 @@ func CallTarget(
info ContextInfoRequest,
) (res []byte, err error) {
ctx, span := tracing.NewSpan(ctx)
defer span.EndWithError(err)
defer func() { span.EndWithError(err) }()
switch target.GetTargetType() {
// get request, ignore response and return request and error for handling in list of targets

View File

@@ -3,6 +3,7 @@ package integration
import (
"context"
"fmt"
"sync"
"testing"
"time"
@@ -157,6 +158,7 @@ func (i *Instance) CreateHumanUser(ctx context.Context) *user_v2.AddHumanUserRes
},
})
logging.OnError(err).Panic("create human user")
i.TriggerUserByID(ctx, resp.GetUserId())
return resp
}
@@ -181,6 +183,7 @@ func (i *Instance) CreateHumanUserNoPhone(ctx context.Context) *user_v2.AddHuman
},
})
logging.OnError(err).Panic("create human user")
i.TriggerUserByID(ctx, resp.GetUserId())
return resp
}
@@ -212,9 +215,26 @@ func (i *Instance) CreateHumanUserWithTOTP(ctx context.Context, secret string) *
TotpSecret: gu.Ptr(secret),
})
logging.OnError(err).Panic("create human user")
i.TriggerUserByID(ctx, resp.GetUserId())
return resp
}
// TriggerUserByID makes sure the user projection gets triggered after creation.
func (i *Instance) TriggerUserByID(ctx context.Context, users ...string) {
var wg sync.WaitGroup
wg.Add(len(users))
for _, user := range users {
go func(user string) {
defer wg.Done()
_, err := i.Client.UserV2.GetUserByID(ctx, &user_v2.GetUserByIDRequest{
UserId: user,
})
logging.OnError(err).Warn("get user by ID for trigger failed")
}(user)
}
wg.Wait()
}
func (i *Instance) CreateOrganization(ctx context.Context, name, adminEmail string) *org.AddOrganizationResponse {
resp, err := i.Client.OrgV2.AddOrganization(ctx, &org.AddOrganizationRequest{
Name: name,
@@ -238,6 +258,13 @@ func (i *Instance) CreateOrganization(ctx context.Context, name, adminEmail stri
},
})
logging.OnError(err).Panic("create org")
users := make([]string, len(resp.GetCreatedAdmins()))
for i, admin := range resp.GetCreatedAdmins() {
users[i] = admin.GetUserId()
}
i.TriggerUserByID(ctx, users...)
return resp
}
@@ -302,6 +329,7 @@ func (i *Instance) CreateHumanUserVerified(ctx context.Context, org, email, phon
},
})
logging.OnError(err).Panic("create human user")
i.TriggerUserByID(ctx, resp.GetUserId())
return resp
}
@@ -313,6 +341,7 @@ func (i *Instance) CreateMachineUser(ctx context.Context) *mgmt.AddMachineUserRe
AccessTokenType: user_pb.AccessTokenType_ACCESS_TOKEN_TYPE_BEARER,
})
logging.OnError(err).Panic("create human user")
i.TriggerUserByID(ctx, resp.GetUserId())
return resp
}

View File

@@ -438,7 +438,7 @@ func (i *Instance) CreateOIDCCredentialsClientInactive(ctx context.Context) (mac
return machine, name, secret.GetClientId(), secret.GetClientSecret(), nil
}
func (i *Instance) CreateOIDCJWTProfileClient(ctx context.Context) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) {
func (i *Instance) CreateOIDCJWTProfileClient(ctx context.Context, keyLifetime time.Duration) (machine *management.AddMachineUserResponse, name string, keyData []byte, err error) {
name = gofakeit.Username()
machine, err = i.Client.Mgmt.AddMachineUser(ctx, &management.AddMachineUserRequest{
Name: name,
@@ -451,7 +451,7 @@ func (i *Instance) CreateOIDCJWTProfileClient(ctx context.Context) (machine *man
keyResp, err := i.Client.Mgmt.AddMachineKey(ctx, &management.AddMachineKeyRequest{
UserId: machine.GetUserId(),
Type: authn.KeyType_KEY_TYPE_JSON,
ExpirationDate: timestamppb.New(time.Now().Add(time.Hour)),
ExpirationDate: timestamppb.New(time.Now().Add(keyLifetime)),
})
if err != nil {
return nil, "", nil, err

View File

@@ -3,9 +3,10 @@ from projections.authn_keys2 k
join projections.users14 u
on k.instance_id = u.instance_id
and k.identifier = u.id
join projections.users14_machines m
join projections.users14_machines m
on u.instance_id = m.instance_id
and u.id = m.user_id
where k.instance_id = $1
and k.id = $2
and u.id = $3;
and u.id = $3
and k.expiration > current_timestamp;

View File

@@ -427,34 +427,6 @@ func (q *Queries) GetUserByLoginName(ctx context.Context, shouldTriggered bool,
return user, err
}
// Deprecated: use either GetUserByID or GetUserByLoginName
func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, queries ...SearchQuery) (user *User, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if shouldTriggerBulk {
triggerUserProjections(ctx)
}
query, scan := prepareUserQuery()
for _, q := range queries {
query = q.toQuery(query)
}
eq := sq.Eq{
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
stmt, args, err := query.Where(eq).ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Dnhr2", "Errors.Query.SQLStatment")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
user, err = scan(row)
return err
}, stmt, args...)
return user, err
}
func (q *Queries) GetHumanProfile(ctx context.Context, userID string, queries ...SearchQuery) (profile *Profile, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()

View File

@@ -3,6 +3,7 @@ package query
import (
"context"
"database/sql"
_ "embed"
"errors"
"slices"
"time"
@@ -211,6 +212,9 @@ type UserAuthMethodRequirements struct {
ForceMFALocalOnly bool
}
//go:embed user_auth_method_types_required.sql
var listUserAuthMethodTypesStmt string
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) {
ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID {
@@ -221,20 +225,33 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareUserAuthMethodTypesRequiredQuery()
eq := sq.Eq{
UserIDCol.identifier(): userID,
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
stmt, args, err := query.Where(eq).ToSql()
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
requirements, err = scan(row)
return err
}, stmt, args...)
err = q.client.QueryRowContext(ctx,
func(row *sql.Row) error {
var userType sql.NullInt32
var forceMFA sql.NullBool
var forceMFALocalOnly sql.NullBool
err := row.Scan(
&userType,
&forceMFA,
&forceMFALocalOnly,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
}
return zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
}
requirements = &UserAuthMethodRequirements{
UserType: domain.UserType(userType.Int32),
ForceMFA: forceMFA.Bool,
ForceMFALocalOnly: forceMFALocalOnly.Bool,
}
return nil
},
listUserAuthMethodTypesStmt,
userID,
authz.GetInstance(ctx).InstanceID(),
)
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
}
@@ -460,45 +477,6 @@ func prepareUserAuthMethodTypesQuery(activeOnly bool, includeWithoutDomain bool,
}
}
func prepareUserAuthMethodTypesRequiredQuery() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
if err != nil {
return sq.SelectBuilder{}, nil
}
return sq.Select(
UserTypeCol.identifier(),
forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier()).
From(userTable.identifier()).
LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " +
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
OrderBy(forceMFAIsDefault.identifier()).
Limit(1).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*UserAuthMethodRequirements, error) {
var userType sql.NullInt32
var forceMFA sql.NullBool
var forceMFALocalOnly sql.NullBool
err := row.Scan(
&userType,
&forceMFA,
&forceMFALocalOnly,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
}
return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
}
return &UserAuthMethodRequirements{
UserType: domain.UserType(userType.Int32),
ForceMFA: forceMFA.Bool,
ForceMFALocalOnly: forceMFALocalOnly.Bool,
}, nil
}
}
func prepareAuthMethodsIDPsQuery() (string, error) {
idpsQuery, _, err := sq.Select(
userIDPsCountUserID.identifier(),
@@ -535,16 +513,3 @@ func prepareAuthMethodQuery(activeOnly bool, includeWithoutDomain bool, queryDom
return q.ToSql()
}
func prepareAuthMethodsForceMFAQuery() (string, error) {
loginPolicyQuery, _, err := sq.Select(
forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier(),
forceMFAInstanceID.identifier(),
forceMFAOrgID.identifier(),
forceMFAIsDefault.identifier(),
).
From(forceMFATable.identifier()).
ToSql()
return loginPolicyQuery, err
}

View File

@@ -14,7 +14,6 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/zerrors"
)
func TestUser_authMethodsCheckPermission(t *testing.T) {
@@ -660,106 +659,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
},
object: (*AuthMethodTypes)(nil),
},
{
name: "prepareUserAuthMethodTypesRequiredQuery no result",
prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery()
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(row)
}
},
want: want{
sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
nil,
nil,
),
err: func(err error) (error, bool) {
if !zerrors.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
},
object: (*UserAuthMethodRequirements)(nil),
},
{
name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery()
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(row)
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
prepareAuthMethodTypesRequiredCols,
[][]driver.Value{
{
domain.UserTypeHuman,
true,
true,
},
},
),
},
object: &UserAuthMethodRequirements{
UserType: domain.UserTypeHuman,
ForceMFA: true,
ForceMFALocalOnly: true,
},
},
{
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery()
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(row)
}
},
want: want{
sqlExpectations: mockQueries(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
prepareAuthMethodTypesRequiredCols,
[][]driver.Value{
{
domain.UserTypeHuman,
true,
true,
},
},
),
},
object: &UserAuthMethodRequirements{
UserType: domain.UserTypeHuman,
ForceMFA: true,
ForceMFALocalOnly: true,
},
},
{
name: "prepareUserAuthMethodTypesRequiredQuery sql err",
prepare: func() (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery()
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(row)
}
},
want: want{
sqlExpectations: mockQueryErr(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
sql.ErrConnDone,
),
err: func(err error) (error, bool) {
if !errors.Is(err, sql.ErrConnDone) {
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
}
return nil, true
},
},
object: nil,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {

View File

@@ -0,0 +1,17 @@
SELECT
projections.users14.type
, auth_methods_force_mfa.force_mfa
, auth_methods_force_mfa.force_mfa_local_only
FROM
projections.users14
LEFT JOIN
projections.login_policies5 AS auth_methods_force_mfa
ON
auth_methods_force_mfa.instance_id = projections.users14.instance_id
AND auth_methods_force_mfa.aggregate_id = ANY(ARRAY[projections.users14.instance_id, projections.users14.resource_owner])
WHERE
projections.users14.id = $1
AND projections.users14.instance_id = $2
ORDER BY
auth_methods_force_mfa.is_default
LIMIT 1;