mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-06 17:22:28 +00:00
perf: improve scalability of session api (#9635)
This pull request improves the scalability of the session API by enhancing middleware tracing and refining SQL query behavior for user authentication methods. # Which Problems Are Solved - Eventstore subscriptions locked each other during they wrote the events to the event channels of the subscribers in push. - `ListUserAuthMethodTypesRequired` query used `Bitmap heap scan` to join the tables needed. - The auth and oidc package triggered projections often when data were read. - The session API triggered the user projection each time a user was searched to write the user check command. # How the Problems Are Solved - the `sync.Mutex` was replaced with `sync.RWMutex` to allow parallel read of the map - The query was refactored to use index scans only - if the data should already be up-to-date `shouldTriggerBulk` is set to false - as the user should already exist for some time the trigger was removed. # Additional Changes - refactoring of `tracing#Span.End` calls # Additional Context - part of https://github.com/zitadel/zitadel/issues/9239 --------- Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
@@ -428,34 +428,6 @@ func (q *Queries) GetUserByLoginName(ctx context.Context, shouldTriggered bool,
|
||||
return user, err
|
||||
}
|
||||
|
||||
// Deprecated: use either GetUserByID or GetUserByLoginName
|
||||
func (q *Queries) GetUser(ctx context.Context, shouldTriggerBulk bool, queries ...SearchQuery) (user *User, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if shouldTriggerBulk {
|
||||
triggerUserProjections(ctx)
|
||||
}
|
||||
|
||||
query, scan := prepareUserQuery(ctx, q.client)
|
||||
for _, q := range queries {
|
||||
query = q.toQuery(query)
|
||||
}
|
||||
eq := sq.Eq{
|
||||
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||
}
|
||||
stmt, args, err := query.Where(eq).ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-Dnhr2", "Errors.Query.SQLStatment")
|
||||
}
|
||||
|
||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
user, err = scan(row)
|
||||
return err
|
||||
}, stmt, args...)
|
||||
return user, err
|
||||
}
|
||||
|
||||
func (q *Queries) GetHumanProfile(ctx context.Context, userID string, queries ...SearchQuery) (profile *Profile, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
@@ -3,6 +3,7 @@ package query
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
_ "embed"
|
||||
"errors"
|
||||
"slices"
|
||||
"time"
|
||||
@@ -212,6 +213,9 @@ type UserAuthMethodRequirements struct {
|
||||
ForceMFALocalOnly bool
|
||||
}
|
||||
|
||||
//go:embed user_auth_method_types_required.sql
|
||||
var listUserAuthMethodTypesStmt string
|
||||
|
||||
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) {
|
||||
ctxData := authz.GetCtxData(ctx)
|
||||
if ctxData.UserID != userID {
|
||||
@@ -222,20 +226,33 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
query, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, q.client)
|
||||
eq := sq.Eq{
|
||||
UserIDCol.identifier(): userID,
|
||||
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||
}
|
||||
stmt, args, err := query.Where(eq).ToSql()
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
|
||||
}
|
||||
|
||||
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
|
||||
requirements, err = scan(row)
|
||||
return err
|
||||
}, stmt, args...)
|
||||
err = q.client.QueryRowContext(ctx,
|
||||
func(row *sql.Row) error {
|
||||
var userType sql.NullInt32
|
||||
var forceMFA sql.NullBool
|
||||
var forceMFALocalOnly sql.NullBool
|
||||
err := row.Scan(
|
||||
&userType,
|
||||
&forceMFA,
|
||||
&forceMFALocalOnly,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
|
||||
}
|
||||
return zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
|
||||
}
|
||||
requirements = &UserAuthMethodRequirements{
|
||||
UserType: domain.UserType(userType.Int32),
|
||||
ForceMFA: forceMFA.Bool,
|
||||
ForceMFALocalOnly: forceMFALocalOnly.Bool,
|
||||
}
|
||||
return nil
|
||||
},
|
||||
listUserAuthMethodTypesStmt,
|
||||
userID,
|
||||
authz.GetInstance(ctx).InstanceID(),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
|
||||
}
|
||||
@@ -461,45 +478,6 @@ func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, ac
|
||||
}
|
||||
}
|
||||
|
||||
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
||||
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
|
||||
if err != nil {
|
||||
return sq.SelectBuilder{}, nil
|
||||
}
|
||||
return sq.Select(
|
||||
UserTypeCol.identifier(),
|
||||
forceMFAForce.identifier(),
|
||||
forceMFAForceLocalOnly.identifier()).
|
||||
From(userTable.identifier()).
|
||||
LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " +
|
||||
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
|
||||
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
|
||||
OrderBy(forceMFAIsDefault.identifier()).
|
||||
Limit(1).
|
||||
PlaceholderFormat(sq.Dollar),
|
||||
func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
||||
var userType sql.NullInt32
|
||||
var forceMFA sql.NullBool
|
||||
var forceMFALocalOnly sql.NullBool
|
||||
err := row.Scan(
|
||||
&userType,
|
||||
&forceMFA,
|
||||
&forceMFALocalOnly,
|
||||
)
|
||||
if err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
|
||||
}
|
||||
return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
|
||||
}
|
||||
return &UserAuthMethodRequirements{
|
||||
UserType: domain.UserType(userType.Int32),
|
||||
ForceMFA: forceMFA.Bool,
|
||||
ForceMFALocalOnly: forceMFALocalOnly.Bool,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func prepareAuthMethodsIDPsQuery() (string, error) {
|
||||
idpsQuery, _, err := sq.Select(
|
||||
userIDPsCountUserID.identifier(),
|
||||
@@ -536,16 +514,3 @@ func prepareAuthMethodQuery(activeOnly bool, includeWithoutDomain bool, queryDom
|
||||
|
||||
return q.ToSql()
|
||||
}
|
||||
|
||||
func prepareAuthMethodsForceMFAQuery() (string, error) {
|
||||
loginPolicyQuery, _, err := sq.Select(
|
||||
forceMFAForce.identifier(),
|
||||
forceMFAForceLocalOnly.identifier(),
|
||||
forceMFAInstanceID.identifier(),
|
||||
forceMFAOrgID.identifier(),
|
||||
forceMFAIsDefault.identifier(),
|
||||
).
|
||||
From(forceMFATable.identifier()).
|
||||
ToSql()
|
||||
return loginPolicyQuery, err
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func TestUser_authMethodsCheckPermission(t *testing.T) {
|
||||
@@ -664,106 +663,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
|
||||
},
|
||||
object: (*AuthMethodTypes)(nil),
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery no result",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
||||
return scan(row)
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueriesScanErr(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
nil,
|
||||
nil,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !zerrors.IsNotFound(err) {
|
||||
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: (*UserAuthMethodRequirements)(nil),
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
||||
return scan(row)
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
prepareAuthMethodTypesRequiredCols,
|
||||
[][]driver.Value{
|
||||
{
|
||||
domain.UserTypeHuman,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
object: &UserAuthMethodRequirements{
|
||||
UserType: domain.UserTypeHuman,
|
||||
ForceMFA: true,
|
||||
ForceMFALocalOnly: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
||||
return scan(row)
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueries(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
prepareAuthMethodTypesRequiredCols,
|
||||
[][]driver.Value{
|
||||
{
|
||||
domain.UserTypeHuman,
|
||||
true,
|
||||
true,
|
||||
},
|
||||
},
|
||||
),
|
||||
},
|
||||
|
||||
object: &UserAuthMethodRequirements{
|
||||
UserType: domain.UserTypeHuman,
|
||||
ForceMFA: true,
|
||||
ForceMFALocalOnly: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "prepareUserAuthMethodTypesRequiredQuery sql err",
|
||||
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
|
||||
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
|
||||
return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
|
||||
return scan(row)
|
||||
}
|
||||
},
|
||||
want: want{
|
||||
sqlExpectations: mockQueryErr(
|
||||
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
|
||||
sql.ErrConnDone,
|
||||
),
|
||||
err: func(err error) (error, bool) {
|
||||
if !errors.Is(err, sql.ErrConnDone) {
|
||||
return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false
|
||||
}
|
||||
return nil, true
|
||||
},
|
||||
},
|
||||
object: nil,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
|
||||
17
internal/query/user_auth_method_types_required.sql
Normal file
17
internal/query/user_auth_method_types_required.sql
Normal file
@@ -0,0 +1,17 @@
|
||||
SELECT
|
||||
projections.users14.type
|
||||
, auth_methods_force_mfa.force_mfa
|
||||
, auth_methods_force_mfa.force_mfa_local_only
|
||||
FROM
|
||||
projections.users14
|
||||
LEFT JOIN
|
||||
projections.login_policies5 AS auth_methods_force_mfa
|
||||
ON
|
||||
auth_methods_force_mfa.instance_id = projections.users14.instance_id
|
||||
AND auth_methods_force_mfa.aggregate_id = ANY(ARRAY[projections.users14.instance_id, projections.users14.resource_owner])
|
||||
WHERE
|
||||
projections.users14.id = $1
|
||||
AND projections.users14.instance_id = $2
|
||||
ORDER BY
|
||||
auth_methods_force_mfa.is_default
|
||||
LIMIT 1;
|
||||
Reference in New Issue
Block a user