perf: improve scalability of session api (#9635)

This pull request improves the scalability of the session API by
enhancing middleware tracing and refining SQL query behavior for user
authentication methods.

# Which Problems Are Solved

- Eventstore subscriptions locked each other during they wrote the
events to the event channels of the subscribers in push.
- `ListUserAuthMethodTypesRequired` query used `Bitmap heap scan` to
join the tables needed.
- The auth and oidc package triggered projections often when data were
read.
- The session API triggered the user projection each time a user was
searched to write the user check command.

# How the Problems Are Solved

- the `sync.Mutex` was replaced with `sync.RWMutex` to allow parallel
read of the map
- The query was refactored to use index scans only
- if the data should already be up-to-date `shouldTriggerBulk` is set to
false
- as the user should already exist for some time the trigger was
removed.

# Additional Changes

- refactoring of `tracing#Span.End` calls

# Additional Context

- part of https://github.com/zitadel/zitadel/issues/9239

---------

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
Silvan
2025-03-28 13:36:05 +01:00
committed by GitHub
parent 79d1e7d434
commit 817670f1f7
13 changed files with 101 additions and 226 deletions

View File

@@ -3,6 +3,7 @@ package query
import (
"context"
"database/sql"
_ "embed"
"errors"
"slices"
"time"
@@ -212,6 +213,9 @@ type UserAuthMethodRequirements struct {
ForceMFALocalOnly bool
}
//go:embed user_auth_method_types_required.sql
var listUserAuthMethodTypesStmt string
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string) (requirements *UserAuthMethodRequirements, err error) {
ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID {
@@ -222,20 +226,33 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, q.client)
eq := sq.Eq{
UserIDCol.identifier(): userID,
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
}
stmt, args, err := query.Where(eq).ToSql()
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
requirements, err = scan(row)
return err
}, stmt, args...)
err = q.client.QueryRowContext(ctx,
func(row *sql.Row) error {
var userType sql.NullInt32
var forceMFA sql.NullBool
var forceMFALocalOnly sql.NullBool
err := row.Scan(
&userType,
&forceMFA,
&forceMFALocalOnly,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
}
return zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
}
requirements = &UserAuthMethodRequirements{
UserType: domain.UserType(userType.Int32),
ForceMFA: forceMFA.Bool,
ForceMFALocalOnly: forceMFALocalOnly.Bool,
}
return nil
},
listUserAuthMethodTypesStmt,
userID,
authz.GetInstance(ctx).InstanceID(),
)
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
}
@@ -461,45 +478,6 @@ func prepareUserAuthMethodTypesQuery(ctx context.Context, db prepareDatabase, ac
}
}
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
if err != nil {
return sq.SelectBuilder{}, nil
}
return sq.Select(
UserTypeCol.identifier(),
forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier()).
From(userTable.identifier()).
LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " +
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
OrderBy(forceMFAIsDefault.identifier()).
Limit(1).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*UserAuthMethodRequirements, error) {
var userType sql.NullInt32
var forceMFA sql.NullBool
var forceMFALocalOnly sql.NullBool
err := row.Scan(
&userType,
&forceMFA,
&forceMFALocalOnly,
)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
}
return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
}
return &UserAuthMethodRequirements{
UserType: domain.UserType(userType.Int32),
ForceMFA: forceMFA.Bool,
ForceMFALocalOnly: forceMFALocalOnly.Bool,
}, nil
}
}
func prepareAuthMethodsIDPsQuery() (string, error) {
idpsQuery, _, err := sq.Select(
userIDPsCountUserID.identifier(),
@@ -536,16 +514,3 @@ func prepareAuthMethodQuery(activeOnly bool, includeWithoutDomain bool, queryDom
return q.ToSql()
}
func prepareAuthMethodsForceMFAQuery() (string, error) {
loginPolicyQuery, _, err := sq.Select(
forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier(),
forceMFAInstanceID.identifier(),
forceMFAOrgID.identifier(),
forceMFAIsDefault.identifier(),
).
From(forceMFATable.identifier()).
ToSql()
return loginPolicyQuery, err
}