mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 03:27:32 +00:00
perf: improve scalability of session api (#9635)
This pull request improves the scalability of the session API by enhancing middleware tracing and refining SQL query behavior for user authentication methods. # Which Problems Are Solved - Eventstore subscriptions locked each other during they wrote the events to the event channels of the subscribers in push. - `ListUserAuthMethodTypesRequired` query used `Bitmap heap scan` to join the tables needed. - The auth and oidc package triggered projections often when data were read. - The session API triggered the user projection each time a user was searched to write the user check command. # How the Problems Are Solved - the `sync.Mutex` was replaced with `sync.RWMutex` to allow parallel read of the map - The query was refactored to use index scans only - if the data should already be up-to-date `shouldTriggerBulk` is set to false - as the user should already exist for some time the trigger was removed. # Additional Changes - refactoring of `tracing#Span.End` calls # Additional Context - part of https://github.com/zitadel/zitadel/issues/9239 --------- Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
@@ -20,7 +20,6 @@ func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) grpc.Una
|
|||||||
if !svc.Enabled() {
|
if !svc.Enabled() {
|
||||||
return handler(ctx, req)
|
return handler(ctx, req)
|
||||||
}
|
}
|
||||||
|
|
||||||
reqMd, _ := metadata.FromIncomingContext(ctx)
|
reqMd, _ := metadata.FromIncomingContext(ctx)
|
||||||
|
|
||||||
resp, handlerErr := handler(ctx, req)
|
resp, handlerErr := handler(ctx, req)
|
||||||
|
@@ -34,7 +34,7 @@ func ExecutionHandler(queries *query.Queries) grpc.UnaryServerInterceptor {
|
|||||||
|
|
||||||
func executeTargetsForRequest(ctx context.Context, targets []execution.Target, fullMethod string, req interface{}) (_ interface{}, err error) {
|
func executeTargetsForRequest(ctx context.Context, targets []execution.Target, fullMethod string, req interface{}) (_ interface{}, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer span.EndWithError(err)
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
// if no targets are found, return without any calls
|
// if no targets are found, return without any calls
|
||||||
if len(targets) == 0 {
|
if len(targets) == 0 {
|
||||||
@@ -56,7 +56,7 @@ func executeTargetsForRequest(ctx context.Context, targets []execution.Target, f
|
|||||||
|
|
||||||
func executeTargetsForResponse(ctx context.Context, targets []execution.Target, fullMethod string, req, resp interface{}) (_ interface{}, err error) {
|
func executeTargetsForResponse(ctx context.Context, targets []execution.Target, fullMethod string, req, resp interface{}) (_ interface{}, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer span.EndWithError(err)
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
// if no targets are found, return without any calls
|
// if no targets are found, return without any calls
|
||||||
if len(targets) == 0 {
|
if len(targets) == 0 {
|
||||||
|
@@ -255,7 +255,7 @@ type userSearchByID struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u userSearchByID) search(ctx context.Context, q *query.Queries) (*query.User, error) {
|
func (u userSearchByID) search(ctx context.Context, q *query.Queries) (*query.User, error) {
|
||||||
return q.GetUserByID(ctx, true, u.id)
|
return q.GetUserByID(ctx, false, u.id)
|
||||||
}
|
}
|
||||||
|
|
||||||
type userSearchByLoginName struct {
|
type userSearchByLoginName struct {
|
||||||
|
@@ -150,7 +150,7 @@ func (o *OPStorage) audienceFromProjectID(ctx context.Context, projectID string)
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, true)
|
appIDs, err := o.query.SearchClientIDs(ctx, &query.AppSearchQueries{Queries: []query.SearchQuery{projectIDQuery}}, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -546,11 +546,7 @@ func CreateCodeCallbackURL(ctx context.Context, authReq op.AuthRequest, authoriz
|
|||||||
code: code,
|
code: code,
|
||||||
state: authReq.GetState(),
|
state: authReq.GetState(),
|
||||||
}
|
}
|
||||||
callback, err := op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
return op.AuthResponseURL(authReq.GetRedirectURI(), authReq.GetResponseType(), authReq.GetResponseMode(), &codeResponse, authorizer.Encoder())
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
return callback, err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) (string, error) {
|
func (s *Server) CreateTokenCallbackURL(ctx context.Context, req op.AuthRequest) (string, error) {
|
||||||
|
@@ -23,7 +23,7 @@ type OrgRepository struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (repo *OrgRepository) GetMyPasswordComplexityPolicy(ctx context.Context) (*iam_model.PasswordComplexityPolicyView, error) {
|
func (repo *OrgRepository) GetMyPasswordComplexityPolicy(ctx context.Context) (*iam_model.PasswordComplexityPolicyView, error) {
|
||||||
policy, err := repo.Query.PasswordComplexityPolicyByOrg(ctx, true, authz.GetCtxData(ctx).OrgID, false)
|
policy, err := repo.Query.PasswordComplexityPolicyByOrg(ctx, false, authz.GetCtxData(ctx).OrgID, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
package eventstore
|
package eventstore
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"slices"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/zitadel/logging"
|
"github.com/zitadel/logging"
|
||||||
@@ -8,7 +9,7 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
subscriptions = map[AggregateType][]*Subscription{}
|
subscriptions = map[AggregateType][]*Subscription{}
|
||||||
subsMutext sync.Mutex
|
subsMutex sync.RWMutex
|
||||||
)
|
)
|
||||||
|
|
||||||
type Subscription struct {
|
type Subscription struct {
|
||||||
@@ -27,8 +28,8 @@ func SubscribeAggregates(eventQueue chan Event, aggregates ...AggregateType) *Su
|
|||||||
types: types,
|
types: types,
|
||||||
}
|
}
|
||||||
|
|
||||||
subsMutext.Lock()
|
subsMutex.Lock()
|
||||||
defer subsMutext.Unlock()
|
defer subsMutex.Unlock()
|
||||||
|
|
||||||
for _, aggregate := range aggregates {
|
for _, aggregate := range aggregates {
|
||||||
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
|
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
|
||||||
@@ -45,8 +46,8 @@ func SubscribeEventTypes(eventQueue chan Event, types map[AggregateType][]EventT
|
|||||||
types: types,
|
types: types,
|
||||||
}
|
}
|
||||||
|
|
||||||
subsMutext.Lock()
|
subsMutex.Lock()
|
||||||
defer subsMutext.Unlock()
|
defer subsMutex.Unlock()
|
||||||
|
|
||||||
for aggregate := range types {
|
for aggregate := range types {
|
||||||
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
|
subscriptions[aggregate] = append(subscriptions[aggregate], sub)
|
||||||
@@ -56,8 +57,8 @@ func SubscribeEventTypes(eventQueue chan Event, types map[AggregateType][]EventT
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (es *Eventstore) notify(events []Event) {
|
func (es *Eventstore) notify(events []Event) {
|
||||||
subsMutext.Lock()
|
subsMutex.RLock()
|
||||||
defer subsMutext.Unlock()
|
defer subsMutex.RUnlock()
|
||||||
for _, event := range events {
|
for _, event := range events {
|
||||||
subs, ok := subscriptions[event.Aggregate().Type]
|
subs, ok := subscriptions[event.Aggregate().Type]
|
||||||
if !ok {
|
if !ok {
|
||||||
@@ -71,14 +72,11 @@ func (es *Eventstore) notify(events []Event) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
//subscription for certain events
|
//subscription for certain events
|
||||||
for _, eventType := range eventTypes {
|
if slices.Contains(eventTypes, event.Type()) {
|
||||||
if event.Type() == eventType {
|
select {
|
||||||
select {
|
case sub.Events <- event:
|
||||||
case sub.Events <- event:
|
default:
|
||||||
default:
|
logging.Debug("unable to push event")
|
||||||
logging.Debug("unable to push event")
|
|
||||||
}
|
|
||||||
break
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -86,8 +84,8 @@ func (es *Eventstore) notify(events []Event) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Subscription) Unsubscribe() {
|
func (s *Subscription) Unsubscribe() {
|
||||||
subsMutext.Lock()
|
subsMutex.Lock()
|
||||||
defer subsMutext.Unlock()
|
defer subsMutex.Unlock()
|
||||||
for aggregate := range s.types {
|
for aggregate := range s.types {
|
||||||
subs, ok := subscriptions[aggregate]
|
subs, ok := subscriptions[aggregate]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@@ -47,7 +47,7 @@ func (es *Eventstore) FillFields(ctx context.Context, events ...eventstore.FillF
|
|||||||
// Search implements the [eventstore.Search] method
|
// Search implements the [eventstore.Search] method
|
||||||
func (es *Eventstore) Search(ctx context.Context, conditions ...map[eventstore.FieldType]any) (result []*eventstore.SearchResult, err error) {
|
func (es *Eventstore) Search(ctx context.Context, conditions ...map[eventstore.FieldType]any) (result []*eventstore.SearchResult, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer span.EndWithError(err)
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
var builder strings.Builder
|
var builder strings.Builder
|
||||||
args := buildSearchStatement(ctx, &builder, conditions...)
|
args := buildSearchStatement(ctx, &builder, conditions...)
|
||||||
|
@@ -42,7 +42,7 @@ func CallTargets(
|
|||||||
info ContextInfo,
|
info ContextInfo,
|
||||||
) (_ interface{}, err error) {
|
) (_ interface{}, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer span.EndWithError(err)
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
for _, target := range targets {
|
for _, target := range targets {
|
||||||
// call the type of target
|
// call the type of target
|
||||||
@@ -72,7 +72,7 @@ func CallTarget(
|
|||||||
info ContextInfoRequest,
|
info ContextInfoRequest,
|
||||||
) (res []byte, err error) {
|
) (res []byte, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer span.EndWithError(err)
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
switch target.GetTargetType() {
|
switch target.GetTargetType() {
|
||||||
// get request, ignore response and return request and error for handling in list of targets
|
// get request, ignore response and return request and error for handling in list of targets
|
||||||
|
@@ -3,6 +3,7 @@ package integration
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -157,6 +158,7 @@ func (i *Instance) CreateHumanUser(ctx context.Context) *user_v2.AddHumanUserRes
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
logging.OnError(err).Panic("create human user")
|
logging.OnError(err).Panic("create human user")
|
||||||
|
i.TriggerUserByID(ctx, resp.GetUserId())
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -181,6 +183,7 @@ func (i *Instance) CreateHumanUserNoPhone(ctx context.Context) *user_v2.AddHuman
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
logging.OnError(err).Panic("create human user")
|
logging.OnError(err).Panic("create human user")
|
||||||
|
i.TriggerUserByID(ctx, resp.GetUserId())
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -212,9 +215,26 @@ func (i *Instance) CreateHumanUserWithTOTP(ctx context.Context, secret string) *
|
|||||||
TotpSecret: gu.Ptr(secret),
|
TotpSecret: gu.Ptr(secret),
|
||||||
})
|
})
|
||||||
logging.OnError(err).Panic("create human user")
|
logging.OnError(err).Panic("create human user")
|
||||||
|
i.TriggerUserByID(ctx, resp.GetUserId())
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TriggerUserByID makes sure the user projection gets triggered after creation.
|
||||||
|
func (i *Instance) TriggerUserByID(ctx context.Context, users ...string) {
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(len(users))
|
||||||
|
for _, user := range users {
|
||||||
|
go func(user string) {
|
||||||
|
defer wg.Done()
|
||||||
|
_, err := i.Client.UserV2.GetUserByID(ctx, &user_v2.GetUserByIDRequest{
|
||||||
|
UserId: user,
|
||||||
|
})
|
||||||
|
logging.OnError(err).Warn("get user by ID for trigger failed")
|
||||||
|
}(user)
|
||||||
|
}
|
||||||
|
wg.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
func (i *Instance) CreateOrganization(ctx context.Context, name, adminEmail string) *org.AddOrganizationResponse {
|
func (i *Instance) CreateOrganization(ctx context.Context, name, adminEmail string) *org.AddOrganizationResponse {
|
||||||
resp, err := i.Client.OrgV2.AddOrganization(ctx, &org.AddOrganizationRequest{
|
resp, err := i.Client.OrgV2.AddOrganization(ctx, &org.AddOrganizationRequest{
|
||||||
Name: name,
|
Name: name,
|
||||||
@@ -238,6 +258,13 @@ func (i *Instance) CreateOrganization(ctx context.Context, name, adminEmail stri
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
logging.OnError(err).Panic("create org")
|
logging.OnError(err).Panic("create org")
|
||||||
|
|
||||||
|
users := make([]string, len(resp.GetCreatedAdmins()))
|
||||||
|
for i, admin := range resp.GetCreatedAdmins() {
|
||||||
|
users[i] = admin.GetUserId()
|
||||||
|
}
|
||||||
|
i.TriggerUserByID(ctx, users...)
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -302,6 +329,7 @@ func (i *Instance) CreateHumanUserVerified(ctx context.Context, org, email, phon
|
|||||||
},
|
},
|
||||||
})
|
})
|
||||||
logging.OnError(err).Panic("create human user")
|
logging.OnError(err).Panic("create human user")
|
||||||
|
i.TriggerUserByID(ctx, resp.GetUserId())
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -313,6 +341,7 @@ func (i *Instance) CreateMachineUser(ctx context.Context) *mgmt.AddMachineUserRe
|
|||||||
AccessTokenType: user_pb.AccessTokenType_ACCESS_TOKEN_TYPE_BEARER,
|
AccessTokenType: user_pb.AccessTokenType_ACCESS_TOKEN_TYPE_BEARER,
|
||||||
})
|
})
|
||||||
logging.OnError(err).Panic("create human user")
|
logging.OnError(err).Panic("create human user")
|
||||||
|
i.TriggerUserByID(ctx, resp.GetUserId())
|
||||||
return resp
|
return resp
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -428,34 +428,6 @@ func (q *Queries) GetUserByLoginName(ctx context.Context, shouldTriggered bool,
|
|||||||
return user, err
|
return user, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deprecated: use either GetUserByID or GetUserByLoginName
|
|
||||||
func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, queries ...SearchQuery) (user *User, err error) {
|
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
|
||||||
defer func() { span.EndWithError(err) }()
|
|
||||||
|
|
||||||
if shouldTriggerBulk {
|
|
||||||
triggerUserProjections(ctx)
|
|
||||||
}
|
|
||||||
|
|
||||||
query, scan := prepareUserQuery(ctx, q.client)
|
|
||||||
for _, q := range queries {
|
|
||||||
query = q.toQuery(query)
|
|
||||||
}
|
|
||||||
eq := sq.Eq{
|
|
||||||
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
|
||||||
}
|
|
||||||
stmt, args, err := query.Where(eq).ToSql()
|
|
||||||
if err != nil {
|
|
||||||
return nil, zerrors.ThrowInternal(err, "QUERY-Dnhr2", "Errors.Query.SQLStatment")
|
|
||||||
}
|
|
||||||
|
|
||||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
|
||||||
user, err = scan(row)
|
|
||||||
return err
|
|
||||||
}, stmt, args...)
|
|
||||||
return user, err
|
|
||||||
}
|
|
||||||
|
|
||||||
func (q *Queries) GetHumanProfile(ctx context.Context, userID string, queries ...SearchQuery) (profile *Profile, err error) {
|
func (q *Queries) GetHumanProfile(ctx context.Context, userID string, queries ...SearchQuery) (profile *Profile, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
@@ -3,6 +3,7 @@ package query
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"database/sql"
|
"database/sql"
|
||||||
|
_ "embed"
|
||||||
"errors"
|
"errors"
|
||||||
"slices"
|
"slices"
|
||||||
"time"
|
"time"
|
||||||
@@ -212,6 +213,9 @@ type UserAuthMethodRequirements struct {
|
|||||||
ForceMFALocalOnly bool
|
ForceMFALocalOnly bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
//go:embed user_auth_method_types_required.sql
|
||||||
|
var listUserAuthMethodTypesStmt string
|
||||||
|
|
||||||
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) {
|
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) {
|
||||||
ctxData := authz.GetCtxData(ctx)
|
ctxData := authz.GetCtxData(ctx)
|
||||||
if ctxData.UserID != userID {
|
if ctxData.UserID != userID {
|
||||||
@@ -222,20 +226,33 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
|
|||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
query, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, q.client)
|
err = q.client.QueryRowContext(ctx,
|
||||||
eq := sq.Eq{
|
func(row *sql.Row) error {
|
||||||
UserIDCol.identifier(): userID,
|
var userType sql.NullInt32
|
||||||
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
var forceMFA sql.NullBool
|
||||||
}
|
var forceMFALocalOnly sql.NullBool
|
||||||
stmt, args, err := query.Where(eq).ToSql()
|
err := row.Scan(
|
||||||
if err != nil {
|
&userType,
|
||||||
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
|
&forceMFA,
|
||||||
}
|
&forceMFALocalOnly,
|
||||||
|
)
|
||||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
if err != nil {
|
||||||
requirements, err = scan(row)
|
if errors.Is(err, sql.ErrNoRows) {
|
||||||
return err
|
return zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
|
||||||
}, stmt, args...)
|
}
|
||||||
|
return zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
|
||||||
|
}
|
||||||
|
requirements = &UserAuthMethodRequirements{
|
||||||
|
UserType: domain.UserType(userType.Int32),
|
||||||
|
ForceMFA: forceMFA.Bool,
|
||||||
|
ForceMFALocalOnly: forceMFALocalOnly.Bool,
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
listUserAuthMethodTypesStmt,
|
||||||
|
userID,
|
||||||
|
authz.GetInstance(ctx).InstanceID(),
|
||||||
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
|
return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
|
||||||
}
|
}
|
||||||
@@ -461,45 +478,6 @@ func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, ac
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
|
||||||
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
|
|
||||||
if err != nil {
|
|
||||||
return sq.SelectBuilder{}, nil
|
|
||||||
}
|
|
||||||
return sq.Select(
|
|
||||||
UserTypeCol.identifier(),
|
|
||||||
forceMFAForce.identifier(),
|
|
||||||
forceMFAForceLocalOnly.identifier()).
|
|
||||||
From(userTable.identifier()).
|
|
||||||
LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " +
|
|
||||||
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
|
|
||||||
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
|
|
||||||
OrderBy(forceMFAIsDefault.identifier()).
|
|
||||||
Limit(1).
|
|
||||||
PlaceholderFormat(sq.Dollar),
|
|
||||||
func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
|
||||||
var userType sql.NullInt32
|
|
||||||
var forceMFA sql.NullBool
|
|
||||||
var forceMFALocalOnly sql.NullBool
|
|
||||||
err := row.Scan(
|
|
||||||
&userType,
|
|
||||||
&forceMFA,
|
|
||||||
&forceMFALocalOnly,
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
if errors.Is(err, sql.ErrNoRows) {
|
|
||||||
return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
|
|
||||||
}
|
|
||||||
return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
|
|
||||||
}
|
|
||||||
return &UserAuthMethodRequirements{
|
|
||||||
UserType: domain.UserType(userType.Int32),
|
|
||||||
ForceMFA: forceMFA.Bool,
|
|
||||||
ForceMFALocalOnly: forceMFALocalOnly.Bool,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func prepareAuthMethodsIDPsQuery() (string, error) {
|
func prepareAuthMethodsIDPsQuery() (string, error) {
|
||||||
idpsQuery, _, err := sq.Select(
|
idpsQuery, _, err := sq.Select(
|
||||||
userIDPsCountUserID.identifier(),
|
userIDPsCountUserID.identifier(),
|
||||||
@@ -536,16 +514,3 @@ func prepareAuthMethodQuery(activeOnly bool, includeWithoutDomain bool, queryDom
|
|||||||
|
|
||||||
return q.ToSql()
|
return q.ToSql()
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareAuthMethodsForceMFAQuery() (string, error) {
|
|
||||||
loginPolicyQuery, _, err := sq.Select(
|
|
||||||
forceMFAForce.identifier(),
|
|
||||||
forceMFAForceLocalOnly.identifier(),
|
|
||||||
forceMFAInstanceID.identifier(),
|
|
||||||
forceMFAOrgID.identifier(),
|
|
||||||
forceMFAIsDefault.identifier(),
|
|
||||||
).
|
|
||||||
From(forceMFATable.identifier()).
|
|
||||||
ToSql()
|
|
||||||
return loginPolicyQuery, err
|
|
||||||
}
|
|
||||||
|
@@ -14,7 +14,6 @@ import (
|
|||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/api/authz"
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
"github.com/zitadel/zitadel/internal/domain"
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
"github.com/zitadel/zitadel/internal/zerrors"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestUser_authMethodsCheckPermission(t *testing.T) {
|
func TestUser_authMethodsCheckPermission(t *testing.T) {
|
||||||
@@ -664,106 +663,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
|||||||
},
|
},
|
||||||
object: (*AuthMethodTypes)(nil),
|
object: (*AuthMethodTypes)(nil),
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery no result",
|
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
|
||||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
|
||||||
return scan(row)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
want: want{
|
|
||||||
sqlExpectations: mockQueriesScanErr(
|
|
||||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
),
|
|
||||||
err: func(err error) (error, bool) {
|
|
||||||
if !zerrors.IsNotFound(err) {
|
|
||||||
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
|
|
||||||
}
|
|
||||||
return nil, true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
object: (*UserAuthMethodRequirements)(nil),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
|
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
|
||||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
|
||||||
return scan(row)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
want: want{
|
|
||||||
sqlExpectations: mockQueries(
|
|
||||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
|
||||||
prepareAuthMethodTypesRequiredCols,
|
|
||||||
[][]driver.Value{
|
|
||||||
{
|
|
||||||
domain.UserTypeHuman,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
object: &UserAuthMethodRequirements{
|
|
||||||
UserType: domain.UserTypeHuman,
|
|
||||||
ForceMFA: true,
|
|
||||||
ForceMFALocalOnly: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
|
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
|
||||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
|
||||||
return scan(row)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
want: want{
|
|
||||||
sqlExpectations: mockQueries(
|
|
||||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
|
||||||
prepareAuthMethodTypesRequiredCols,
|
|
||||||
[][]driver.Value{
|
|
||||||
{
|
|
||||||
domain.UserTypeHuman,
|
|
||||||
true,
|
|
||||||
true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
),
|
|
||||||
},
|
|
||||||
|
|
||||||
object: &UserAuthMethodRequirements{
|
|
||||||
UserType: domain.UserTypeHuman,
|
|
||||||
ForceMFA: true,
|
|
||||||
ForceMFALocalOnly: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "prepareUserAuthMethodTypesRequiredQuery sql err",
|
|
||||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
|
||||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
|
||||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
|
||||||
return scan(row)
|
|
||||||
}
|
|
||||||
},
|
|
||||||
want: want{
|
|
||||||
sqlExpectations: mockQueryErr(
|
|
||||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
|
||||||
sql.ErrConnDone,
|
|
||||||
),
|
|
||||||
err: func(err error) (error, bool) {
|
|
||||||
if !errors.Is(err, sql.ErrConnDone) {
|
|
||||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
|
||||||
}
|
|
||||||
return nil, true
|
|
||||||
},
|
|
||||||
},
|
|
||||||
object: nil,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
17
internal/query/user_auth_method_types_required.sql
Normal file
17
internal/query/user_auth_method_types_required.sql
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
SELECT
|
||||||
|
projections.users14.type
|
||||||
|
, auth_methods_force_mfa.force_mfa
|
||||||
|
, auth_methods_force_mfa.force_mfa_local_only
|
||||||
|
FROM
|
||||||
|
projections.users14
|
||||||
|
LEFT JOIN
|
||||||
|
projections.login_policies5 AS auth_methods_force_mfa
|
||||||
|
ON
|
||||||
|
auth_methods_force_mfa.instance_id = projections.users14.instance_id
|
||||||
|
AND auth_methods_force_mfa.aggregate_id = ANY(ARRAY[projections.users14.instance_id, projections.users14.resource_owner])
|
||||||
|
WHERE
|
||||||
|
projections.users14.id = $1
|
||||||
|
AND projections.users14.instance_id = $2
|
||||||
|
ORDER BY
|
||||||
|
auth_methods_force_mfa.is_default
|
||||||
|
LIMIT 1;
|
Reference in New Issue
Block a user