diff --git a/internal/api/grpc/session/v2/session_integration_test.go b/internal/api/grpc/session/v2/session_integration_test.go index c6fd1b5d42..92d3b0baf7 100644 --- a/internal/api/grpc/session/v2/session_integration_test.go +++ b/internal/api/grpc/session/v2/session_integration_test.go @@ -949,15 +949,6 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.Nil(t, sessionResp) } -func Test_ZITADEL_API_missing_mfa(t *testing.T) { - id, token, _, _ := Tester.CreatePasswordSession(t, CTX, User.GetUserId(), integration.UserPassword) - - ctx := Tester.WithAuthorizationToken(context.Background(), token) - sessionResp, err := Tester.Client.SessionV2.GetSession(ctx, &session.GetSessionRequest{SessionId: id}) - require.Error(t, err) - require.Nil(t, sessionResp) -} - func Test_ZITADEL_API_success(t *testing.T) { id, token, _, _ := Tester.CreateVerifiedWebAuthNSession(t, CTX, User.GetUserId()) diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index 6df8daa132..0baeb53363 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -120,37 +120,6 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.Nil(t, myUserResp) } -func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) { - clientID, _ := createClient(t) - userResp := Tester.CreateHumanUser(CTX) - Tester.SetUserPassword(CTX, userResp.GetUserId(), integration.UserPassword, false) - Tester.RegisterUserU2F(CTX, userResp.GetUserId()) - authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) - sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userResp.GetUserId(), integration.UserPassword) - linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ - AuthRequestId: authRequestID, - CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ - Session: &oidc_pb.Session{ - SessionId: sessionID, - SessionToken: sessionToken, - }, - }, - }) - require.NoError(t, err) - - // code exchange - code := assertCodeResponse(t, linkResp.GetCallbackUrl()) - tokens, err := exchangeTokens(t, clientID, code, redirectURI) - require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime, sessionID) - - ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) - - myUserResp, err := Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) - require.Error(t, err) - require.Nil(t, myUserResp) -} - func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) { clientID, _ := createClient(t) org := Tester.CreateOrganization(CTXIAM, fmt.Sprintf("ZITADEL_API_MISSING_MFA_%d", time.Now().UnixNano()), fmt.Sprintf("%d@mouse.com", time.Now().UnixNano())) diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index 603146f511..4d8823913d 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -190,8 +190,8 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth if domain.RequiresMFA( requirements.ForceMFA, requirements.ForceMFALocalOnly, - !hasIDPAuthentication(authMethods)) || - domain.Has2FA(requirements.AuthMethods) { + !hasIDPAuthentication(authMethods), + ) { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required") } return nil diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index a350d75360..f1e2721bab 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -11,7 +11,6 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/call" - "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -180,7 +179,6 @@ func (q *Queries) ListActiveUserAuthMethodTypes(ctx context.Context, userID stri type UserAuthMethodRequirements struct { UserType domain.UserType - AuthMethods []domain.UserAuthMethodType ForceMFA bool ForceMFALocalOnly bool } @@ -422,30 +420,11 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData if err != nil { return sq.SelectBuilder{}, nil } - authMethodsQuery, authMethodsArgs, err := prepareAggAuthMethodsQuery() - if err != nil { - return sq.SelectBuilder{}, nil - } - idpsQuery, err := prepareAuthMethodsIDPsQuery() - if err != nil { - return sq.SelectBuilder{}, nil - } return sq.Select( - NotifyPasswordSetCol.identifier(), - authMethodTypeTypes.identifier(), - userIDPsCountCount.identifier(), UserTypeCol.identifier(), forceMFAForce.identifier(), forceMFAForceLocalOnly.identifier()). From(userTable.identifier()). - LeftJoin(join(NotifyUserIDCol, UserIDCol)). - LeftJoin("("+authMethodsQuery+") AS "+authMethodTypeTable.alias+" ON "+ - authMethodTypeUserID.identifier()+" = "+UserIDCol.identifier()+" AND "+ - authMethodTypeInstanceID.identifier()+" = "+UserInstanceIDCol.identifier(), - authMethodsArgs...). - LeftJoin("(" + idpsQuery + ") AS " + userIDPsCountTable.alias + " ON " + - userIDPsCountUserID.identifier() + " = " + UserIDCol.identifier() + " AND " + - userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " + "(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " + forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). @@ -453,16 +432,10 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData Limit(1). PlaceholderFormat(sq.Dollar), func(row *sql.Row) (*UserAuthMethodRequirements, error) { - var passwordSet sql.NullBool - var authMethodTypes database.NumberArray[domain.UserAuthMethodType] - var idp sql.NullInt64 var userType sql.NullInt32 var forceMFA sql.NullBool var forceMFALocalOnly sql.NullBool err := row.Scan( - &passwordSet, - &authMethodTypes, - &idp, &userType, &forceMFA, &forceMFALocalOnly, @@ -473,16 +446,8 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData } return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal") } - if passwordSet.Valid && passwordSet.Bool { - authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypePassword) - } - if idp.Valid && idp.Int64 > 0 { - authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypeIDP) - } - return &UserAuthMethodRequirements{ UserType: domain.UserType(userType.Int32), - AuthMethods: authMethodTypes, ForceMFA: forceMFA.Bool, ForceMFALocalOnly: forceMFALocalOnly.Bool, }, nil @@ -513,17 +478,6 @@ func prepareAuthMethodQuery() (string, []interface{}, error) { ToSql() } -func prepareAggAuthMethodsQuery() (string, []interface{}, error) { - return sq.Select( - "array_agg(DISTINCT("+authMethodTypeType.identifier()+")) as method_types", - authMethodTypeUserID.identifier(), - authMethodTypeInstanceID.identifier()). - From(authMethodTypeTable.identifier()). - Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}). - GroupBy(authMethodTypeInstanceID.identifier(), authMethodTypeUserID.identifier()). - ToSql() -} - func prepareAuthMethodsForceMFAQuery() (string, error) { loginPolicyQuery, _, err := sq.Select( forceMFAForce.identifier(), diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go index 698667d990..b75e6fd461 100644 --- a/internal/query/user_auth_method_test.go +++ b/internal/query/user_auth_method_test.go @@ -11,7 +11,6 @@ import ( sq "github.com/Masterminds/squirrel" - "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -58,29 +57,16 @@ var ( "method_type", "idps_count", } - prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` + - ` auth_method_types.method_types,` + - ` user_idps_count.count,` + - ` projections.users12.type,` + + prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12.type,` + ` auth_methods_force_mfa.force_mfa,` + ` auth_methods_force_mfa.force_mfa_local_only` + ` FROM projections.users12` + - ` LEFT JOIN projections.users12_notifications ON projections.users12.id = projections.users12_notifications.user_id AND projections.users12.instance_id = projections.users12_notifications.instance_id` + - ` LEFT JOIN (SELECT array_agg(DISTINCT(auth_method_types.method_type)) as method_types, auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` + - ` WHERE auth_method_types.state = $1 GROUP BY auth_method_types.instance_id, auth_method_types.user_id) AS auth_method_types` + - ` ON auth_method_types.user_id = projections.users12.id AND auth_method_types.instance_id = projections.users12.instance_id` + - ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + - ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + - ` ON user_idps_count.user_id = projections.users12.id AND user_idps_count.instance_id = projections.users12.instance_id` + ` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id, auth_methods_force_mfa.is_default FROM projections.login_policies5 AS auth_methods_force_mfa) AS auth_methods_force_mfa` + ` ON (auth_methods_force_mfa.aggregate_id = projections.users12.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users12.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users12.instance_id` + ` ORDER BY auth_methods_force_mfa.is_default LIMIT 1 ` prepareAuthMethodTypesRequiredCols = []string{ - "password_set", "type", - "method_types", - "idps_count", "force_mfa", "force_mfa_local_only", } @@ -356,9 +342,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) { prepareAuthMethodTypesRequiredCols, [][]driver.Value{ { - true, - database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless}, - 1, domain.UserTypeHuman, true, true, @@ -367,12 +350,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { ), }, object: &UserAuthMethodRequirements{ - UserType: domain.UserTypeHuman, - AuthMethods: []domain.UserAuthMethodType{ - domain.UserAuthMethodTypePasswordless, - domain.UserAuthMethodTypePassword, - domain.UserAuthMethodTypeIDP, - }, + UserType: domain.UserTypeHuman, ForceMFA: true, ForceMFALocalOnly: true, }, @@ -391,9 +369,6 @@ func Test_UserAuthMethodPrepares(t *testing.T) { prepareAuthMethodTypesRequiredCols, [][]driver.Value{ { - true, - database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypeTOTP}, - 1, domain.UserTypeHuman, true, true, @@ -403,13 +378,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, object: &UserAuthMethodRequirements{ - UserType: domain.UserTypeHuman, - AuthMethods: []domain.UserAuthMethodType{ - domain.UserAuthMethodTypePasswordless, - domain.UserAuthMethodTypeTOTP, - domain.UserAuthMethodTypePassword, - domain.UserAuthMethodTypeIDP, - }, + UserType: domain.UserTypeHuman, ForceMFA: true, ForceMFALocalOnly: true, },