From ec222a13d7a42a9a5e523a8e7f8dbb923747e654 Mon Sep 17 00:00:00 2001 From: Livio Spring Date: Tue, 28 May 2024 10:59:49 +0200 Subject: [PATCH] fix(oidc): IDP and passwordless user auth methods (#7998) # Which Problems Are Solved As already mentioned and (partially) fixed in #7992 we discovered, issues with v2 tokens that where obtained through an IDP, with passwordless authentication or with password authentication (wihtout any 2FA set up) using the v1 login for zitadel API calls - (Previous) authentication through an IdP is now correctly treated as auth method in case of a reauth even when the user is not redirected to the IdP - There were some cases where passwordless authentication was successfully checked but not correctly set as auth method, which denied access to ZITADEL API - Users with password and passwordless, but no 2FA set up which authenticate just wich password can access the ZITADEL API again Additionally while testing we found out that because of #7969 the login UI could completely break / block with the following error: `sql: Scan error on column index 3, name "state": converting NULL to int32 is unsupported (Internal)` # How the Problems Are Solved - IdP checks are treated the same way as other factors and it's ensured that a succeeded check within the configured timeframe will always provide the idp auth method - `MFATypesAllowed` checks for possible passwordless authentication - As with the v1 login, the token check now only requires MFA if the policy is set or the user has 2FA set up - UserAuthMethodsRequirements now always uses the correctly policy to check for MFA enforcement - `State` column is handled as nullable and additional events set the state to active (as before #7969) # Additional Changes - Console now also checks for 403 (mfa required) errors (e.g. after setting up the first 2FA in console) and redirects the user to the login UI (with the current id_token as id_token_hint) - Possible duplicates in auth methods / AMRs are removed now as well. # Additional Context - Bugs were introduced in #7822 and # and 7969 and only part of a pre-release. - partially already fixed with #7992 - Reported internally. --- .../app/services/authentication.service.ts | 4 + console/src/app/services/grpc.service.ts | 1 + .../services/interceptors/auth.interceptor.ts | 14 +++- internal/api/oidc/oidc_integration_test.go | 68 +++++++++++++++- .../eventsourcing/eventstore/auth_request.go | 33 ++++++-- .../eventstore/auth_request_test.go | 56 +++++++++++-- .../eventsourcing/eventstore/user.go | 2 +- .../eventsourcing/handler/user_session.go | 3 + .../eventstore/token_verifier.go | 4 +- internal/domain/auth_request.go | 3 +- internal/domain/user.go | 27 ++++++- internal/query/user_auth_method.go | 81 +++++++++++-------- internal/query/user_auth_method_test.go | 60 +++++++------- .../repository/view/model/user_session.go | 24 +++--- .../view/model/user_session_test.go | 5 +- 15 files changed, 281 insertions(+), 104 deletions(-) diff --git a/console/src/app/services/authentication.service.ts b/console/src/app/services/authentication.service.ts index 314b9edb7e..3ee3024e40 100644 --- a/console/src/app/services/authentication.service.ts +++ b/console/src/app/services/authentication.service.ts @@ -35,6 +35,10 @@ export class AuthenticationService { return from(this.oauthService.loadUserProfile()); } + public getIdToken(): string { + return this.oauthService.getIdToken(); + } + public async authenticate(partialConfig?: Partial, force: boolean = false): Promise { if (partialConfig) { Object.assign(this.authConfig, partialConfig); diff --git a/console/src/app/services/grpc.service.ts b/console/src/app/services/grpc.service.ts index 8a87f95ef0..5c4c0fd510 100644 --- a/console/src/app/services/grpc.service.ts +++ b/console/src/app/services/grpc.service.ts @@ -18,6 +18,7 @@ import { I18nInterceptor } from './interceptors/i18n.interceptor'; import { OrgInterceptor } from './interceptors/org.interceptor'; import { StorageService } from './storage.service'; import { FeatureServiceClient } from '../proto/generated/zitadel/feature/v2beta/Feature_serviceServiceClientPb'; +import { GrpcAuthService } from './grpc-auth.service'; @Injectable({ providedIn: 'root', diff --git a/console/src/app/services/interceptors/auth.interceptor.ts b/console/src/app/services/interceptors/auth.interceptor.ts index d21bb5cdaa..4ccdc768c7 100644 --- a/console/src/app/services/interceptors/auth.interceptor.ts +++ b/console/src/app/services/interceptors/auth.interceptor.ts @@ -2,11 +2,13 @@ import { Injectable } from '@angular/core'; import { MatDialog } from '@angular/material/dialog'; import { Request, UnaryInterceptor, UnaryResponse } from 'grpc-web'; import { Subject } from 'rxjs'; -import { debounceTime, filter, first, take } from 'rxjs/operators'; +import { debounceTime, filter, first, map, take, tap } from 'rxjs/operators'; import { WarnDialogComponent } from 'src/app/modules/warn-dialog/warn-dialog.component'; import { AuthenticationService } from '../authentication.service'; import { StorageService } from '../storage.service'; +import { AuthConfig } from 'angular-oauth2-oidc'; +import { GrpcAuthService } from '../grpc-auth.service'; const authorizationKey = 'Authorization'; const bearerPrefix = 'Bearer'; @@ -44,7 +46,7 @@ export class AuthInterceptor implements UnaryIn return response; }) .catch(async (error: any) => { - if (error.code === 16) { + if (error.code === 16 || (error.code === 7 && error.message === 'mfa required (AUTHZ-Kl3p0)')) { this.triggerDialog.next(true); } return Promise.reject(error); @@ -67,7 +69,13 @@ export class AuthInterceptor implements UnaryIn .pipe(take(1)) .subscribe((resp) => { if (resp) { - this.authenticationService.authenticate(undefined, true); + const idToken = this.authenticationService.getIdToken(); + const configWithPrompt: Partial = { + customQueryParams: { + id_token_hint: idToken, + }, + }; + this.authenticationService.authenticate(configWithPrompt, true); } }); } diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index 09e76391bd..c7a101afd3 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -18,6 +18,7 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/pkg/grpc/auth" + mgmt "github.com/zitadel/zitadel/pkg/grpc/management" oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" @@ -26,6 +27,7 @@ import ( var ( CTX context.Context CTXLOGIN context.Context + CTXIAM context.Context Tester *integration.Tester User *user.AddHumanUserResponse ) @@ -50,6 +52,7 @@ func TestMain(m *testing.M) { Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword, false) Tester.RegisterUserPasskey(CTX, User.GetUserId()) CTXLOGIN = Tester.WithAuthorization(ctx, integration.Login) + CTXIAM = Tester.WithAuthorization(ctx, integration.IAMOwner) return m.Run() }()) } @@ -117,10 +120,13 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.Nil(t, myUserResp) } -func Test_ZITADEL_API_missing_mfa(t *testing.T) { +func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) { clientID, _ := createClient(t) + userResp := Tester.CreateHumanUser(CTX) + Tester.SetUserPassword(CTX, userResp.GetUserId(), integration.UserPassword, false) + Tester.RegisterUserU2F(CTX, userResp.GetUserId()) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) - sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTX, User.GetUserId(), integration.UserPassword) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userResp.GetUserId(), integration.UserPassword) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ AuthRequestId: authRequestID, CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ @@ -136,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) { code := assertCodeResponse(t, linkResp.GetCallbackUrl()) tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -145,6 +151,62 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) { require.Nil(t, myUserResp) } +func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) { + clientID, _ := createClient(t) + org := Tester.CreateOrganization(CTXIAM, fmt.Sprintf("ZITADEL_API_MISSING_MFA_%d", time.Now().UnixNano()), fmt.Sprintf("%d@mouse.com", time.Now().UnixNano())) + userID := org.CreatedAdmins[0].GetUserId() + Tester.SetUserPassword(CTXIAM, userID, integration.UserPassword, false) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userID, integration.UserPassword) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) + require.NoError(t, err) + assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime) + + ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) + + // pre check if request would succeed + myUserResp, err := Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) + require.NoError(t, err) + require.Equal(t, userID, myUserResp.GetUser().GetId()) + + // require MFA + ctxOrg := metadata.AppendToOutgoingContext(CTXIAM, "x-zitadel-orgid", org.GetOrganizationId()) + _, err = Tester.Client.Mgmt.AddCustomLoginPolicy(ctxOrg, &mgmt.AddCustomLoginPolicyRequest{ + ForceMfa: true, + }) + require.NoError(t, err) + + // make sure policy is projected + retryDuration := 5 * time.Second + if ctxDeadline, ok := CTX.Deadline(); ok { + retryDuration = time.Until(ctxDeadline) + } + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, getErr := Tester.Client.Mgmt.GetLoginPolicy(ctxOrg, &mgmt.GetLoginPolicyRequest{}) + assert.NoError(ttt, getErr) + assert.False(ttt, got.GetPolicy().IsDefault) + + }, retryDuration, time.Millisecond*100, "timeout waiting for login policy") + + // now it must fail + myUserResp, err = Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) + require.Error(t, err) + require.Nil(t, myUserResp) +} + func Test_ZITADEL_API_success(t *testing.T) { clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 9de07b742f..a5105d1f5b 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -2,6 +2,7 @@ package eventstore import ( "context" + "slices" "strings" "time" @@ -1030,15 +1031,11 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth if err != nil { return nil, err } - if (!isInternalLogin || len(idps.Links) > 0) && len(request.LinkingUsers) == 0 && !checkVerificationTimeMaxAge(userSession.ExternalLoginVerification, request.LoginPolicy.ExternalLoginCheckLifetime, request) { - selectedIDPConfigID := request.SelectedIDPConfigID - if selectedIDPConfigID == "" { - selectedIDPConfigID = userSession.SelectedIDPConfigID + if (!isInternalLogin || len(idps.Links) > 0) && len(request.LinkingUsers) == 0 { + step := repo.idpChecked(request, idps.Links, userSession) + if step != nil { + return append(steps, step), nil } - if selectedIDPConfigID == "" { - selectedIDPConfigID = idps.Links[0].IDPID - } - return append(steps, &domain.ExternalLoginStep{SelectedIDPConfigID: selectedIDPConfigID}), nil } if isInternalLogin || (!isInternalLogin && len(request.LinkingUsers) > 0) { step := repo.firstFactorChecked(request, user, userSession) @@ -1198,6 +1195,7 @@ func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, use var step domain.NextStep if request.LoginPolicy.PasswordlessType != domain.PasswordlessTypeNotAllowed && user.IsPasswordlessReady() { if checkVerificationTimeMaxAge(userSession.PasswordlessVerification, request.LoginPolicy.MultiFactorCheckLifetime, request) { + request.MFAsVerified = append(request.MFAsVerified, domain.MFATypeU2FUserVerification) request.AuthTime = userSession.PasswordlessVerification return nil } @@ -1225,8 +1223,27 @@ func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, use return &domain.PasswordStep{} } +func (repo *AuthRequestRepo) idpChecked(request *domain.AuthRequest, idps []*query.IDPUserLink, userSession *user_model.UserSessionView) domain.NextStep { + if checkVerificationTimeMaxAge(userSession.ExternalLoginVerification, request.LoginPolicy.ExternalLoginCheckLifetime, request) { + request.IDPLoginChecked = true + request.AuthTime = userSession.ExternalLoginVerification + return nil + } + selectedIDPConfigID := request.SelectedIDPConfigID + if selectedIDPConfigID == "" { + selectedIDPConfigID = userSession.SelectedIDPConfigID + } + if selectedIDPConfigID == "" && len(idps) > 0 { + selectedIDPConfigID = idps[0].IDPID + } + return &domain.ExternalLoginStep{SelectedIDPConfigID: selectedIDPConfigID} +} + func (repo *AuthRequestRepo) mfaChecked(userSession *user_model.UserSessionView, request *domain.AuthRequest, user *user_model.UserView, isInternalAuthentication bool) (domain.NextStep, bool, error) { mfaLevel := request.MFALevel() + if slices.Contains(request.MFAsVerified, domain.MFATypeU2FUserVerification) { + return nil, true, nil + } allowedProviders, required := user.MFATypesAllowed(mfaLevel, request.LoginPolicy, isInternalAuthentication) promptRequired := (user.MFAMaxSetUp < mfaLevel) || (len(allowedProviders) == 0 && required) if promptRequired || !repo.mfaSkippedOrSetUp(user, request) { diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 324039a765..35ce216877 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -89,7 +89,7 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie for i, user := range m.Users { sessions[i] = &user_view_model.UserSessionView{ ResourceOwner: user.ResourceOwner, - State: int32(user.SessionState), + State: sql.Null[domain.UserSessionState]{V: user.SessionState}, UserID: user.UserID, LoginName: sql.NullString{String: user.LoginName}, } @@ -1682,11 +1682,12 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { isInternal bool } tests := []struct { - name string - args args - want domain.NextStep - wantChecked bool - errFunc func(err error) bool + name string + args args + want domain.NextStep + wantChecked bool + errFunc func(err error) bool + wantMFAVerified []domain.MFAType }{ //{ // "required, prompt and false", //TODO: enable when LevelsOfAssurance is checked @@ -1718,6 +1719,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, false, zerrors.IsPreconditionFailed, + nil, }, { "not set up, no mfas configured, no prompt and true", @@ -1737,6 +1739,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + nil, }, { "not set up, prompt and false", @@ -1761,6 +1764,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "not set up, forced by org, true", @@ -1787,6 +1791,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "not set up and skipped, true", @@ -1807,6 +1812,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + nil, }, { "checked second factor, true", @@ -1829,6 +1835,38 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + []domain.MFAType{domain.MFATypeTOTP}, + }, + { + "checked passwordless, true", + args{ + request: &domain.AuthRequest{ + LoginPolicy: &domain.LoginPolicy{ + SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP}, + SecondFactorCheckLifetime: 18 * time.Hour, + MultiFactors: []domain.MultiFactorType{domain.MultiFactorTypeU2FWithPIN}, + MultiFactorCheckLifetime: 18 * time.Hour, + }, + MFAsVerified: []domain.MFAType{domain.MFATypeU2FUserVerification}, + }, + user: &user_model.UserView{ + HumanView: &user_model.HumanView{ + MFAMaxSetUp: domain.MFALevelMultiFactor, + PasswordlessTokens: []*user_model.WebAuthNView{ + { + TokenID: "tokenID", + State: user_model.MFAStateReady, + }, + }, + }, + }, + userSession: &user_model.UserSessionView{PasswordlessVerification: testNow.Add(-5 * time.Hour)}, + isInternal: true, + }, + nil, + true, + nil, + []domain.MFAType{domain.MFATypeU2FUserVerification}, }, { "not checked, check and false", @@ -1854,6 +1892,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "external not checked or forced but set up, want step", @@ -1878,6 +1917,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "external not forced but checked", @@ -1900,6 +1940,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + []domain.MFAType{domain.MFATypeTOTP}, }, { "external not checked but required, want step", @@ -1927,6 +1968,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "external not checked but local required", @@ -1950,6 +1992,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + nil, }, } for _, tt := range tests { @@ -1964,6 +2007,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { t.Errorf("mfaChecked() checked = %v, want %v", ok, tt.wantChecked) } assert.Equal(t, tt.want, got) + assert.ElementsMatch(t, tt.args.request.MFAsVerified, tt.wantMFAVerified) }) } } diff --git a/internal/auth/repository/eventsourcing/eventstore/user.go b/internal/auth/repository/eventsourcing/eventstore/user.go index 83f09be6ae..cfa573e7e3 100644 --- a/internal/auth/repository/eventsourcing/eventstore/user.go +++ b/internal/auth/repository/eventsourcing/eventstore/user.go @@ -32,7 +32,7 @@ func (repo *UserRepo) UserSessionUserIDsByAgentID(ctx context.Context, agentID s } userIDs := make([]string, 0, len(userSessions)) for _, session := range userSessions { - if session.State == int32(domain.UserSessionStateActive) { + if session.State.V == domain.UserSessionStateActive { userIDs = append(userIDs, session.UserID) } } diff --git a/internal/auth/repository/eventsourcing/handler/user_session.go b/internal/auth/repository/eventsourcing/handler/user_session.go index de97a2062e..3147b336f3 100644 --- a/internal/auth/repository/eventsourcing/handler/user_session.go +++ b/internal/auth/repository/eventsourcing/handler/user_session.go @@ -220,6 +220,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err user.HumanPasswordCheckFailedType: columns, err := sessionColumns(event, handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}), + handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive), ) if err != nil { return nil, err @@ -241,6 +242,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err user.HumanU2FTokenCheckFailedType: columns, err := sessionColumns(event, handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}), + handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive), ) if err != nil { return nil, err @@ -317,6 +319,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err columns, err := sessionColumns(event, handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}), handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, time.Time{}), + handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive), ) if err != nil { return nil, err diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index c02e9cb329..603146f511 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -172,7 +172,7 @@ func (repo *TokenVerifierRepo) verifySessionToken(ctx context.Context, sessionID } // checkAuthentication ensures the session or token was authenticated (at least a single [domain.UserAuthMethodType]). -// It will also check if there was a multi factor authentication, if either MFA is forced by the login policy or if the user has set up any +// It will also check if there was a multi factor authentication, if either MFA is forced by the login policy or if the user has set up any second factor func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMethods []domain.UserAuthMethodType, userID string) error { if len(authMethods) == 0 { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "authentication required") @@ -191,7 +191,7 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth requirements.ForceMFA, requirements.ForceMFALocalOnly, !hasIDPAuthentication(authMethods)) || - domain.HasMFA(requirements.AuthMethods) { + domain.Has2FA(requirements.AuthMethods) { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required") } return nil diff --git a/internal/domain/auth_request.go b/internal/domain/auth_request.go index b7e463ff9a..1aaf2bb63d 100644 --- a/internal/domain/auth_request.go +++ b/internal/domain/auth_request.go @@ -1,6 +1,7 @@ package domain import ( + "slices" "strings" "time" @@ -81,7 +82,7 @@ func (a *AuthRequest) AuthMethods() []UserAuthMethodType { for _, mfa := range a.MFAsVerified { list = append(list, mfa.UserAuthMethodType()) } - return list + return slices.Compact(list) } type ExternalUser struct { diff --git a/internal/domain/user.go b/internal/domain/user.go index 3204d658da..da32e17490 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -62,7 +62,8 @@ func HasMFA(methods []UserAuthMethodType) bool { UserAuthMethodTypeOTPSMS, UserAuthMethodTypeOTPEmail, UserAuthMethodTypeIDP, - UserAuthMethodTypeOTP: + UserAuthMethodTypeOTP, + UserAuthMethodTypePrivateKey: factors++ case UserAuthMethodTypeUnspecified, userAuthMethodTypeCount: @@ -72,6 +73,30 @@ func HasMFA(methods []UserAuthMethodType) bool { return factors > 1 } +// Has2FA checks whether the auth factors provided are a second factor and will return true if at least one is. +func Has2FA(methods []UserAuthMethodType) bool { + var factors int + for _, method := range methods { + switch method { + case + UserAuthMethodTypeU2F, + UserAuthMethodTypeTOTP, + UserAuthMethodTypeOTPSMS, + UserAuthMethodTypeOTPEmail, + UserAuthMethodTypeOTP: + factors++ + case UserAuthMethodTypeUnspecified, + UserAuthMethodTypePassword, + UserAuthMethodTypePasswordless, + UserAuthMethodTypeIDP, + UserAuthMethodTypePrivateKey, + userAuthMethodTypeCount: + // ignore + } + } + return factors > 0 +} + // RequiresMFA checks whether the user requires to authenticate with multiple auth factors based on the LoginPolicy and the authentication type. // Internal authentication will require MFA if either option is activated. // External authentication will only require MFA if it's forced generally and not local only. diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index 3c24948d87..a350d75360 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -3,6 +3,7 @@ package query import ( "context" "database/sql" + "errors" "time" sq "github.com/Masterminds/squirrel" @@ -10,6 +11,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/call" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -69,8 +71,12 @@ var ( authMethodTypeTable = userAuthMethodTable.setAlias("auth_method_types") authMethodTypeUserID = UserAuthMethodColumnUserID.setTable(authMethodTypeTable) authMethodTypeInstanceID = UserAuthMethodColumnInstanceID.setTable(authMethodTypeTable) - authMethodTypeTypes = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable) - authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable) + authMethodTypeType = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable) + authMethodTypeTypes = Column{ + name: "method_types", + table: authMethodTypeTable, + } + authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable) userIDPsCountTable = idpUserLinkTable.setAlias("user_idps_count") userIDPsCountUserID = IDPUserLinkUserIDCol.setTable(userIDPsCountTable) @@ -199,8 +205,8 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest") } - err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { - requirements, err = scan(rows) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + requirements, err = scan(row) return err }, stmt, args...) if err != nil { @@ -360,7 +366,7 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba } return sq.Select( NotifyPasswordSetCol.identifier(), - authMethodTypeTypes.identifier(), + authMethodTypeType.identifier(), userIDPsCountCount.identifier()). From(userTable.identifier()). LeftJoin(join(NotifyUserIDCol, UserIDCol)). @@ -411,12 +417,12 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba } } -func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { +func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery() if err != nil { return sq.SelectBuilder{}, nil } - authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery() + authMethodsQuery, authMethodsArgs, err := prepareAggAuthMethodsQuery() if err != nil { return sq.SelectBuilder{}, nil } @@ -442,47 +448,41 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " + "(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " + - forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))). + forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). + OrderBy(forceMFAIsDefault.identifier()). + Limit(1). PlaceholderFormat(sq.Dollar), - func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - userAuthMethodTypes := make([]domain.UserAuthMethodType, 0) + func(row *sql.Row) (*UserAuthMethodRequirements, error) { var passwordSet sql.NullBool + var authMethodTypes database.NumberArray[domain.UserAuthMethodType] var idp sql.NullInt64 var userType sql.NullInt32 var forceMFA sql.NullBool var forceMFALocalOnly sql.NullBool - for rows.Next() { - var authMethodType sql.NullInt16 - err := rows.Scan( - &passwordSet, - &authMethodType, - &idp, - &userType, - &forceMFA, - &forceMFALocalOnly, - ) - if err != nil { - return nil, err - } - if authMethodType.Valid { - userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16)) + err := row.Scan( + &passwordSet, + &authMethodTypes, + &idp, + &userType, + &forceMFA, + &forceMFALocalOnly, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal") } + return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal") } if passwordSet.Valid && passwordSet.Bool { - userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypePassword) + authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypePassword) } if idp.Valid && idp.Int64 > 0 { - logging.Error("IDP", idp.Int64) - userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypeIDP) - } - - if err := rows.Close(); err != nil { - return nil, zerrors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows") + authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypeIDP) } return &UserAuthMethodRequirements{ UserType: domain.UserType(userType.Int32), - AuthMethods: userAuthMethodTypes, + AuthMethods: authMethodTypes, ForceMFA: forceMFA.Bool, ForceMFALocalOnly: forceMFALocalOnly.Bool, }, nil @@ -505,7 +505,7 @@ func prepareAuthMethodsIDPsQuery() (string, error) { func prepareAuthMethodQuery() (string, []interface{}, error) { return sq.Select( - "DISTINCT("+authMethodTypeTypes.identifier()+")", + "DISTINCT("+authMethodTypeType.identifier()+")", authMethodTypeUserID.identifier(), authMethodTypeInstanceID.identifier()). From(authMethodTypeTable.identifier()). @@ -513,15 +513,26 @@ func prepareAuthMethodQuery() (string, []interface{}, error) { ToSql() } +func prepareAggAuthMethodsQuery() (string, []interface{}, error) { + return sq.Select( + "array_agg(DISTINCT("+authMethodTypeType.identifier()+")) as method_types", + authMethodTypeUserID.identifier(), + authMethodTypeInstanceID.identifier()). + From(authMethodTypeTable.identifier()). + Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}). + GroupBy(authMethodTypeInstanceID.identifier(), authMethodTypeUserID.identifier()). + ToSql() +} + func prepareAuthMethodsForceMFAQuery() (string, error) { loginPolicyQuery, _, err := sq.Select( forceMFAForce.identifier(), forceMFAForceLocalOnly.identifier(), forceMFAInstanceID.identifier(), forceMFAOrgID.identifier(), + forceMFAIsDefault.identifier(), ). From(forceMFATable.identifier()). - OrderBy(forceMFAIsDefault.identifier()). ToSql() return loginPolicyQuery, err } diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go index 578b14cec5..698667d990 100644 --- a/internal/query/user_auth_method_test.go +++ b/internal/query/user_auth_method_test.go @@ -11,7 +11,9 @@ import ( sq "github.com/Masterminds/squirrel" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/zerrors" ) var ( @@ -57,27 +59,27 @@ var ( "idps_count", } prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` + - ` auth_method_types.method_type,` + + ` auth_method_types.method_types,` + ` user_idps_count.count,` + ` projections.users12.type,` + ` auth_methods_force_mfa.force_mfa,` + ` auth_methods_force_mfa.force_mfa_local_only` + ` FROM projections.users12` + ` LEFT JOIN projections.users12_notifications ON projections.users12.id = projections.users12_notifications.user_id AND projections.users12.instance_id = projections.users12_notifications.instance_id` + - ` LEFT JOIN (SELECT DISTINCT(auth_method_types.method_type), auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` + - ` WHERE auth_method_types.state = $1) AS auth_method_types` + + ` LEFT JOIN (SELECT array_agg(DISTINCT(auth_method_types.method_type)) as method_types, auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` + + ` WHERE auth_method_types.state = $1 GROUP BY auth_method_types.instance_id, auth_method_types.user_id) AS auth_method_types` + ` ON auth_method_types.user_id = projections.users12.id AND auth_method_types.instance_id = projections.users12.instance_id` + ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + ` ON user_idps_count.user_id = projections.users12.id AND user_idps_count.instance_id = projections.users12.instance_id` + - ` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id FROM projections.login_policies5 AS auth_methods_force_mfa ORDER BY auth_methods_force_mfa.is_default) AS auth_methods_force_mfa` + + ` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id, auth_methods_force_mfa.is_default FROM projections.login_policies5 AS auth_methods_force_mfa) AS auth_methods_force_mfa` + ` ON (auth_methods_force_mfa.aggregate_id = projections.users12.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users12.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users12.instance_id` + - ` AS OF SYSTEM TIME '-1 ms + ` ORDER BY auth_methods_force_mfa.is_default LIMIT 1 ` prepareAuthMethodTypesRequiredCols = []string{ "password_set", "type", - "method_type", + "method_types", "idps_count", "force_mfa", "force_mfa_local_only", @@ -319,27 +321,33 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt), nil, nil, ), + err: func(err error) (error, bool) { + if !zerrors.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, }, - object: &UserAuthMethodRequirements{AuthMethods: []domain.UserAuthMethodType{}, ForceMFA: false}, + object: (*UserAuthMethodRequirements)(nil), }, { name: "prepareUserAuthMethodTypesRequiredQuery one second factor", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ @@ -349,7 +357,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { [][]driver.Value{ { true, - domain.UserAuthMethodTypePasswordless, + database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless}, 1, domain.UserTypeHuman, true, @@ -371,10 +379,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ @@ -384,15 +392,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { [][]driver.Value{ { true, - domain.UserAuthMethodTypePasswordless, - 1, - domain.UserTypeHuman, - true, - true, - }, - { - true, - domain.UserAuthMethodTypeTOTP, + database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypeTOTP}, 1, domain.UserTypeHuman, true, @@ -416,10 +416,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ diff --git a/internal/user/repository/view/model/user_session.go b/internal/user/repository/view/model/user_session.go index 4208ea5b53..d761592551 100644 --- a/internal/user/repository/view/model/user_session.go +++ b/internal/user/repository/view/model/user_session.go @@ -35,12 +35,12 @@ const ( ) type UserSessionView struct { - CreationDate time.Time `json:"-" gorm:"column:creation_date"` - ChangeDate time.Time `json:"-" gorm:"column:change_date"` - ResourceOwner string `json:"-" gorm:"column:resource_owner"` - State int32 `json:"-" gorm:"column:state"` - UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"` - UserID string `json:"userID" gorm:"column:user_id;primary_key"` + CreationDate time.Time `json:"-" gorm:"column:creation_date"` + ChangeDate time.Time `json:"-" gorm:"column:change_date"` + ResourceOwner string `json:"-" gorm:"column:resource_owner"` + State sql.Null[domain.UserSessionState] `json:"-" gorm:"column:state"` + UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"` + UserID string `json:"userID" gorm:"column:user_id;primary_key"` // As of https://github.com/zitadel/zitadel/pull/7199 the following 4 attributes // are not projected in the user session handler anymore // and are therefore annotated with a `gorm:"-"`. @@ -79,7 +79,7 @@ func UserSessionToModel(userSession *UserSessionView) *model.UserSessionView { ChangeDate: userSession.ChangeDate, CreationDate: userSession.CreationDate, ResourceOwner: userSession.ResourceOwner, - State: domain.UserSessionState(userSession.State), + State: userSession.State.V, UserAgentID: userSession.UserAgentID, UserID: userSession.UserID, UserName: userSession.UserName.String, @@ -114,7 +114,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { case user.UserV1PasswordCheckSucceededType, user.HumanPasswordCheckSucceededType: v.PasswordVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive case user.UserIDPLoginCheckSucceededType: data := new(es_model.AuthRequest) err := data.SetData(event) @@ -123,12 +123,12 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { } v.ExternalLoginVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.SelectedIDPConfigID = sql.NullString{String: data.SelectedIDPConfigID, Valid: true} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive case user.HumanPasswordlessTokenCheckSucceededType: v.PasswordlessVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.MultiFactorVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFATypeU2FUserVerification)} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive case user.HumanPasswordlessTokenCheckFailedType, user.HumanPasswordlessTokenRemovedType: v.PasswordlessVerification = sql.NullTime{Time: time.Time{}, Valid: true} @@ -207,7 +207,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { v.MultiFactorVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFALevelNotSetUp)} v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true} - v.State = int32(domain.UserSessionStateTerminated) + v.State.V = domain.UserSessionStateTerminated case user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType: v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.SelectedIDPConfigID = sql.NullString{String: "", Valid: true} @@ -218,7 +218,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { func (v *UserSessionView) setSecondFactorVerification(verificationTime time.Time, mfaType domain.MFAType) { v.SecondFactorVerification = sql.NullTime{Time: verificationTime, Valid: true} v.SecondFactorVerificationType = sql.NullInt32{Int32: int32(mfaType)} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive } func (v *UserSessionView) EventTypes() []eventstore.EventType { diff --git a/internal/user/repository/view/model/user_session_test.go b/internal/user/repository/view/model/user_session_test.go index 25acd489c7..3e832ae1fd 100644 --- a/internal/user/repository/view/model/user_session_test.go +++ b/internal/user/repository/view/model/user_session_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/repository/user" es_model "github.com/zitadel/zitadel/internal/user/repository/eventsourcing/model" @@ -209,7 +210,7 @@ func TestAppendEvent(t *testing.T) { ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true}, PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true}, MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true}, - State: 1, + State: sql.Null[domain.UserSessionState]{V: domain.UserSessionStateTerminated}, }, }, { @@ -228,7 +229,7 @@ func TestAppendEvent(t *testing.T) { ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true}, PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true}, MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true}, - State: 1, + State: sql.Null[domain.UserSessionState]{V: domain.UserSessionStateTerminated}, }, }, }