fix(oidc): IDP and passwordless user auth methods (#7998)

# Which Problems Are Solved

As already mentioned and (partially) fixed in #7992 we discovered,
issues with v2 tokens that where obtained through an IDP, with
passwordless authentication or with password authentication (wihtout any
2FA set up) using the v1 login for zitadel API calls
- (Previous) authentication through an IdP is now correctly treated as
auth method in case of a reauth even when the user is not redirected to
the IdP
- There were some cases where passwordless authentication was
successfully checked but not correctly set as auth method, which denied
access to ZITADEL API
- Users with password and passwordless, but no 2FA set up which
authenticate just wich password can access the ZITADEL API again

Additionally while testing we found out that because of #7969 the login
UI could completely break / block with the following error:
`sql: Scan error on column index 3, name "state": converting NULL to
int32 is unsupported (Internal)`
# How the Problems Are Solved

- IdP checks are treated the same way as other factors and it's ensured
that a succeeded check within the configured timeframe will always
provide the idp auth method
- `MFATypesAllowed` checks for possible passwordless authentication
- As with the v1 login, the token check now only requires MFA if the
policy is set or the user has 2FA set up
- UserAuthMethodsRequirements now always uses the correctly policy to
check for MFA enforcement
- `State` column is handled as nullable and additional events set the
state to active (as before #7969)

# Additional Changes

- Console now also checks for 403 (mfa required) errors (e.g. after
setting up the first 2FA in console) and redirects the user to the login
UI (with the current id_token as id_token_hint)
- Possible duplicates in auth methods / AMRs are removed now as well.

# Additional Context

- Bugs were introduced in #7822 and # and 7969 and only part of a
pre-release.
- partially already fixed with #7992
- Reported internally.
This commit is contained in:
Livio Spring 2024-05-28 10:59:49 +02:00 committed by GitHub
parent 4dc86c2415
commit ec222a13d7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 281 additions and 104 deletions

View File

@ -35,6 +35,10 @@ export class AuthenticationService {
return from(this.oauthService.loadUserProfile()); return from(this.oauthService.loadUserProfile());
} }
public getIdToken(): string {
return this.oauthService.getIdToken();
}
public async authenticate(partialConfig?: Partial<AuthConfig>, force: boolean = false): Promise<boolean> { public async authenticate(partialConfig?: Partial<AuthConfig>, force: boolean = false): Promise<boolean> {
if (partialConfig) { if (partialConfig) {
Object.assign(this.authConfig, partialConfig); Object.assign(this.authConfig, partialConfig);

View File

@ -18,6 +18,7 @@ import { I18nInterceptor } from './interceptors/i18n.interceptor';
import { OrgInterceptor } from './interceptors/org.interceptor'; import { OrgInterceptor } from './interceptors/org.interceptor';
import { StorageService } from './storage.service'; import { StorageService } from './storage.service';
import { FeatureServiceClient } from '../proto/generated/zitadel/feature/v2beta/Feature_serviceServiceClientPb'; import { FeatureServiceClient } from '../proto/generated/zitadel/feature/v2beta/Feature_serviceServiceClientPb';
import { GrpcAuthService } from './grpc-auth.service';
@Injectable({ @Injectable({
providedIn: 'root', providedIn: 'root',

View File

@ -2,11 +2,13 @@ import { Injectable } from '@angular/core';
import { MatDialog } from '@angular/material/dialog'; import { MatDialog } from '@angular/material/dialog';
import { Request, UnaryInterceptor, UnaryResponse } from 'grpc-web'; import { Request, UnaryInterceptor, UnaryResponse } from 'grpc-web';
import { Subject } from 'rxjs'; import { Subject } from 'rxjs';
import { debounceTime, filter, first, take } from 'rxjs/operators'; import { debounceTime, filter, first, map, take, tap } from 'rxjs/operators';
import { WarnDialogComponent } from 'src/app/modules/warn-dialog/warn-dialog.component'; import { WarnDialogComponent } from 'src/app/modules/warn-dialog/warn-dialog.component';
import { AuthenticationService } from '../authentication.service'; import { AuthenticationService } from '../authentication.service';
import { StorageService } from '../storage.service'; import { StorageService } from '../storage.service';
import { AuthConfig } from 'angular-oauth2-oidc';
import { GrpcAuthService } from '../grpc-auth.service';
const authorizationKey = 'Authorization'; const authorizationKey = 'Authorization';
const bearerPrefix = 'Bearer'; const bearerPrefix = 'Bearer';
@ -44,7 +46,7 @@ export class AuthInterceptor<TReq = unknown, TResp = unknown> implements UnaryIn
return response; return response;
}) })
.catch(async (error: any) => { .catch(async (error: any) => {
if (error.code === 16) { if (error.code === 16 || (error.code === 7 && error.message === 'mfa required (AUTHZ-Kl3p0)')) {
this.triggerDialog.next(true); this.triggerDialog.next(true);
} }
return Promise.reject(error); return Promise.reject(error);
@ -67,7 +69,13 @@ export class AuthInterceptor<TReq = unknown, TResp = unknown> implements UnaryIn
.pipe(take(1)) .pipe(take(1))
.subscribe((resp) => { .subscribe((resp) => {
if (resp) { if (resp) {
this.authenticationService.authenticate(undefined, true); const idToken = this.authenticationService.getIdToken();
const configWithPrompt: Partial<AuthConfig> = {
customQueryParams: {
id_token_hint: idToken,
},
};
this.authenticationService.authenticate(configWithPrompt, true);
} }
}); });
} }

View File

@ -18,6 +18,7 @@ import (
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/integration" "github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/auth" "github.com/zitadel/zitadel/pkg/grpc/auth"
mgmt "github.com/zitadel/zitadel/pkg/grpc/management"
oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta" oidc_pb "github.com/zitadel/zitadel/pkg/grpc/oidc/v2beta"
session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta" session "github.com/zitadel/zitadel/pkg/grpc/session/v2beta"
user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta" user "github.com/zitadel/zitadel/pkg/grpc/user/v2beta"
@ -26,6 +27,7 @@ import (
var ( var (
CTX context.Context CTX context.Context
CTXLOGIN context.Context CTXLOGIN context.Context
CTXIAM context.Context
Tester *integration.Tester Tester *integration.Tester
User *user.AddHumanUserResponse User *user.AddHumanUserResponse
) )
@ -50,6 +52,7 @@ func TestMain(m *testing.M) {
Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword, false) Tester.SetUserPassword(CTX, User.GetUserId(), integration.UserPassword, false)
Tester.RegisterUserPasskey(CTX, User.GetUserId()) Tester.RegisterUserPasskey(CTX, User.GetUserId())
CTXLOGIN = Tester.WithAuthorization(ctx, integration.Login) CTXLOGIN = Tester.WithAuthorization(ctx, integration.Login)
CTXIAM = Tester.WithAuthorization(ctx, integration.IAMOwner)
return m.Run() return m.Run()
}()) }())
} }
@ -117,10 +120,13 @@ func Test_ZITADEL_API_missing_authentication(t *testing.T) {
require.Nil(t, myUserResp) require.Nil(t, myUserResp)
} }
func Test_ZITADEL_API_missing_mfa(t *testing.T) { func Test_ZITADEL_API_missing_mfa_2fa_setup(t *testing.T) {
clientID, _ := createClient(t) clientID, _ := createClient(t)
userResp := Tester.CreateHumanUser(CTX)
Tester.SetUserPassword(CTX, userResp.GetUserId(), integration.UserPassword, false)
Tester.RegisterUserU2F(CTX, userResp.GetUserId())
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTX, User.GetUserId(), integration.UserPassword) sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userResp.GetUserId(), integration.UserPassword)
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{ linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID, AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{ CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
@ -136,7 +142,7 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) {
code := assertCodeResponse(t, linkResp.GetCallbackUrl()) code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI) tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err) require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, User.GetUserId(), armPassword, startTime, changeTime) assertIDTokenClaims(t, tokens.IDTokenClaims, userResp.GetUserId(), armPassword, startTime, changeTime)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken)) ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
@ -145,6 +151,62 @@ func Test_ZITADEL_API_missing_mfa(t *testing.T) {
require.Nil(t, myUserResp) require.Nil(t, myUserResp)
} }
func Test_ZITADEL_API_missing_mfa_policy(t *testing.T) {
clientID, _ := createClient(t)
org := Tester.CreateOrganization(CTXIAM, fmt.Sprintf("ZITADEL_API_MISSING_MFA_%d", time.Now().UnixNano()), fmt.Sprintf("%d@mouse.com", time.Now().UnixNano()))
userID := org.CreatedAdmins[0].GetUserId()
Tester.SetUserPassword(CTXIAM, userID, integration.UserPassword, false)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope)
sessionID, sessionToken, startTime, changeTime := Tester.CreatePasswordSession(t, CTXLOGIN, userID, integration.UserPassword)
linkResp, err := Tester.Client.OIDCv2.CreateCallback(CTXLOGIN, &oidc_pb.CreateCallbackRequest{
AuthRequestId: authRequestID,
CallbackKind: &oidc_pb.CreateCallbackRequest_Session{
Session: &oidc_pb.Session{
SessionId: sessionID,
SessionToken: sessionToken,
},
},
})
require.NoError(t, err)
// code exchange
code := assertCodeResponse(t, linkResp.GetCallbackUrl())
tokens, err := exchangeTokens(t, clientID, code, redirectURI)
require.NoError(t, err)
assertIDTokenClaims(t, tokens.IDTokenClaims, userID, armPassword, startTime, changeTime)
ctx := metadata.AppendToOutgoingContext(context.Background(), "Authorization", fmt.Sprintf("%s %s", tokens.TokenType, tokens.AccessToken))
// pre check if request would succeed
myUserResp, err := Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{})
require.NoError(t, err)
require.Equal(t, userID, myUserResp.GetUser().GetId())
// require MFA
ctxOrg := metadata.AppendToOutgoingContext(CTXIAM, "x-zitadel-orgid", org.GetOrganizationId())
_, err = Tester.Client.Mgmt.AddCustomLoginPolicy(ctxOrg, &mgmt.AddCustomLoginPolicyRequest{
ForceMfa: true,
})
require.NoError(t, err)
// make sure policy is projected
retryDuration := 5 * time.Second
if ctxDeadline, ok := CTX.Deadline(); ok {
retryDuration = time.Until(ctxDeadline)
}
require.EventuallyWithT(t, func(ttt *assert.CollectT) {
got, getErr := Tester.Client.Mgmt.GetLoginPolicy(ctxOrg, &mgmt.GetLoginPolicyRequest{})
assert.NoError(ttt, getErr)
assert.False(ttt, got.GetPolicy().IsDefault)
}, retryDuration, time.Millisecond*100, "timeout waiting for login policy")
// now it must fail
myUserResp, err = Tester.Client.Auth.GetMyUser(ctx, &auth.GetMyUserRequest{})
require.Error(t, err)
require.Nil(t, myUserResp)
}
func Test_ZITADEL_API_success(t *testing.T) { func Test_ZITADEL_API_success(t *testing.T) {
clientID, _ := createClient(t) clientID, _ := createClient(t)
authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope) authRequestID := createAuthRequest(t, clientID, redirectURI, oidc.ScopeOpenID, zitadelAudienceScope)

View File

@ -2,6 +2,7 @@ package eventstore
import ( import (
"context" "context"
"slices"
"strings" "strings"
"time" "time"
@ -1030,15 +1031,11 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
if err != nil { if err != nil {
return nil, err return nil, err
} }
if (!isInternalLogin || len(idps.Links) > 0) && len(request.LinkingUsers) == 0 && !checkVerificationTimeMaxAge(userSession.ExternalLoginVerification, request.LoginPolicy.ExternalLoginCheckLifetime, request) { if (!isInternalLogin || len(idps.Links) > 0) && len(request.LinkingUsers) == 0 {
selectedIDPConfigID := request.SelectedIDPConfigID step := repo.idpChecked(request, idps.Links, userSession)
if selectedIDPConfigID == "" { if step != nil {
selectedIDPConfigID = userSession.SelectedIDPConfigID return append(steps, step), nil
} }
if selectedIDPConfigID == "" {
selectedIDPConfigID = idps.Links[0].IDPID
}
return append(steps, &domain.ExternalLoginStep{SelectedIDPConfigID: selectedIDPConfigID}), nil
} }
if isInternalLogin || (!isInternalLogin && len(request.LinkingUsers) > 0) { if isInternalLogin || (!isInternalLogin && len(request.LinkingUsers) > 0) {
step := repo.firstFactorChecked(request, user, userSession) step := repo.firstFactorChecked(request, user, userSession)
@ -1198,6 +1195,7 @@ func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, use
var step domain.NextStep var step domain.NextStep
if request.LoginPolicy.PasswordlessType != domain.PasswordlessTypeNotAllowed && user.IsPasswordlessReady() { if request.LoginPolicy.PasswordlessType != domain.PasswordlessTypeNotAllowed && user.IsPasswordlessReady() {
if checkVerificationTimeMaxAge(userSession.PasswordlessVerification, request.LoginPolicy.MultiFactorCheckLifetime, request) { if checkVerificationTimeMaxAge(userSession.PasswordlessVerification, request.LoginPolicy.MultiFactorCheckLifetime, request) {
request.MFAsVerified = append(request.MFAsVerified, domain.MFATypeU2FUserVerification)
request.AuthTime = userSession.PasswordlessVerification request.AuthTime = userSession.PasswordlessVerification
return nil return nil
} }
@ -1225,8 +1223,27 @@ func (repo *AuthRequestRepo) firstFactorChecked(request *domain.AuthRequest, use
return &domain.PasswordStep{} return &domain.PasswordStep{}
} }
func (repo *AuthRequestRepo) idpChecked(request *domain.AuthRequest, idps []*query.IDPUserLink, userSession *user_model.UserSessionView) domain.NextStep {
if checkVerificationTimeMaxAge(userSession.ExternalLoginVerification, request.LoginPolicy.ExternalLoginCheckLifetime, request) {
request.IDPLoginChecked = true
request.AuthTime = userSession.ExternalLoginVerification
return nil
}
selectedIDPConfigID := request.SelectedIDPConfigID
if selectedIDPConfigID == "" {
selectedIDPConfigID = userSession.SelectedIDPConfigID
}
if selectedIDPConfigID == "" && len(idps) > 0 {
selectedIDPConfigID = idps[0].IDPID
}
return &domain.ExternalLoginStep{SelectedIDPConfigID: selectedIDPConfigID}
}
func (repo *AuthRequestRepo) mfaChecked(userSession *user_model.UserSessionView, request *domain.AuthRequest, user *user_model.UserView, isInternalAuthentication bool) (domain.NextStep, bool, error) { func (repo *AuthRequestRepo) mfaChecked(userSession *user_model.UserSessionView, request *domain.AuthRequest, user *user_model.UserView, isInternalAuthentication bool) (domain.NextStep, bool, error) {
mfaLevel := request.MFALevel() mfaLevel := request.MFALevel()
if slices.Contains(request.MFAsVerified, domain.MFATypeU2FUserVerification) {
return nil, true, nil
}
allowedProviders, required := user.MFATypesAllowed(mfaLevel, request.LoginPolicy, isInternalAuthentication) allowedProviders, required := user.MFATypesAllowed(mfaLevel, request.LoginPolicy, isInternalAuthentication)
promptRequired := (user.MFAMaxSetUp < mfaLevel) || (len(allowedProviders) == 0 && required) promptRequired := (user.MFAMaxSetUp < mfaLevel) || (len(allowedProviders) == 0 && required)
if promptRequired || !repo.mfaSkippedOrSetUp(user, request) { if promptRequired || !repo.mfaSkippedOrSetUp(user, request) {

View File

@ -89,7 +89,7 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie
for i, user := range m.Users { for i, user := range m.Users {
sessions[i] = &user_view_model.UserSessionView{ sessions[i] = &user_view_model.UserSessionView{
ResourceOwner: user.ResourceOwner, ResourceOwner: user.ResourceOwner,
State: int32(user.SessionState), State: sql.Null[domain.UserSessionState]{V: user.SessionState},
UserID: user.UserID, UserID: user.UserID,
LoginName: sql.NullString{String: user.LoginName}, LoginName: sql.NullString{String: user.LoginName},
} }
@ -1682,11 +1682,12 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
isInternal bool isInternal bool
} }
tests := []struct { tests := []struct {
name string name string
args args args args
want domain.NextStep want domain.NextStep
wantChecked bool wantChecked bool
errFunc func(err error) bool errFunc func(err error) bool
wantMFAVerified []domain.MFAType
}{ }{
//{ //{
// "required, prompt and false", //TODO: enable when LevelsOfAssurance is checked // "required, prompt and false", //TODO: enable when LevelsOfAssurance is checked
@ -1718,6 +1719,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
nil, nil,
false, false,
zerrors.IsPreconditionFailed, zerrors.IsPreconditionFailed,
nil,
}, },
{ {
"not set up, no mfas configured, no prompt and true", "not set up, no mfas configured, no prompt and true",
@ -1737,6 +1739,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
nil, nil,
true, true,
nil, nil,
nil,
}, },
{ {
"not set up, prompt and false", "not set up, prompt and false",
@ -1761,6 +1764,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
}, },
false, false,
nil, nil,
nil,
}, },
{ {
"not set up, forced by org, true", "not set up, forced by org, true",
@ -1787,6 +1791,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
}, },
false, false,
nil, nil,
nil,
}, },
{ {
"not set up and skipped, true", "not set up and skipped, true",
@ -1807,6 +1812,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
nil, nil,
true, true,
nil, nil,
nil,
}, },
{ {
"checked second factor, true", "checked second factor, true",
@ -1829,6 +1835,38 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
nil, nil,
true, true,
nil, nil,
[]domain.MFAType{domain.MFATypeTOTP},
},
{
"checked passwordless, true",
args{
request: &domain.AuthRequest{
LoginPolicy: &domain.LoginPolicy{
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
SecondFactorCheckLifetime: 18 * time.Hour,
MultiFactors: []domain.MultiFactorType{domain.MultiFactorTypeU2FWithPIN},
MultiFactorCheckLifetime: 18 * time.Hour,
},
MFAsVerified: []domain.MFAType{domain.MFATypeU2FUserVerification},
},
user: &user_model.UserView{
HumanView: &user_model.HumanView{
MFAMaxSetUp: domain.MFALevelMultiFactor,
PasswordlessTokens: []*user_model.WebAuthNView{
{
TokenID: "tokenID",
State: user_model.MFAStateReady,
},
},
},
},
userSession: &user_model.UserSessionView{PasswordlessVerification: testNow.Add(-5 * time.Hour)},
isInternal: true,
},
nil,
true,
nil,
[]domain.MFAType{domain.MFATypeU2FUserVerification},
}, },
{ {
"not checked, check and false", "not checked, check and false",
@ -1854,6 +1892,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
}, },
false, false,
nil, nil,
nil,
}, },
{ {
"external not checked or forced but set up, want step", "external not checked or forced but set up, want step",
@ -1878,6 +1917,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
}, },
false, false,
nil, nil,
nil,
}, },
{ {
"external not forced but checked", "external not forced but checked",
@ -1900,6 +1940,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
nil, nil,
true, true,
nil, nil,
[]domain.MFAType{domain.MFATypeTOTP},
}, },
{ {
"external not checked but required, want step", "external not checked but required, want step",
@ -1927,6 +1968,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
}, },
false, false,
nil, nil,
nil,
}, },
{ {
"external not checked but local required", "external not checked but local required",
@ -1950,6 +1992,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
nil, nil,
true, true,
nil, nil,
nil,
}, },
} }
for _, tt := range tests { for _, tt := range tests {
@ -1964,6 +2007,7 @@ func TestAuthRequestRepo_mfaChecked(t *testing.T) {
t.Errorf("mfaChecked() checked = %v, want %v", ok, tt.wantChecked) t.Errorf("mfaChecked() checked = %v, want %v", ok, tt.wantChecked)
} }
assert.Equal(t, tt.want, got) assert.Equal(t, tt.want, got)
assert.ElementsMatch(t, tt.args.request.MFAsVerified, tt.wantMFAVerified)
}) })
} }
} }

View File

@ -32,7 +32,7 @@ func (repo *UserRepo) UserSessionUserIDsByAgentID(ctx context.Context, agentID s
} }
userIDs := make([]string, 0, len(userSessions)) userIDs := make([]string, 0, len(userSessions))
for _, session := range userSessions { for _, session := range userSessions {
if session.State == int32(domain.UserSessionStateActive) { if session.State.V == domain.UserSessionStateActive {
userIDs = append(userIDs, session.UserID) userIDs = append(userIDs, session.UserID)
} }
} }

View File

@ -220,6 +220,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
user.HumanPasswordCheckFailedType: user.HumanPasswordCheckFailedType:
columns, err := sessionColumns(event, columns, err := sessionColumns(event,
handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}), handler.NewCol(view_model.UserSessionKeyPasswordVerification, time.Time{}),
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -241,6 +242,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
user.HumanU2FTokenCheckFailedType: user.HumanU2FTokenCheckFailedType:
columns, err := sessionColumns(event, columns, err := sessionColumns(event,
handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}), handler.NewCol(view_model.UserSessionKeySecondFactorVerification, time.Time{}),
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
) )
if err != nil { if err != nil {
return nil, err return nil, err
@ -317,6 +319,7 @@ func (u *UserSession) Reduce(event eventstore.Event) (_ *handler.Statement, err
columns, err := sessionColumns(event, columns, err := sessionColumns(event,
handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}), handler.NewCol(view_model.UserSessionKeyPasswordlessVerification, time.Time{}),
handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, time.Time{}), handler.NewCol(view_model.UserSessionKeyMultiFactorVerification, time.Time{}),
handler.NewCol(view_model.UserSessionKeyState, domain.UserSessionStateActive),
) )
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -172,7 +172,7 @@ func (repo *TokenVerifierRepo) verifySessionToken(ctx context.Context, sessionID
} }
// checkAuthentication ensures the session or token was authenticated (at least a single [domain.UserAuthMethodType]). // checkAuthentication ensures the session or token was authenticated (at least a single [domain.UserAuthMethodType]).
// It will also check if there was a multi factor authentication, if either MFA is forced by the login policy or if the user has set up any // It will also check if there was a multi factor authentication, if either MFA is forced by the login policy or if the user has set up any second factor
func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMethods []domain.UserAuthMethodType, userID string) error { func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMethods []domain.UserAuthMethodType, userID string) error {
if len(authMethods) == 0 { if len(authMethods) == 0 {
return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "authentication required") return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "authentication required")
@ -191,7 +191,7 @@ func (repo *TokenVerifierRepo) checkAuthentication(ctx context.Context, authMeth
requirements.ForceMFA, requirements.ForceMFA,
requirements.ForceMFALocalOnly, requirements.ForceMFALocalOnly,
!hasIDPAuthentication(authMethods)) || !hasIDPAuthentication(authMethods)) ||
domain.HasMFA(requirements.AuthMethods) { domain.Has2FA(requirements.AuthMethods) {
return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required") return zerrors.ThrowPermissionDenied(nil, "AUTHZ-Kl3p0", "mfa required")
} }
return nil return nil

View File

@ -1,6 +1,7 @@
package domain package domain
import ( import (
"slices"
"strings" "strings"
"time" "time"
@ -81,7 +82,7 @@ func (a *AuthRequest) AuthMethods() []UserAuthMethodType {
for _, mfa := range a.MFAsVerified { for _, mfa := range a.MFAsVerified {
list = append(list, mfa.UserAuthMethodType()) list = append(list, mfa.UserAuthMethodType())
} }
return list return slices.Compact(list)
} }
type ExternalUser struct { type ExternalUser struct {

View File

@ -62,7 +62,8 @@ func HasMFA(methods []UserAuthMethodType) bool {
UserAuthMethodTypeOTPSMS, UserAuthMethodTypeOTPSMS,
UserAuthMethodTypeOTPEmail, UserAuthMethodTypeOTPEmail,
UserAuthMethodTypeIDP, UserAuthMethodTypeIDP,
UserAuthMethodTypeOTP: UserAuthMethodTypeOTP,
UserAuthMethodTypePrivateKey:
factors++ factors++
case UserAuthMethodTypeUnspecified, case UserAuthMethodTypeUnspecified,
userAuthMethodTypeCount: userAuthMethodTypeCount:
@ -72,6 +73,30 @@ func HasMFA(methods []UserAuthMethodType) bool {
return factors > 1 return factors > 1
} }
// Has2FA checks whether the auth factors provided are a second factor and will return true if at least one is.
func Has2FA(methods []UserAuthMethodType) bool {
var factors int
for _, method := range methods {
switch method {
case
UserAuthMethodTypeU2F,
UserAuthMethodTypeTOTP,
UserAuthMethodTypeOTPSMS,
UserAuthMethodTypeOTPEmail,
UserAuthMethodTypeOTP:
factors++
case UserAuthMethodTypeUnspecified,
UserAuthMethodTypePassword,
UserAuthMethodTypePasswordless,
UserAuthMethodTypeIDP,
UserAuthMethodTypePrivateKey,
userAuthMethodTypeCount:
// ignore
}
}
return factors > 0
}
// RequiresMFA checks whether the user requires to authenticate with multiple auth factors based on the LoginPolicy and the authentication type. // RequiresMFA checks whether the user requires to authenticate with multiple auth factors based on the LoginPolicy and the authentication type.
// Internal authentication will require MFA if either option is activated. // Internal authentication will require MFA if either option is activated.
// External authentication will only require MFA if it's forced generally and not local only. // External authentication will only require MFA if it's forced generally and not local only.

View File

@ -3,6 +3,7 @@ package query
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"time" "time"
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
@ -10,6 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call" "github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
@ -69,8 +71,12 @@ var (
authMethodTypeTable = userAuthMethodTable.setAlias("auth_method_types") authMethodTypeTable = userAuthMethodTable.setAlias("auth_method_types")
authMethodTypeUserID = UserAuthMethodColumnUserID.setTable(authMethodTypeTable) authMethodTypeUserID = UserAuthMethodColumnUserID.setTable(authMethodTypeTable)
authMethodTypeInstanceID = UserAuthMethodColumnInstanceID.setTable(authMethodTypeTable) authMethodTypeInstanceID = UserAuthMethodColumnInstanceID.setTable(authMethodTypeTable)
authMethodTypeTypes = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable) authMethodTypeType = UserAuthMethodColumnMethodType.setTable(authMethodTypeTable)
authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable) authMethodTypeTypes = Column{
name: "method_types",
table: authMethodTypeTable,
}
authMethodTypeState = UserAuthMethodColumnState.setTable(authMethodTypeTable)
userIDPsCountTable = idpUserLinkTable.setAlias("user_idps_count") userIDPsCountTable = idpUserLinkTable.setAlias("user_idps_count")
userIDPsCountUserID = IDPUserLinkUserIDCol.setTable(userIDPsCountTable) userIDPsCountUserID = IDPUserLinkUserIDCol.setTable(userIDPsCountTable)
@ -199,8 +205,8 @@ func (q *Queries) ListUserAuthMethodTypesRequired(ctx context.Context, userID st
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest") return nil, zerrors.ThrowInvalidArgument(err, "QUERY-E5ut4", "Errors.Query.InvalidRequest")
} }
err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
requirements, err = scan(rows) requirements, err = scan(row)
return err return err
}, stmt, args...) }, stmt, args...)
if err != nil { if err != nil {
@ -360,7 +366,7 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba
} }
return sq.Select( return sq.Select(
NotifyPasswordSetCol.identifier(), NotifyPasswordSetCol.identifier(),
authMethodTypeTypes.identifier(), authMethodTypeType.identifier(),
userIDPsCountCount.identifier()). userIDPsCountCount.identifier()).
From(userTable.identifier()). From(userTable.identifier()).
LeftJoin(join(NotifyUserIDCol, UserIDCol)). LeftJoin(join(NotifyUserIDCol, UserIDCol)).
@ -411,12 +417,12 @@ func prepareActiveUserAuthMethodTypesQuery(ctx context.Context, db prepareDataba
} }
} }
func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery() loginPolicyQuery, err := prepareAuthMethodsForceMFAQuery()
if err != nil { if err != nil {
return sq.SelectBuilder{}, nil return sq.SelectBuilder{}, nil
} }
authMethodsQuery, authMethodsArgs, err := prepareAuthMethodQuery() authMethodsQuery, authMethodsArgs, err := prepareAggAuthMethodsQuery()
if err != nil { if err != nil {
return sq.SelectBuilder{}, nil return sq.SelectBuilder{}, nil
} }
@ -442,47 +448,41 @@ func prepareUserAuthMethodTypesRequiredQuery(ctx context.Context, db prepareData
userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()). userIDPsCountInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " + LeftJoin("(" + loginPolicyQuery + ") AS " + forceMFATable.alias + " ON " +
"(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " + "(" + forceMFAOrgID.identifier() + " = " + UserInstanceIDCol.identifier() + " OR " + forceMFAOrgID.identifier() + " = " + UserResourceOwnerCol.identifier() + ") AND " +
forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier() + db.Timetravel(call.Took(ctx))). forceMFAInstanceID.identifier() + " = " + UserInstanceIDCol.identifier()).
OrderBy(forceMFAIsDefault.identifier()).
Limit(1).
PlaceholderFormat(sq.Dollar), PlaceholderFormat(sq.Dollar),
func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { func(row *sql.Row) (*UserAuthMethodRequirements, error) {
userAuthMethodTypes := make([]domain.UserAuthMethodType, 0)
var passwordSet sql.NullBool var passwordSet sql.NullBool
var authMethodTypes database.NumberArray[domain.UserAuthMethodType]
var idp sql.NullInt64 var idp sql.NullInt64
var userType sql.NullInt32 var userType sql.NullInt32
var forceMFA sql.NullBool var forceMFA sql.NullBool
var forceMFALocalOnly sql.NullBool var forceMFALocalOnly sql.NullBool
for rows.Next() { err := row.Scan(
var authMethodType sql.NullInt16 &passwordSet,
err := rows.Scan( &authMethodTypes,
&passwordSet, &idp,
&authMethodType, &userType,
&idp, &forceMFA,
&userType, &forceMFALocalOnly,
&forceMFA, )
&forceMFALocalOnly, if err != nil {
) if errors.Is(err, sql.ErrNoRows) {
if err != nil { return nil, zerrors.ThrowNotFound(err, "QUERY-SF3h2", "Errors.Internal")
return nil, err
}
if authMethodType.Valid {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodType(authMethodType.Int16))
} }
return nil, zerrors.ThrowInternal(err, "QUERY-Sf3rt", "Errors.Internal")
} }
if passwordSet.Valid && passwordSet.Bool { if passwordSet.Valid && passwordSet.Bool {
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypePassword) authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypePassword)
} }
if idp.Valid && idp.Int64 > 0 { if idp.Valid && idp.Int64 > 0 {
logging.Error("IDP", idp.Int64) authMethodTypes = append(authMethodTypes, domain.UserAuthMethodTypeIDP)
userAuthMethodTypes = append(userAuthMethodTypes, domain.UserAuthMethodTypeIDP)
}
if err := rows.Close(); err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-W4zje", "Errors.Query.CloseRows")
} }
return &UserAuthMethodRequirements{ return &UserAuthMethodRequirements{
UserType: domain.UserType(userType.Int32), UserType: domain.UserType(userType.Int32),
AuthMethods: userAuthMethodTypes, AuthMethods: authMethodTypes,
ForceMFA: forceMFA.Bool, ForceMFA: forceMFA.Bool,
ForceMFALocalOnly: forceMFALocalOnly.Bool, ForceMFALocalOnly: forceMFALocalOnly.Bool,
}, nil }, nil
@ -505,7 +505,7 @@ func prepareAuthMethodsIDPsQuery() (string, error) {
func prepareAuthMethodQuery() (string, []interface{}, error) { func prepareAuthMethodQuery() (string, []interface{}, error) {
return sq.Select( return sq.Select(
"DISTINCT("+authMethodTypeTypes.identifier()+")", "DISTINCT("+authMethodTypeType.identifier()+")",
authMethodTypeUserID.identifier(), authMethodTypeUserID.identifier(),
authMethodTypeInstanceID.identifier()). authMethodTypeInstanceID.identifier()).
From(authMethodTypeTable.identifier()). From(authMethodTypeTable.identifier()).
@ -513,15 +513,26 @@ func prepareAuthMethodQuery() (string, []interface{}, error) {
ToSql() ToSql()
} }
func prepareAggAuthMethodsQuery() (string, []interface{}, error) {
return sq.Select(
"array_agg(DISTINCT("+authMethodTypeType.identifier()+")) as method_types",
authMethodTypeUserID.identifier(),
authMethodTypeInstanceID.identifier()).
From(authMethodTypeTable.identifier()).
Where(sq.Eq{authMethodTypeState.identifier(): domain.MFAStateReady}).
GroupBy(authMethodTypeInstanceID.identifier(), authMethodTypeUserID.identifier()).
ToSql()
}
func prepareAuthMethodsForceMFAQuery() (string, error) { func prepareAuthMethodsForceMFAQuery() (string, error) {
loginPolicyQuery, _, err := sq.Select( loginPolicyQuery, _, err := sq.Select(
forceMFAForce.identifier(), forceMFAForce.identifier(),
forceMFAForceLocalOnly.identifier(), forceMFAForceLocalOnly.identifier(),
forceMFAInstanceID.identifier(), forceMFAInstanceID.identifier(),
forceMFAOrgID.identifier(), forceMFAOrgID.identifier(),
forceMFAIsDefault.identifier(),
). ).
From(forceMFATable.identifier()). From(forceMFATable.identifier()).
OrderBy(forceMFAIsDefault.identifier()).
ToSql() ToSql()
return loginPolicyQuery, err return loginPolicyQuery, err
} }

View File

@ -11,7 +11,9 @@ import (
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
"github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/zerrors"
) )
var ( var (
@ -57,27 +59,27 @@ var (
"idps_count", "idps_count",
} }
prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` + prepareAuthMethodTypesRequiredStmt = `SELECT projections.users12_notifications.password_set,` +
` auth_method_types.method_type,` + ` auth_method_types.method_types,` +
` user_idps_count.count,` + ` user_idps_count.count,` +
` projections.users12.type,` + ` projections.users12.type,` +
` auth_methods_force_mfa.force_mfa,` + ` auth_methods_force_mfa.force_mfa,` +
` auth_methods_force_mfa.force_mfa_local_only` + ` auth_methods_force_mfa.force_mfa_local_only` +
` FROM projections.users12` + ` FROM projections.users12` +
` LEFT JOIN projections.users12_notifications ON projections.users12.id = projections.users12_notifications.user_id AND projections.users12.instance_id = projections.users12_notifications.instance_id` + ` LEFT JOIN projections.users12_notifications ON projections.users12.id = projections.users12_notifications.user_id AND projections.users12.instance_id = projections.users12_notifications.instance_id` +
` LEFT JOIN (SELECT DISTINCT(auth_method_types.method_type), auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` + ` LEFT JOIN (SELECT array_agg(DISTINCT(auth_method_types.method_type)) as method_types, auth_method_types.user_id, auth_method_types.instance_id FROM projections.user_auth_methods4 AS auth_method_types` +
` WHERE auth_method_types.state = $1) AS auth_method_types` + ` WHERE auth_method_types.state = $1 GROUP BY auth_method_types.instance_id, auth_method_types.user_id) AS auth_method_types` +
` ON auth_method_types.user_id = projections.users12.id AND auth_method_types.instance_id = projections.users12.instance_id` + ` ON auth_method_types.user_id = projections.users12.id AND auth_method_types.instance_id = projections.users12.instance_id` +
` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` + ` LEFT JOIN (SELECT user_idps_count.user_id, user_idps_count.instance_id, COUNT(user_idps_count.user_id) AS count FROM projections.idp_user_links3 AS user_idps_count` +
` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` + ` GROUP BY user_idps_count.user_id, user_idps_count.instance_id) AS user_idps_count` +
` ON user_idps_count.user_id = projections.users12.id AND user_idps_count.instance_id = projections.users12.instance_id` + ` ON user_idps_count.user_id = projections.users12.id AND user_idps_count.instance_id = projections.users12.instance_id` +
` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id FROM projections.login_policies5 AS auth_methods_force_mfa ORDER BY auth_methods_force_mfa.is_default) AS auth_methods_force_mfa` + ` LEFT JOIN (SELECT auth_methods_force_mfa.force_mfa, auth_methods_force_mfa.force_mfa_local_only, auth_methods_force_mfa.instance_id, auth_methods_force_mfa.aggregate_id, auth_methods_force_mfa.is_default FROM projections.login_policies5 AS auth_methods_force_mfa) AS auth_methods_force_mfa` +
` ON (auth_methods_force_mfa.aggregate_id = projections.users12.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users12.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users12.instance_id` + ` ON (auth_methods_force_mfa.aggregate_id = projections.users12.instance_id OR auth_methods_force_mfa.aggregate_id = projections.users12.resource_owner) AND auth_methods_force_mfa.instance_id = projections.users12.instance_id` +
` AS OF SYSTEM TIME '-1 ms ` ORDER BY auth_methods_force_mfa.is_default LIMIT 1
` `
prepareAuthMethodTypesRequiredCols = []string{ prepareAuthMethodTypesRequiredCols = []string{
"password_set", "password_set",
"type", "type",
"method_type", "method_types",
"idps_count", "idps_count",
"force_mfa", "force_mfa",
"force_mfa_local_only", "force_mfa_local_only",
@ -319,27 +321,33 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery no result", name: "prepareUserAuthMethodTypesRequiredQuery no result",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(rows) return scan(row)
} }
}, },
want: want{ want: want{
sqlExpectations: mockQueries( sqlExpectations: mockQueriesScanErr(
regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt), regexp.QuoteMeta(prepareAuthMethodTypesRequiredStmt),
nil, nil,
nil, nil,
), ),
err: func(err error) (error, bool) {
if !zerrors.IsNotFound(err) {
return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false
}
return nil, true
},
}, },
object: &UserAuthMethodRequirements{AuthMethods: []domain.UserAuthMethodType{}, ForceMFA: false}, object: (*UserAuthMethodRequirements)(nil),
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery one second factor", name: "prepareUserAuthMethodTypesRequiredQuery one second factor",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(rows) return scan(row)
} }
}, },
want: want{ want: want{
@ -349,7 +357,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
[][]driver.Value{ [][]driver.Value{
{ {
true, true,
domain.UserAuthMethodTypePasswordless, database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless},
1, 1,
domain.UserTypeHuman, domain.UserTypeHuman,
true, true,
@ -371,10 +379,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors", name: "prepareUserAuthMethodTypesRequiredQuery multiple second factors",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(rows) return scan(row)
} }
}, },
want: want{ want: want{
@ -384,15 +392,7 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
[][]driver.Value{ [][]driver.Value{
{ {
true, true,
domain.UserAuthMethodTypePasswordless, database.NumberArray[domain.UserAuthMethodType]{domain.UserAuthMethodTypePasswordless, domain.UserAuthMethodTypeTOTP},
1,
domain.UserTypeHuman,
true,
true,
},
{
true,
domain.UserAuthMethodTypeTOTP,
1, 1,
domain.UserTypeHuman, domain.UserTypeHuman,
true, true,
@ -416,10 +416,10 @@ func Test_UserAuthMethodPrepares(t *testing.T) {
}, },
{ {
name: "prepareUserAuthMethodTypesRequiredQuery sql err", name: "prepareUserAuthMethodTypesRequiredQuery sql err",
prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Rows) (*UserAuthMethodRequirements, error)) { prepare: func(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*UserAuthMethodRequirements, error)) {
builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db) builder, scan := prepareUserAuthMethodTypesRequiredQuery(ctx, db)
return builder, func(rows *sql.Rows) (*UserAuthMethodRequirements, error) { return builder, func(row *sql.Row) (*UserAuthMethodRequirements, error) {
return scan(rows) return scan(row)
} }
}, },
want: want{ want: want{

View File

@ -35,12 +35,12 @@ const (
) )
type UserSessionView struct { type UserSessionView struct {
CreationDate time.Time `json:"-" gorm:"column:creation_date"` CreationDate time.Time `json:"-" gorm:"column:creation_date"`
ChangeDate time.Time `json:"-" gorm:"column:change_date"` ChangeDate time.Time `json:"-" gorm:"column:change_date"`
ResourceOwner string `json:"-" gorm:"column:resource_owner"` ResourceOwner string `json:"-" gorm:"column:resource_owner"`
State int32 `json:"-" gorm:"column:state"` State sql.Null[domain.UserSessionState] `json:"-" gorm:"column:state"`
UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"` UserAgentID string `json:"userAgentID" gorm:"column:user_agent_id;primary_key"`
UserID string `json:"userID" gorm:"column:user_id;primary_key"` UserID string `json:"userID" gorm:"column:user_id;primary_key"`
// As of https://github.com/zitadel/zitadel/pull/7199 the following 4 attributes // As of https://github.com/zitadel/zitadel/pull/7199 the following 4 attributes
// are not projected in the user session handler anymore // are not projected in the user session handler anymore
// and are therefore annotated with a `gorm:"-"`. // and are therefore annotated with a `gorm:"-"`.
@ -79,7 +79,7 @@ func UserSessionToModel(userSession *UserSessionView) *model.UserSessionView {
ChangeDate: userSession.ChangeDate, ChangeDate: userSession.ChangeDate,
CreationDate: userSession.CreationDate, CreationDate: userSession.CreationDate,
ResourceOwner: userSession.ResourceOwner, ResourceOwner: userSession.ResourceOwner,
State: domain.UserSessionState(userSession.State), State: userSession.State.V,
UserAgentID: userSession.UserAgentID, UserAgentID: userSession.UserAgentID,
UserID: userSession.UserID, UserID: userSession.UserID,
UserName: userSession.UserName.String, UserName: userSession.UserName.String,
@ -114,7 +114,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error {
case user.UserV1PasswordCheckSucceededType, case user.UserV1PasswordCheckSucceededType,
user.HumanPasswordCheckSucceededType: user.HumanPasswordCheckSucceededType:
v.PasswordVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.PasswordVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true}
v.State = int32(domain.UserSessionStateActive) v.State.V = domain.UserSessionStateActive
case user.UserIDPLoginCheckSucceededType: case user.UserIDPLoginCheckSucceededType:
data := new(es_model.AuthRequest) data := new(es_model.AuthRequest)
err := data.SetData(event) err := data.SetData(event)
@ -123,12 +123,12 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error {
} }
v.ExternalLoginVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.ExternalLoginVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true}
v.SelectedIDPConfigID = sql.NullString{String: data.SelectedIDPConfigID, Valid: true} v.SelectedIDPConfigID = sql.NullString{String: data.SelectedIDPConfigID, Valid: true}
v.State = int32(domain.UserSessionStateActive) v.State.V = domain.UserSessionStateActive
case user.HumanPasswordlessTokenCheckSucceededType: case user.HumanPasswordlessTokenCheckSucceededType:
v.PasswordlessVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.PasswordlessVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true}
v.MultiFactorVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true} v.MultiFactorVerification = sql.NullTime{Time: event.CreatedAt(), Valid: true}
v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFATypeU2FUserVerification)} v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFATypeU2FUserVerification)}
v.State = int32(domain.UserSessionStateActive) v.State.V = domain.UserSessionStateActive
case user.HumanPasswordlessTokenCheckFailedType, case user.HumanPasswordlessTokenCheckFailedType,
user.HumanPasswordlessTokenRemovedType: user.HumanPasswordlessTokenRemovedType:
v.PasswordlessVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.PasswordlessVerification = sql.NullTime{Time: time.Time{}, Valid: true}
@ -207,7 +207,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error {
v.MultiFactorVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.MultiFactorVerification = sql.NullTime{Time: time.Time{}, Valid: true}
v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFALevelNotSetUp)} v.MultiFactorVerificationType = sql.NullInt32{Int32: int32(domain.MFALevelNotSetUp)}
v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true}
v.State = int32(domain.UserSessionStateTerminated) v.State.V = domain.UserSessionStateTerminated
case user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType: case user.UserIDPLinkRemovedType, user.UserIDPLinkCascadeRemovedType:
v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true} v.ExternalLoginVerification = sql.NullTime{Time: time.Time{}, Valid: true}
v.SelectedIDPConfigID = sql.NullString{String: "", Valid: true} v.SelectedIDPConfigID = sql.NullString{String: "", Valid: true}
@ -218,7 +218,7 @@ func (v *UserSessionView) AppendEvent(event eventstore.Event) error {
func (v *UserSessionView) setSecondFactorVerification(verificationTime time.Time, mfaType domain.MFAType) { func (v *UserSessionView) setSecondFactorVerification(verificationTime time.Time, mfaType domain.MFAType) {
v.SecondFactorVerification = sql.NullTime{Time: verificationTime, Valid: true} v.SecondFactorVerification = sql.NullTime{Time: verificationTime, Valid: true}
v.SecondFactorVerificationType = sql.NullInt32{Int32: int32(mfaType)} v.SecondFactorVerificationType = sql.NullInt32{Int32: int32(mfaType)}
v.State = int32(domain.UserSessionStateActive) v.State.V = domain.UserSessionStateActive
} }
func (v *UserSessionView) EventTypes() []eventstore.EventType { func (v *UserSessionView) EventTypes() []eventstore.EventType {

View File

@ -9,6 +9,7 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models" es_models "github.com/zitadel/zitadel/internal/eventstore/v1/models"
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
es_model "github.com/zitadel/zitadel/internal/user/repository/eventsourcing/model" es_model "github.com/zitadel/zitadel/internal/user/repository/eventsourcing/model"
@ -209,7 +210,7 @@ func TestAppendEvent(t *testing.T) {
ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true}, ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true},
PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true}, PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true},
MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true}, MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true},
State: 1, State: sql.Null[domain.UserSessionState]{V: domain.UserSessionStateTerminated},
}, },
}, },
{ {
@ -228,7 +229,7 @@ func TestAppendEvent(t *testing.T) {
ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true}, ExternalLoginVerification: sql.NullTime{Time: time.Time{}, Valid: true},
PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true}, PasswordlessVerification: sql.NullTime{Time: time.Time{}, Valid: true},
MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true}, MultiFactorVerification: sql.NullTime{Time: time.Time{}, Valid: true},
State: 1, State: sql.Null[domain.UserSessionState]{V: domain.UserSessionStateTerminated},
}, },
}, },
} }