feat: allow to force MFA local only (#6234)

This PR adds an option to the LoginPolicy to "Force MFA for local users", so that users authenticated through an IDP must not configure (and verify) an MFA.
This commit is contained in:
Livio Spring
2023-07-20 06:06:16 +02:00
committed by GitHub
parent 1c3a15ff57
commit fed15574f6
49 changed files with 488 additions and 94 deletions

View File

@@ -80,11 +80,12 @@ var (
table: userIDPsCountTable,
}
forceMFATable = loginPolicyTable.setAlias("auth_methods_force_mfa")
forceMFAInstanceID = LoginPolicyColumnInstanceID.setTable(forceMFATable)
forceMFAOrgID = LoginPolicyColumnOrgID.setTable(forceMFATable)
forceMFAIsDefault = LoginPolicyColumnIsDefault.setTable(forceMFATable)
forceMFAForce = LoginPolicyColumnForceMFA.setTable(forceMFATable)
forceMFATable = loginPolicyTable.setAlias("auth_methods_force_mfa")
forceMFAInstanceID = LoginPolicyColumnInstanceID.setTable(forceMFATable)
forceMFAOrgID = LoginPolicyColumnOrgID.setTable(forceMFATable)
forceMFAIsDefault = LoginPolicyColumnIsDefault.setTable(forceMFATable)
forceMFAForce = LoginPolicyColumnForceMFA.setTable(forceMFATable)
forceMFAForceLocalOnly = LoginPolicyColumnForceMFALocalOnly.setTable(forceMFATable)
)
type AuthMethods struct {
@@ -176,11 +177,11 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri
return userAuthMethodTypes, err
}
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string, withOwnerRemoved bool) (userAuthMethodTypes []domain.UserAuthMethodType, forceMFA bool, err error) {
func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID string, withOwnerRemoved bool) (userAuthMethodTypes []domain.UserAuthMethodType, forceMFA, forceMFALocalOnly bool, err error) {
ctxData := authz.GetCtxData(ctx)
if ctxData.UserID != userID {
if err := q.checkPermission(ctx, domain.PermissionUserRead, ctxData.OrgID, userID); err != nil {
return nil, false, err
return nil, false, false, err
}
}
ctx, span := tracing.NewSpan(ctx)
@@ -196,12 +197,12 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
}
stmt, args, err := query.Where(eq).ToSql()
if err != nil {
return nil, false, errors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
return nil, false, false, errors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
}
rows, err := q.client.QueryContext(ctx, stmt, args...)
if err != nil || rows.Err() != nil {
return nil, false, errors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
return nil, false, false, errors.ThrowInternal(err, "QUERY-Dun75", "Errors.Internal")
}
return scan(rows)
}
@@ -408,7 +409,7 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba
}
}
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) ([]domain.UserAuthMethodType, bool, error)) {
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (_ []domain.UserAuthMethodType, forceMFA, forceMFALocalOnly bool, err error)) {
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
if err != nil {
return sq.SelectBuilder{}, nil
@@ -425,7 +426,8 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
NotifyPasswordSetCol.identifier(),
authMethodTypeTypes.identifier(),
userIDPsCountCount.identifier(),
forceMFAForce.identifier()).
forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier()).
From(userTable.identifier()).
LeftJoin(join(NotifyUserIDCol, UserIDCol)).
LeftJoin("("+authMethodsQuery+") AS "+authMethodTypeTable.alias+" ON "+
@@ -439,11 +441,12 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) ([]domain.UserAuthMethodType, bool, error) {
func(rows *sql.Rows) ([]domain.UserAuthMethodType, bool, bool, error) {
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
var passwordSet sql.NullBool
var idp sql.NullInt64
var forceMFA sql.NullBool
var forceMFALocalOnly sql.NullBool
for rows.Next() {
var authMethodType sql.NullInt16
err := rows.Scan(
@@ -451,9 +454,10 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
&authMethodType,
&idp,
&forceMFA,
&forceMFALocalOnly,
)
if err != nil {
return nil, false, err
return nil, false, false, err
}
if authMethodType.Valid {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
@@ -468,10 +472,10 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
}
if err := rows.Close(); err != nil {
return nil, false, errors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
return nil, false, false, errors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
}
return userAuthMethodTypes, forceMFA.Bool, nil
return userAuthMethodTypes, forceMFA.Bool, forceMFALocalOnly.Bool, nil
}
}
@@ -502,6 +506,7 @@ func prepareAuthMethodQuery() (string, []interface{}, error) {
func prepareAuthMethodsForceMFAQuery() (string, error) {
loginPolicyQuery, _, err := sq.Select(
forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier(),
forceMFAInstanceID.identifier(),
forceMFAOrgID.identifier(),
).