diff --git a/internal/api/grpc/session/v2/integration_test/session_test.go b/internal/api/grpc/session/v2/integration_test/session_test.go index a9581eea6a0..88eedd22258 100644 --- a/internal/api/grpc/session/v2/integration_test/session_test.go +++ b/internal/api/grpc/session/v2/integration_test/session_test.go @@ -970,6 +970,17 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { }, retryDuration, tick) } +func Test_ZITADEL_API_missing_mfa(t *testing.T) { + mfaUser := createFullUser(CTX) + registerTOTP(CTX, t, mfaUser.GetUserId()) + id, token, _, _ := Instance.CreatePasswordSession(t, LoginCTX, mfaUser.GetUserId(), integration.UserPassword) + ctx := integration.WithAuthorizationToken(context.Background(), token) + + sessionResp, err := Instance.Client.SessionV2.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) + require.Error(t, err) + require.Nil(t, sessionResp) +} + func Test_ZITADEL_API_success(t *testing.T) { id, token, _, _ := Instance.CreateVerifiedWebAuthNSession(t, LoginCTX, User.GetUserId()) ctx := integration.WithAuthorizationToken(context.Background(), token) diff --git a/internal/api/oidc/integration_test/oidc_test.go b/internal/api/oidc/integration_test/oidc_test.go index f7ed2da635e..7a1455a1acd 100644 --- a/internal/api/oidc/integration_test/oidc_test.go +++ b/internal/api/oidc/integration_test/oidc_test.go @@ -120,6 +120,38 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.Nil(t, myUserResp) } +func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) { + clientID, _ := createClient(t, Instance) + org := Instance.CreateOrganization(CTXIAM, integration.OrganizationName(), integration.Email()) + userID := org.CreatedAdmins[0].GetUserId() + Instance.SetUserPassword(CTXIAM, userID, integration.UserPassword, false) + Instance.RegisterUserU2F(CTXIAM, userID) + authRequestID := createAuthRequest(t, Instance, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) + sessionID, sessionToken, startTime, changeTime := Instance.CreatePasswordSession(t, CTXLOGIN, userID, integration.UserPassword) + linkResp, err := Instance.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, Instance, clientID, code, redirectURI) + require.NoError(t, err) + assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime, sessionID) + + ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) + + myUserResp, err := Instance.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) + require.Error(t, err) + require.Nil(t, myUserResp) +} + func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) { clientID, _ := createClient(t, Instance) org := Instance.CreateOrganization(CTXIAM, integration.OrganizationName(), integration.Email()) diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index d6c14afea36..a71ad4a880f 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -4,6 +4,7 @@ import ( "context" "encoding/base64" "fmt" + "slices" "strings" "time" @@ -177,6 +178,7 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth if len(authMethods) == 0 { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "authentication required") } + // if the user has MFA, we don't need to check any mfa requirements if domain.HasMFA(authMethods) { return nil } @@ -184,19 +186,38 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth if err != nil { return err } + // machine users do not have interactive logins, so we don't check for MFA requirements if requirements.UserType == domain.UserTypeMachine { return nil } - if domain.RequiresMFA( - requirements.ForceMFA, - requirements.ForceMFALocalOnly, - !hasIDPAuthentication(authMethods), - ) { + // we'll only require 2FA factors, that are allowed by the policy + allowedFactors := allowed2FAFactors(requirements.AllowedSecondFactors, requirements.SetUpFactors) + // if either the user has set up a factor that is allowed by the policy + // or the policy requires MFA, we'll require it and can directly return the error + // since the token/session was not authenticated with MFA + if domain.Has2FA(allowedFactors) || + domain.RequiresMFA( + requirements.ForceMFA, + requirements.ForceMFALocalOnly, + !hasIDPAuthentication(authMethods), + ) { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required") } return nil } +func allowed2FAFactors(factors []domain.SecondFactorType, authMethods []domain.UserAuthMethodType) []domain.UserAuthMethodType { + allowedFactors := make([]domain.UserAuthMethodType, 0, len(factors)) + for _, method := range authMethods { + factorType := domain.AuthMethodToSecondFactor(method) + if factorType != domain.SecondFactorTypeUnspecified && + slices.Contains(factors, factorType) { + allowedFactors = append(allowedFactors, method) + } + } + return allowedFactors +} + func hasIDPAuthentication(authMethods []domain.UserAuthMethodType) bool { for _, method := range authMethods { if method == domain.UserAuthMethodTypeIDP { diff --git a/internal/domain/user.go b/internal/domain/user.go index da32e174904..c59ff0563f5 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -107,6 +107,25 @@ func RequiresMFA(forceMFA, forceMFALocalOnly, isInternalLogin bool) bool { return forceMFA && !forceMFALocalOnly } +// AuthMethodToSecondFactor maps user auth methods to their corresponding second factor types +func AuthMethodToSecondFactor(method UserAuthMethodType) SecondFactorType { + switch method { + case UserAuthMethodTypeTOTP: + return SecondFactorTypeTOTP + case UserAuthMethodTypeU2F: + return SecondFactorTypeU2F + case UserAuthMethodTypeOTPSMS: + return SecondFactorTypeOTPSMS + case UserAuthMethodTypeOTPEmail: + return SecondFactorTypeOTPEmail + case UserAuthMethodTypeOTP: + return SecondFactorTypeOTPSMS + default: + // First-factor methods: password, IDP, passwordless, private key + return 0 + } +} + type PersonalAccessTokenState int32 const ( diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index 803c0908678..ab8c5634e71 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -12,6 +12,7 @@ import ( "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -222,9 +223,11 @@ func (q *Queries) ListUserAuthMethodTypes(ctx context.Context, userID string, ac } type UserAuthMethodRequirements struct { - UserType domain.UserType - ForceMFA bool - ForceMFALocalOnly bool + UserType domain.UserType + ForceMFA bool + ForceMFALocalOnly bool + AllowedSecondFactors []domain.SecondFactorType + SetUpFactors []domain.UserAuthMethodType } //go:embed user_auth_method_types_required.sql @@ -245,10 +248,14 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st var userType sql.NullInt32 var forceMFA sql.NullBool var forceMFALocalOnly sql.NullBool + var allowedSecondFactors database.NumberArray[domain.SecondFactorType] + var setUpFactors database.NumberArray[domain.UserAuthMethodType] err := row.Scan( &userType, &forceMFA, &forceMFALocalOnly, + &allowedSecondFactors, + &setUpFactors, ) if err != nil { if errors.Is(err, sql.ErrNoRows) { @@ -257,9 +264,11 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st return zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal") } requirements = &UserAuthMethodRequirements{ - UserType: domain.UserType(userType.Int32), - ForceMFA: forceMFA.Bool, - ForceMFALocalOnly: forceMFALocalOnly.Bool, + UserType: domain.UserType(userType.Int32), + ForceMFA: forceMFA.Bool, + ForceMFALocalOnly: forceMFALocalOnly.Bool, + AllowedSecondFactors: allowedSecondFactors, + SetUpFactors: setUpFactors, } return nil }, diff --git a/internal/query/user_auth_method_types_required.sql b/internal/query/user_auth_method_types_required.sql index d10420f0ebe..3c24b39fc9e 100644 --- a/internal/query/user_auth_method_types_required.sql +++ b/internal/query/user_auth_method_types_required.sql @@ -1,17 +1,28 @@ -SELECT +SELECT projections.users14.type , auth_methods_force_mfa.force_mfa - , auth_methods_force_mfa.force_mfa_local_only -FROM - projections.users14 -LEFT JOIN + , auth_methods_force_mfa.force_mfa_local_only + , auth_methods_force_mfa.second_factors + , user_auth_methods5.auth_method_types +FROM + projections.users14 +LEFT JOIN projections.login_policies5 AS auth_methods_force_mfa ON auth_methods_force_mfa.instance_id = projections.users14.instance_id AND auth_methods_force_mfa.aggregate_id = ANY(ARRAY[projections.users14.instance_id, projections.users14.resource_owner]) -WHERE +LEFT JOIN LATERAL ( + SELECT + ARRAY_AGG(projections.user_auth_methods5.method_type) AS auth_method_types + FROM + projections.user_auth_methods5 + WHERE + projections.user_auth_methods5.user_id = projections.users14.id + AND projections.user_auth_methods5.instance_id = projections.users14.instance_id + ) AS user_auth_methods5 ON TRUE +WHERE projections.users14.id = $1 AND projections.users14.instance_id = $2 -ORDER BY - auth_methods_force_mfa.is_default +ORDER BY + auth_methods_force_mfa.is_default LIMIT 1; \ No newline at end of file