diff --git a/console/src/app/services/authentication.service.ts b/console/src/app/services/authentication.service.ts index 314b9edb7e..3ee3024e40 100644 --- a/console/src/app/services/authentication.service.ts +++ b/console/src/app/services/authentication.service.ts @@ -35,6 +35,10 @@ export class AuthenticationService { return from(this.oauthService.loadUserProfile()); } + public getIdToken(): string { + return this.oauthService.getIdToken(); + } + public async authenticate(partialConfig?: Partial, force: boolean = false): Promise { if (partialConfig) { Object.assign(this.authConfig, partialConfig); diff --git a/console/src/app/services/grpc.service.ts b/console/src/app/services/grpc.service.ts index 8a87f95ef0..5c4c0fd510 100644 --- a/console/src/app/services/grpc.service.ts +++ b/console/src/app/services/grpc.service.ts @@ -18,6 +18,7 @@ import { I18nInterceptor } from './interceptors/i18n.interceptor'; import { OrgInterceptor } from './interceptors/org.interceptor'; import { StorageService } from './storage.service'; import { FeatureServiceClient } from '../proto/generated/zitadel/feature/v2beta/Feature_serviceServiceClientPb'; +import { GrpcAuthService } from './grpc-auth.service'; @Injectable({ providedIn: 'root', diff --git a/console/src/app/services/interceptors/auth.interceptor.ts b/console/src/app/services/interceptors/auth.interceptor.ts index d21bb5cdaa..4ccdc768c7 100644 --- a/console/src/app/services/interceptors/auth.interceptor.ts +++ b/console/src/app/services/interceptors/auth.interceptor.ts @@ -2,11 +2,13 @@ import { Injectable } from '@angular/core'; import { MatDialog } from '@angular/material/dialog'; import { Request, UnaryInterceptor, UnaryResponse } from 'grpc-web'; import { Subject } from 'rxjs'; -import { debounceTime, filter, first, take } from 'rxjs/operators'; +import { debounceTime, filter, first, map, take, tap } from 'rxjs/operators'; import { WarnDialogComponent } from 'src/app/modules/warn-dialog/warn-dialog.component'; import { AuthenticationService } from '../authentication.service'; import { StorageService } from '../storage.service'; +import { AuthConfig } from 'angular-oauth2-oidc'; +import { GrpcAuthService } from '../grpc-auth.service'; const authorizationKey = 'Authorization'; const bearerPrefix = 'Bearer'; @@ -44,7 +46,7 @@ export class AuthInterceptor implements UnaryIn return response; }) .catch(async (error: any) => { - if (error.code === 16) { + if (error.code === 16 || (error.code === 7 && error.message === 'mfa required (AUTHZ-Kl3p0)')) { this.triggerDialog.next(true); } return Promise.reject(error); @@ -67,7 +69,13 @@ export class AuthInterceptor implements UnaryIn .pipe(take(1)) .subscribe((resp) => { if (resp) { - this.authenticationService.authenticate(undefined, true); + const idToken = this.authenticationService.getIdToken(); + const configWithPrompt: Partial = { + customQueryParams: { + id_token_hint: idToken, + }, + }; + this.authenticationService.authenticate(configWithPrompt, true); } }); } diff --git a/internal/api/oidc/oidc_integration_test.go b/internal/api/oidc/oidc_integration_test.go index 09e76391bd..c7a101afd3 100644 --- a/internal/api/oidc/oidc_integration_test.go +++ b/internal/api/oidc/oidc_integration_test.go @@ -18,6 +18,7 @@ import ( "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/pkg/grpc/auth" + mgmt "github.com/zitadel/zitadel/pkg/grpc/management" oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" @@ -26,6 +27,7 @@ import ( var ( CTX context.Context CTXLOGIN context.Context + CTXIAM context.Context Tester *integration.Tester User *user.AddHumanUserResponse ) @@ -50,6 +52,7 @@ func TestMain(m *testing.M) { Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword, false) Tester.RegisterUserPasskey(CTX, User.GetUserId()) CTXLOGIN = Tester.WithAuthorization(ctx, integration.Login) + CTXIAM = Tester.WithAuthorization(ctx, integration.IAMOwner) return m.Run() }()) } @@ -117,10 +120,13 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) { require.Nil(t, myUserResp) } -func Test_ZITADEL_API_missing_mfa(t *testing.T) { +func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) { clientID, _ := createClient(t) + userResp := Tester.CreateHumanUser(CTX) + Tester.SetUserPassword(CTX, userResp.GetUserId(), integration.UserPassword, false) + Tester.RegisterUserU2F(CTX, userResp.GetUserId()) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) - sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTX, User.GetUserId(), integration.UserPassword) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userResp.GetUserId(), integration.UserPassword) linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ AuthRequestId: authRequestID, CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ @@ -136,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) { code := assertCodeResponse(t, linkResp.GetCallbackUrl()) tokens, err := exchangeTokens(t, clientID, code, redirectURI) require.NoError(t, err) - assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPassword, startTime, changeTime) + assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) @@ -145,6 +151,62 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) { require.Nil(t, myUserResp) } +func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) { + clientID, _ := createClient(t) + org := Tester.CreateOrganization(CTXIAM, fmt.Sprintf("ZITADEL_API_MISSING_MFA_%d", time.Now().UnixNano()), fmt.Sprintf("%d@mouse.com", time.Now().UnixNano())) + userID := org.CreatedAdmins[0].GetUserId() + Tester.SetUserPassword(CTXIAM, userID, integration.UserPassword, false) + authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) + sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userID, integration.UserPassword) + linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ + AuthRequestId: authRequestID, + CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ + Session: &oidc_pb.Session{ + SessionId: sessionID, + SessionToken: sessionToken, + }, + }, + }) + require.NoError(t, err) + + // code exchange + code := assertCodeResponse(t, linkResp.GetCallbackUrl()) + tokens, err := exchangeTokens(t, clientID, code, redirectURI) + require.NoError(t, err) + assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime) + + ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) + + // pre check if request would succeed + myUserResp, err := Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) + require.NoError(t, err) + require.Equal(t, userID, myUserResp.GetUser().GetId()) + + // require MFA + ctxOrg := metadata.AppendToOutgoingContext(CTXIAM, "x-zitadel-orgid", org.GetOrganizationId()) + _, err = Tester.Client.Mgmt.AddCustomLoginPolicy(ctxOrg, &mgmt.AddCustomLoginPolicyRequest{ + ForceMfa: true, + }) + require.NoError(t, err) + + // make sure policy is projected + retryDuration := 5 * time.Second + if ctxDeadline, ok := CTX.Deadline(); ok { + retryDuration = time.Until(ctxDeadline) + } + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, getErr := Tester.Client.Mgmt.GetLoginPolicy(ctxOrg, &mgmt.GetLoginPolicyRequest{}) + assert.NoError(ttt, getErr) + assert.False(ttt, got.GetPolicy().IsDefault) + + }, retryDuration, time.Millisecond*100, "timeout waiting for login policy") + + // now it must fail + myUserResp, err = Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{}) + require.Error(t, err) + require.Nil(t, myUserResp) +} + func Test_ZITADEL_API_success(t *testing.T) { clientID, _ := createClient(t) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 9de07b742f..a5105d1f5b 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -2,6 +2,7 @@ package eventstore import ( "context" + "slices" "strings" "time" @@ -1030,15 +1031,11 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth if err != nil { return nil, err } - if (!isInternalLogin || len(idps.Links) > 0) && len(request.LinkingUsers) == 0 && !checkVerificationTimeMaxAge(userSession.ExternalLoginVerification, request.LoginPolicy.ExternalLoginCheckLifetime, request) { - selectedIDPConfigID := request.SelectedIDPConfigID - if selectedIDPConfigID == "" { - selectedIDPConfigID = userSession.SelectedIDPConfigID + if (!isInternalLogin || len(idps.Links) > 0) && len(request.LinkingUsers) == 0 { + step := repo.idpChecked(request, idps.Links, userSession) + if step != nil { + return append(steps, step), nil } - if selectedIDPConfigID == "" { - selectedIDPConfigID = idps.Links[0].IDPID - } - return append(steps, &domain.ExternalLoginStep{SelectedIDPConfigID: selectedIDPConfigID}), nil } if isInternalLogin || (!isInternalLogin && len(request.LinkingUsers) > 0) { step := repo.firstFactorChecked(request, user, userSession) @@ -1198,6 +1195,7 @@ func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, use var step domain.NextStep if request.LoginPolicy.PasswordlessType != domain.PasswordlessTypeNotAllowed && user.IsPasswordlessReady() { if checkVerificationTimeMaxAge(userSession.PasswordlessVerification, request.LoginPolicy.MultiFactorCheckLifetime, request) { + request.MFAsVerified = append(request.MFAsVerified, domain.MFATypeU2FUserVerification) request.AuthTime = userSession.PasswordlessVerification return nil } @@ -1225,8 +1223,27 @@ func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, use return &domain.PasswordStep{} } +func (repo *AuthRequestRepo) idpChecked(request *domain.AuthRequest, idps []*query.IDPUserLink, userSession *user_model.UserSessionView) domain.NextStep { + if checkVerificationTimeMaxAge(userSession.ExternalLoginVerification, request.LoginPolicy.ExternalLoginCheckLifetime, request) { + request.IDPLoginChecked = true + request.AuthTime = userSession.ExternalLoginVerification + return nil + } + selectedIDPConfigID := request.SelectedIDPConfigID + if selectedIDPConfigID == "" { + selectedIDPConfigID = userSession.SelectedIDPConfigID + } + if selectedIDPConfigID == "" && len(idps) > 0 { + selectedIDPConfigID = idps[0].IDPID + } + return &domain.ExternalLoginStep{SelectedIDPConfigID: selectedIDPConfigID} +} + func (repo *AuthRequestRepo) mfaChecked(userSession *user_model.UserSessionView, request *domain.AuthRequest, user *user_model.UserView, isInternalAuthentication bool) (domain.NextStep, bool, error) { mfaLevel := request.MFALevel() + if slices.Contains(request.MFAsVerified, domain.MFATypeU2FUserVerification) { + return nil, true, nil + } allowedProviders, required := user.MFATypesAllowed(mfaLevel, request.LoginPolicy, isInternalAuthentication) promptRequired := (user.MFAMaxSetUp < mfaLevel) || (len(allowedProviders) == 0 && required) if promptRequired || !repo.mfaSkippedOrSetUp(user, request) { diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 324039a765..35ce216877 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -89,7 +89,7 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie for i, user := range m.Users { sessions[i] = &user_view_model.UserSessionView{ ResourceOwner: user.ResourceOwner, - State: int32(user.SessionState), + State: sql.Null[domain.UserSessionState]{V: user.SessionState}, UserID: user.UserID, LoginName: sql.NullString{String: user.LoginName}, } @@ -1682,11 +1682,12 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { isInternal bool } tests := []struct { - name string - args args - want domain.NextStep - wantChecked bool - errFunc func(err error) bool + name string + args args + want domain.NextStep + wantChecked bool + errFunc func(err error) bool + wantMFAVerified []domain.MFAType }{ //{ // "required, prompt and false", //TODO: enable when LevelsOfAssurance is checked @@ -1718,6 +1719,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, false, zerrors.IsPreconditionFailed, + nil, }, { "not set up, no mfas configured, no prompt and true", @@ -1737,6 +1739,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + nil, }, { "not set up, prompt and false", @@ -1761,6 +1764,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "not set up, forced by org, true", @@ -1787,6 +1791,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "not set up and skipped, true", @@ -1807,6 +1812,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + nil, }, { "checked second factor, true", @@ -1829,6 +1835,38 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + []domain.MFAType{domain.MFATypeTOTP}, + }, + { + "checked passwordless, true", + args{ + request: &domain.AuthRequest{ + LoginPolicy: &domain.LoginPolicy{ + SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP}, + SecondFactorCheckLifetime: 18 * time.Hour, + MultiFactors: []domain.MultiFactorType{domain.MultiFactorTypeU2FWithPIN}, + MultiFactorCheckLifetime: 18 * time.Hour, + }, + MFAsVerified: []domain.MFAType{domain.MFATypeU2FUserVerification}, + }, + user: &user_model.UserView{ + HumanView: &user_model.HumanView{ + MFAMaxSetUp: domain.MFALevelMultiFactor, + PasswordlessTokens: []*user_model.WebAuthNView{ + { + TokenID: "tokenID", + State: user_model.MFAStateReady, + }, + }, + }, + }, + userSession: &user_model.UserSessionView{PasswordlessVerification: testNow.Add(-5 * time.Hour)}, + isInternal: true, + }, + nil, + true, + nil, + []domain.MFAType{domain.MFATypeU2FUserVerification}, }, { "not checked, check and false", @@ -1854,6 +1892,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "external not checked or forced but set up, want step", @@ -1878,6 +1917,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "external not forced but checked", @@ -1900,6 +1940,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + []domain.MFAType{domain.MFATypeTOTP}, }, { "external not checked but required, want step", @@ -1927,6 +1968,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { }, false, nil, + nil, }, { "external not checked but local required", @@ -1950,6 +1992,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { nil, true, nil, + nil, }, } for _, tt := range tests { @@ -1964,6 +2007,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) { t.Errorf("mfaChecked() checked = %v, want %v", ok, tt.wantChecked) } assert.Equal(t, tt.want, got) + assert.ElementsMatch(t, tt.args.request.MFAsVerified, tt.wantMFAVerified) }) } } diff --git a/internal/auth/repository/eventsourcing/eventstore/user.go b/internal/auth/repository/eventsourcing/eventstore/user.go index 83f09be6ae..cfa573e7e3 100644 --- a/internal/auth/repository/eventsourcing/eventstore/user.go +++ b/internal/auth/repository/eventsourcing/eventstore/user.go @@ -32,7 +32,7 @@ func (repo *UserRepo) UserSessionUserIDsByAgentID(ctx context.Context, agentID s } userIDs := make([]string, 0, len(userSessions)) for _, session := range userSessions { - if session.State == int32(domain.UserSessionStateActive) { + if session.State.V == domain.UserSessionStateActive { userIDs = append(userIDs, session.UserID) } } diff --git a/internal/auth/repository/eventsourcing/handler/user_session.go b/internal/auth/repository/eventsourcing/handler/user_session.go index de97a2062e..3147b336f3 100644 --- a/internal/auth/repository/eventsourcing/handler/user_session.go +++ b/internal/auth/repository/eventsourcing/handler/user_session.go @@ -220,6 +220,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err user.HumanPasswordCheckFailedType: columns, err := sessionColumns(event, handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}), + handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive), ) if err != nil { return nil, err @@ -241,6 +242,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err user.HumanU2FTokenCheckFailedType: columns, err := sessionColumns(event, handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}), + handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive), ) if err != nil { return nil, err @@ -317,6 +319,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err columns, err := sessionColumns(event, handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}), handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, time.Time{}), + handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive), ) if err != nil { return nil, err diff --git a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go index c02e9cb329..603146f511 100644 --- a/internal/authz/repository/eventsourcing/eventstore/token_verifier.go +++ b/internal/authz/repository/eventsourcing/eventstore/token_verifier.go @@ -172,7 +172,7 @@ func (repo *TokenVerifierRepo) verifySessionToken(ctx context.Context, sessionID } // checkAuthentication ensures the session or token was authenticated (at least a single [domain.UserAuthMethodType]). -// It will also check if there was a multi factor authentication, if either MFA is forced by the login policy or if the user has set up any +// It will also check if there was a multi factor authentication, if either MFA is forced by the login policy or if the user has set up any second factor func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMethods []domain.UserAuthMethodType, userID string) error { if len(authMethods) == 0 { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "authentication required") @@ -191,7 +191,7 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth requirements.ForceMFA, requirements.ForceMFALocalOnly, !hasIDPAuthentication(authMethods)) || - domain.HasMFA(requirements.AuthMethods) { + domain.Has2FA(requirements.AuthMethods) { return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required") } return nil diff --git a/internal/domain/auth_request.go b/internal/domain/auth_request.go index b7e463ff9a..1aaf2bb63d 100644 --- a/internal/domain/auth_request.go +++ b/internal/domain/auth_request.go @@ -1,6 +1,7 @@ package domain import ( + "slices" "strings" "time" @@ -81,7 +82,7 @@ func (a *AuthRequest) AuthMethods() []UserAuthMethodType { for _, mfa := range a.MFAsVerified { list = append(list, mfa.UserAuthMethodType()) } - return list + return slices.Compact(list) } type ExternalUser struct { diff --git a/internal/domain/user.go b/internal/domain/user.go index 3204d658da..da32e17490 100644 --- a/internal/domain/user.go +++ b/internal/domain/user.go @@ -62,7 +62,8 @@ func HasMFA(methods []UserAuthMethodType) bool { UserAuthMethodTypeOTPSMS, UserAuthMethodTypeOTPEmail, UserAuthMethodTypeIDP, - UserAuthMethodTypeOTP: + UserAuthMethodTypeOTP, + UserAuthMethodTypePrivateKey: factors++ case UserAuthMethodTypeUnspecified, userAuthMethodTypeCount: @@ -72,6 +73,30 @@ func HasMFA(methods []UserAuthMethodType) bool { return factors > 1 } +// Has2FA checks whether the auth factors provided are a second factor and will return true if at least one is. +func Has2FA(methods []UserAuthMethodType) bool { + var factors int + for _, method := range methods { + switch method { + case + UserAuthMethodTypeU2F, + UserAuthMethodTypeTOTP, + UserAuthMethodTypeOTPSMS, + UserAuthMethodTypeOTPEmail, + UserAuthMethodTypeOTP: + factors++ + case UserAuthMethodTypeUnspecified, + UserAuthMethodTypePassword, + UserAuthMethodTypePasswordless, + UserAuthMethodTypeIDP, + UserAuthMethodTypePrivateKey, + userAuthMethodTypeCount: + // ignore + } + } + return factors > 0 +} + // RequiresMFA checks whether the user requires to authenticate with multiple auth factors based on the LoginPolicy and the authentication type. // Internal authentication will require MFA if either option is activated. // External authentication will only require MFA if it's forced generally and not local only. diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index 3c24948d87..a350d75360 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -3,6 +3,7 @@ package query import ( "context" "database/sql" + "errors" "time" sq "github.com/Masterminds/squirrel" @@ -10,6 +11,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/call" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -69,8 +71,12 @@ var ( authMethodTypeTable = userAuthMethodTable.setAlias("auth_method_types") authMethodTypeUserID = UserAuthMethodColumnUserID.setTable(authMethodTypeTable) authMethodTypeInstanceID = UserAuthMethodColumnInstanceID.setTable(authMethodTypeTable) - authMethodTypeTypes = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable) - authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable) + authMethodTypeType = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable) + authMethodTypeTypes = Column{ + name: "method_types", + table: authMethodTypeTable, + } + authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable) userIDPsCountTable = idpUserLinkTable.setAlias("user_idps_count") userIDPsCountUserID = IDPUserLinkUserIDCol.setTable(userIDPsCountTable) @@ -199,8 +205,8 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest") } - err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { - requirements, err = scan(rows) + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + requirements, err = scan(row) return err }, stmt, args...) if err != nil { @@ -360,7 +366,7 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba } return sq.Select( NotifyPasswordSetCol.identifier(), - authMethodTypeTypes.identifier(), + authMethodTypeType.identifier(), userIDPsCountCount.identifier()). From(userTable.identifier()). LeftJoin(join(NotifyUserIDCol, UserIDCol)). @@ -411,12 +417,12 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba } } -func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { +func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery() if err != nil { return sq.SelectBuilder{}, nil } - authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery() + authMethodsQuery, authMethodsArgs, err := prepareAggAuthMethodsQuery() if err != nil { return sq.SelectBuilder{}, nil } @@ -442,47 +448,41 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " + "(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " + - forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))). + forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). + OrderBy(forceMFAIsDefault.identifier()). + Limit(1). PlaceholderFormat(sq.Dollar), - func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - userAuthMethodTypes := make([]domain.UserAuthMethodType, 0) + func(row *sql.Row) (*UserAuthMethodRequirements, error) { var passwordSet sql.NullBool + var authMethodTypes database.NumberArray[domain.UserAuthMethodType] var idp sql.NullInt64 var userType sql.NullInt32 var forceMFA sql.NullBool var forceMFALocalOnly sql.NullBool - for rows.Next() { - var authMethodType sql.NullInt16 - err := rows.Scan( - &passwordSet, - &authMethodType, - &idp, - &userType, - &forceMFA, - &forceMFALocalOnly, - ) - if err != nil { - return nil, err - } - if authMethodType.Valid { - userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16)) + err := row.Scan( + &passwordSet, + &authMethodTypes, + &idp, + &userType, + &forceMFA, + &forceMFALocalOnly, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal") } + return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal") } if passwordSet.Valid && passwordSet.Bool { - userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypePassword) + authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypePassword) } if idp.Valid && idp.Int64 > 0 { - logging.Error("IDP", idp.Int64) - userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypeIDP) - } - - if err := rows.Close(); err != nil { - return nil, zerrors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows") + authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypeIDP) } return &UserAuthMethodRequirements{ UserType: domain.UserType(userType.Int32), - AuthMethods: userAuthMethodTypes, + AuthMethods: authMethodTypes, ForceMFA: forceMFA.Bool, ForceMFALocalOnly: forceMFALocalOnly.Bool, }, nil @@ -505,7 +505,7 @@ func prepareAuthMethodsIDPsQuery() (string, error) { func prepareAuthMethodQuery() (string, []interface{}, error) { return sq.Select( - "DISTINCT("+authMethodTypeTypes.identifier()+")", + "DISTINCT("+authMethodTypeType.identifier()+")", authMethodTypeUserID.identifier(), authMethodTypeInstanceID.identifier()). From(authMethodTypeTable.identifier()). @@ -513,15 +513,26 @@ func prepareAuthMethodQuery() (string, []interface{}, error) { ToSql() } +func prepareAggAuthMethodsQuery() (string, []interface{}, error) { + return sq.Select( + "array_agg(DISTINCT("+authMethodTypeType.identifier()+")) as method_types", + authMethodTypeUserID.identifier(), + authMethodTypeInstanceID.identifier()). + From(authMethodTypeTable.identifier()). + Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}). + GroupBy(authMethodTypeInstanceID.identifier(), authMethodTypeUserID.identifier()). + ToSql() +} + func prepareAuthMethodsForceMFAQuery() (string, error) { loginPolicyQuery, _, err := sq.Select( forceMFAForce.identifier(), forceMFAForceLocalOnly.identifier(), forceMFAInstanceID.identifier(), forceMFAOrgID.identifier(), + forceMFAIsDefault.identifier(), ). From(forceMFATable.identifier()). - OrderBy(forceMFAIsDefault.identifier()). ToSql() return loginPolicyQuery, err } diff --git a/internal/query/user_auth_method_test.go b/internal/query/user_auth_method_test.go index 578b14cec5..698667d990 100644 --- a/internal/query/user_auth_method_test.go +++ b/internal/query/user_auth_method_test.go @@ -11,7 +11,9 @@ import ( sq "github.com/Masterminds/squirrel" + "github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/zerrors" ) var ( @@ -57,27 +59,27 @@ var ( "idps_count", } prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` + - ` auth_method_types.method_type,` + + ` auth_method_types.method_types,` + ` user_idps_count.count,` + ` projections.users12.type,` + ` auth_methods_force_mfa.force_mfa,` + ` auth_methods_force_mfa.force_mfa_local_only` + ` FROM projections.users12` + ` LEFT JOIN projections.users12_notifications ON projections.users12.id = projections.users12_notifications.user_id AND projections.users12.instance_id = projections.users12_notifications.instance_id` + - ` LEFT JOIN (SELECT DISTINCT(auth_method_types.method_type), auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` + - ` WHERE auth_method_types.state = $1) AS auth_method_types` + + ` LEFT JOIN (SELECT array_agg(DISTINCT(auth_method_types.method_type)) as method_types, auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` + + ` WHERE auth_method_types.state = $1 GROUP BY auth_method_types.instance_id, auth_method_types.user_id) AS auth_method_types` + ` ON auth_method_types.user_id = projections.users12.id AND auth_method_types.instance_id = projections.users12.instance_id` + ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + ` ON user_idps_count.user_id = projections.users12.id AND user_idps_count.instance_id = projections.users12.instance_id` + - ` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id FROM projections.login_policies5 AS auth_methods_force_mfa ORDER BY auth_methods_force_mfa.is_default) AS auth_methods_force_mfa` + + ` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id, auth_methods_force_mfa.is_default FROM projections.login_policies5 AS auth_methods_force_mfa) AS auth_methods_force_mfa` + ` ON (auth_methods_force_mfa.aggregate_id = projections.users12.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users12.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users12.instance_id` + - ` AS OF SYSTEM TIME '-1 ms + ` ORDER BY auth_methods_force_mfa.is_default LIMIT 1 ` prepareAuthMethodTypesRequiredCols = []string{ "password_set", "type", - "method_type", + "method_types", "idps_count", "force_mfa", "force_mfa_local_only", @@ -319,27 +321,33 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery no result", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ - sqlExpectations: mockQueries( + sqlExpectations: mockQueriesScanErr( regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt), nil, nil, ), + err: func(err error) (error, bool) { + if !zerrors.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, }, - object: &UserAuthMethodRequirements{AuthMethods: []domain.UserAuthMethodType{}, ForceMFA: false}, + object: (*UserAuthMethodRequirements)(nil), }, { name: "prepareUserAuthMethodTypesRequiredQuery one second factor", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ @@ -349,7 +357,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { [][]driver.Value{ { true, - domain.UserAuthMethodTypePasswordless, + database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless}, 1, domain.UserTypeHuman, true, @@ -371,10 +379,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ @@ -384,15 +392,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) { [][]driver.Value{ { true, - domain.UserAuthMethodTypePasswordless, - 1, - domain.UserTypeHuman, - true, - true, - }, - { - true, - domain.UserAuthMethodTypeTOTP, + database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypeTOTP}, 1, domain.UserTypeHuman, true, @@ -416,10 +416,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) { }, { name: "prepareUserAuthMethodTypesRequiredQuery sql err", - prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { + prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) { builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) - return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { - return scan(rows) + return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) { + return scan(row) } }, want: want{ diff --git a/internal/user/repository/view/model/user_session.go b/internal/user/repository/view/model/user_session.go index 4208ea5b53..d761592551 100644 --- a/internal/user/repository/view/model/user_session.go +++ b/internal/user/repository/view/model/user_session.go @@ -35,12 +35,12 @@ const ( ) type UserSessionView struct { - CreationDate time.Time `json:"-" gorm:"column:creation_date"` - ChangeDate time.Time `json:"-" gorm:"column:change_date"` - ResourceOwner string `json:"-" gorm:"column:resource_owner"` - State int32 `json:"-" gorm:"column:state"` - UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"` - UserID string `json:"userID" gorm:"column:user_id;primary_key"` + CreationDate time.Time `json:"-" gorm:"column:creation_date"` + ChangeDate time.Time `json:"-" gorm:"column:change_date"` + ResourceOwner string `json:"-" gorm:"column:resource_owner"` + State sql.Null[domain.UserSessionState] `json:"-" gorm:"column:state"` + UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"` + UserID string `json:"userID" gorm:"column:user_id;primary_key"` // As of https://github.com/zitadel/zitadel/pull/7199 the following 4 attributes // are not projected in the user session handler anymore // and are therefore annotated with a `gorm:"-"`. @@ -79,7 +79,7 @@ func UserSessionToModel(userSession *UserSessionView) *model.UserSessionView { ChangeDate: userSession.ChangeDate, CreationDate: userSession.CreationDate, ResourceOwner: userSession.ResourceOwner, - State: domain.UserSessionState(userSession.State), + State: userSession.State.V, UserAgentID: userSession.UserAgentID, UserID: userSession.UserID, UserName: userSession.UserName.String, @@ -114,7 +114,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { case user.UserV1PasswordCheckSucceededType, user.HumanPasswordCheckSucceededType: v.PasswordVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive case user.UserIDPLoginCheckSucceededType: data := new(es_model.AuthRequest) err := data.SetData(event) @@ -123,12 +123,12 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { } v.ExternalLoginVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.SelectedIDPConfigID = sql.NullString{String: data.SelectedIDPConfigID, Valid: true} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive case user.HumanPasswordlessTokenCheckSucceededType: v.PasswordlessVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.MultiFactorVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFATypeU2FUserVerification)} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive case user.HumanPasswordlessTokenCheckFailedType, user.HumanPasswordlessTokenRemovedType: v.PasswordlessVerification = sql.NullTime{Time: time.Time{}, Valid: true} @@ -207,7 +207,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { v.MultiFactorVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFALevelNotSetUp)} v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true} - v.State = int32(domain.UserSessionStateTerminated) + v.State.V = domain.UserSessionStateTerminated case user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType: v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.SelectedIDPConfigID = sql.NullString{String: "", Valid: true} @@ -218,7 +218,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error { func (v *UserSessionView) setSecondFactorVerification(verificationTime time.Time, mfaType domain.MFAType) { v.SecondFactorVerification = sql.NullTime{Time: verificationTime, Valid: true} v.SecondFactorVerificationType = sql.NullInt32{Int32: int32(mfaType)} - v.State = int32(domain.UserSessionStateActive) + v.State.V = domain.UserSessionStateActive } func (v *UserSessionView) EventTypes() []eventstore.EventType { diff --git a/internal/user/repository/view/model/user_session_test.go b/internal/user/repository/view/model/user_session_test.go index 25acd489c7..3e832ae1fd 100644 --- a/internal/user/repository/view/model/user_session_test.go +++ b/internal/user/repository/view/model/user_session_test.go @@ -9,6 +9,7 @@ import ( "github.com/stretchr/testify/assert" "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models" "github.com/zitadel/zitadel/internal/repository/user" es_model "github.com/zitadel/zitadel/internal/user/repository/eventsourcing/model" @@ -209,7 +210,7 @@ func TestAppendEvent(t *testing.T) { ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true}, PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true}, MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true}, - State: 1, + State: sql.Null[domain.UserSessionState]{V: domain.UserSessionStateTerminated}, }, }, { @@ -228,7 +229,7 @@ func TestAppendEvent(t *testing.T) { ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true}, PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true}, MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true}, - State: 1, + State: sql.Null[domain.UserSessionState]{V: domain.UserSessionStateTerminated}, }, }, }