From a3fcf6431adb9aeda5dbdcff7d98e322cd631bf9 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Wed, 12 Jun 2024 14:24:17 +0200 Subject: [PATCH] fix(oidc): remove MFA requirement on ZITADEL API based on user auth methods (#8069) # Which Problems Are Solved Request to the ZITADEL API currently require multi factor authentication if the user has set up any second factor. However, the login UI will only prompt the user to check factors that are allowed by the login policy. This can lead to situations, where the user has set up a factor (e.g. some OTP) which was not allowed by the policy, therefore will not have to verify the factor, the ZITADEL API however will require the check since the user has set it up. # How the Problems Are Solved The requirement for multi factor authentication based on the user's authentication methods is removed when accessing the ZITADEL APIs. Those requests will only require MFA in case the login policy does so because of `requireMFA` or `requireMFAForLocalUsers`. # Additional Changes None. # Additional Context - a customer reached out to support - discussed internally - relates #7822 - backport to 2.53.x (cherry picked from commit fb2b1610f9ada78cdfe4b2aed4129c96658ff27e) --- .../session/v2/session_integration_test.go | 9 ---- internal/api/oidc/oidc_integration_test.go | 31 ------------- .../eventstore/token_verifier.go | 4 +- internal/query/user_auth_method.go | 46 ------------------- internal/query/user_auth_method_test.go | 37 ++------------- 5 files changed, 5 insertions(+), 122 deletions(-) diff --git a/internal/api/grpc/session/v2/session_integration_test.go b/internal/api/grpc/session/v2/session_integration_test.go index c6fd1b5d42..92d3b0baf7 100644 --- a/internal/api/grpc/session/v2/session_integration_test.go +++ b/internal/api/grpc/session/v2/session_integration_test.go @@ -949,15 +949,6 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.Nil(t, sessionResp) } -func Test_ZITADEL_API_missing_mfa(t *testing.T) { - id, token, _, _ := Tester.CreatePasswordSession(t, CTX, User.GetUserId(), integration.UserPassword) - - ctx := Tester.WithAuthorizationToken(context.Background(), token) - sessionResp, err := Tester.Client.SessionV2.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.Error(t, err) - require.Nil(t, sessionResp) -} - func Test_ZITADEL_API_success(t *testing.T) { id, token, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId()) diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index 6df8daa132..0baeb53363 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -120,37 +120,6 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.Nil(t, myUserResp) } -func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) { - clientID, _ := createClient(t) - userResp := Tester.CreateHumanUser(CTX) - Tester.SetUserPassword(CTX, userResp.GetUserId(), integration.UserPassword, false) - Tester.RegisterUserU2F(CTX, userResp.GetUserId()) - authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) - sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userResp.GetUserId(), integration.UserPassword) - linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ - AuthRequestId: authRequestID, - CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ - Session: &oidc_pb.Session{ - SessionId: sessionID, - SessionToken: sessionToken, - }, - }, - }) - require.NoError(t, err) - - // code exchange - code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code, redirectURI) - require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID) - - ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) - - myUserResp, err := Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) - require.Error(t, err) - require.Nil(t, myUserResp) -} - func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) { clientID, _ := createClient(t) org := Tester.CreateOrganization(CTXIAM, fmt.Sprintf("ZITADEL_API_MISSING_MFA_%d", time.Now().UnixNano()), fmt.Sprintf("%d@mouse.com", time.Now().UnixNano())) diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index 603146f511..4d8823913d 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -190,8 +190,8 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth if domain.RequiresMFA( requirements.ForceMFA, requirements.ForceMFALocalOnly, - !hasIDPAuthentication(authMethods)) || - domain.Has2FA(requirements.AuthMethods) { + !hasIDPAuthentication(authMethods), + ) { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required") } return nil diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index a350d75360..f1e2721bab 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/call" - "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -180,7 +179,6 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri type UserAuthMethodRequirements struct { UserType domain.UserType - AuthMethods []domain.UserAuthMethodType ForceMFA bool ForceMFALocalOnly bool } @@ -422,30 +420,11 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData if err != nil { return sq.SelectBuilder{}, nil } - authMethodsQuery, authMethodsArgs, err := prepareAggAuthMethodsQuery() - if err != nil { - return sq.SelectBuilder{}, nil - } - idpsQuery, err := prepareAuthMethodsIDPsQuery() - if err != nil { - return sq.SelectBuilder{}, nil - } return sq.Select( - NotifyPasswordSetCol.identifier(), - authMethodTypeTypes.identifier(), - userIDPsCountCount.identifier(), UserTypeCol.identifier(), forceMFAForce.identifier(), forceMFAForceLocalOnly.identifier()). From(userTable.identifier()). - LeftJoin(join(NotifyUserIDCol, UserIDCol)). - LeftJoin("("+authMethodsQuery+") AS "+authMethodTypeTable.alias+" ON "+ - authMethodTypeUserID.identifier()+" = "+UserIDCol.identifier()+" AND "+ - authMethodTypeInstanceID.identifier()+" = "+UserInstanceIDCol.identifier(), - authMethodsArgs...). - LeftJoin("(" + idpsQuery + ") AS " + userIDPsCountTable.alias + " ON " + - userIDPsCountUserID.identifier() + " = " + UserIDCol.identifier() + " AND " + - userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " + "(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " + forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). @@ -453,16 +432,10 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData Limit(1). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*UserAuthMethodRequirements, error) { - var passwordSet sql.NullBool - var authMethodTypes database.NumberArray[domain.UserAuthMethodType] - var idp sql.NullInt64 var userType sql.NullInt32 var forceMFA sql.NullBool var forceMFALocalOnly sql.NullBool err := row.Scan( - &passwordSet, - &authMethodTypes, - &idp, &userType, &forceMFA, &forceMFALocalOnly, @@ -473,16 +446,8 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData } return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal") } - if passwordSet.Valid && passwordSet.Bool { - authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypePassword) - } - if idp.Valid && idp.Int64 > 0 { - authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypeIDP) - } - return &UserAuthMethodRequirements{ UserType: domain.UserType(userType.Int32), - AuthMethods: authMethodTypes, ForceMFA: forceMFA.Bool, ForceMFALocalOnly: forceMFALocalOnly.Bool, }, nil @@ -513,17 +478,6 @@ func prepareAuthMethodQuery() (string, []interface{}, error) { ToSql() } -func prepareAggAuthMethodsQuery() (string, []interface{}, error) { - return sq.Select( - "array_agg(DISTINCT("+authMethodTypeType.identifier()+")) as method_types", - authMethodTypeUserID.identifier(), - authMethodTypeInstanceID.identifier()). - From(authMethodTypeTable.identifier()). - Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}). - GroupBy(authMethodTypeInstanceID.identifier(), authMethodTypeUserID.identifier()). - ToSql() -} - func prepareAuthMethodsForceMFAQuery() (string, error) { loginPolicyQuery, _, err := sq.Select( forceMFAForce.identifier(), diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go index 698667d990..b75e6fd461 100644 --- a/internal/query/user_auth_method_test.go +++ b/internal/query/user_auth_method_test.go @@ -11,7 +11,6 @@ import ( sq "github.com/Masterminds/squirrel" - "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -58,29 +57,16 @@ var ( "method_type", "idps_count", } - prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` + - ` auth_method_types.method_types,` + - ` user_idps_count.count,` + - ` projections.users12.type,` + + prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12.type,` + ` auth_methods_force_mfa.force_mfa,` + ` auth_methods_force_mfa.force_mfa_local_only` + ` FROM projections.users12` + - ` LEFT JOIN projections.users12_notifications ON projections.users12.id = projections.users12_notifications.user_id AND projections.users12.instance_id = projections.users12_notifications.instance_id` + - ` LEFT JOIN (SELECT array_agg(DISTINCT(auth_method_types.method_type)) as method_types, auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` + - ` WHERE auth_method_types.state = $1 GROUP BY auth_method_types.instance_id, auth_method_types.user_id) AS auth_method_types` + - ` ON auth_method_types.user_id = projections.users12.id AND auth_method_types.instance_id = projections.users12.instance_id` + - ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + - ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + - ` ON user_idps_count.user_id = projections.users12.id AND user_idps_count.instance_id = projections.users12.instance_id` + ` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id, auth_methods_force_mfa.is_default FROM projections.login_policies5 AS auth_methods_force_mfa) AS auth_methods_force_mfa` + ` ON (auth_methods_force_mfa.aggregate_id = projections.users12.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users12.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users12.instance_id` + ` ORDER BY auth_methods_force_mfa.is_default LIMIT 1 ` prepareAuthMethodTypesRequiredCols = []string{ - "password_set", "type", - "method_types", - "idps_count", "force_mfa", "force_mfa_local_only", } @@ -356,9 +342,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) { prepareAuthMethodTypesRequiredCols, [][]driver.Value{ { - true, - database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless}, - 1, domain.UserTypeHuman, true, true, @@ -367,12 +350,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { ), }, object: &UserAuthMethodRequirements{ - UserType: domain.UserTypeHuman, - AuthMethods: []domain.UserAuthMethodType{ - domain.UserAuthMethodTypePasswordless, - domain.UserAuthMethodTypePassword, - domain.UserAuthMethodTypeIDP, - }, + UserType: domain.UserTypeHuman, ForceMFA: true, ForceMFALocalOnly: true, }, @@ -391,9 +369,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) { prepareAuthMethodTypesRequiredCols, [][]driver.Value{ { - true, - database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypeTOTP}, - 1, domain.UserTypeHuman, true, true, @@ -403,13 +378,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, object: &UserAuthMethodRequirements{ - UserType: domain.UserTypeHuman, - AuthMethods: []domain.UserAuthMethodType{ - domain.UserAuthMethodTypePasswordless, - domain.UserAuthMethodTypeTOTP, - domain.UserAuthMethodTypePassword, - domain.UserAuthMethodTypeIDP, - }, + UserType: domain.UserTypeHuman, ForceMFA: true, ForceMFALocalOnly: true, },