mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 08:27:32 +00:00
fix: user session with external login (#797)
* fix: user session with external login * fix: tests * fix: tests * fix: change idp config name
This commit is contained in:
@@ -43,10 +43,11 @@ type AuthRequestRepo struct {
|
||||
|
||||
IdGenerator id.Generator
|
||||
|
||||
PasswordCheckLifeTime time.Duration
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
MfaSoftwareCheckLifeTime time.Duration
|
||||
MfaHardwareCheckLifeTime time.Duration
|
||||
PasswordCheckLifeTime time.Duration
|
||||
ExternalLoginCheckLifeTime time.Duration
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
MfaSoftwareCheckLifeTime time.Duration
|
||||
MfaHardwareCheckLifeTime time.Duration
|
||||
|
||||
IAMID string
|
||||
}
|
||||
@@ -164,7 +165,7 @@ func (repo *AuthRequestRepo) SelectExternalIDP(ctx context.Context, authReqID, i
|
||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, externalUser *model.ExternalUser) error {
|
||||
func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, externalUser *model.ExternalUser, info *model.BrowserInfo) error {
|
||||
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -176,6 +177,11 @@ func (repo *AuthRequestRepo) CheckExternalUserLogin(ctx context.Context, authReq
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = repo.UserEvents.ExternalLoginChecked(ctx, request.UserID, request.WithCurrentInfo(info))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||
}
|
||||
|
||||
@@ -219,7 +225,7 @@ func (repo *AuthRequestRepo) VerifyMfaOTP(ctx context.Context, authRequestID, us
|
||||
return repo.UserEvents.CheckMfaOTP(ctx, userID, code, request.WithCurrentInfo(info))
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, userAgentID string) error {
|
||||
func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, userAgentID string, info *model.BrowserInfo) error {
|
||||
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -228,6 +234,10 @@ func (repo *AuthRequestRepo) LinkExternalUsers(ctx context.Context, authReqID, u
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = repo.UserEvents.ExternalLoginChecked(ctx, request.UserID, request.WithCurrentInfo(info))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
request.LinkingUsers = nil
|
||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||
}
|
||||
@@ -242,7 +252,7 @@ func (repo *AuthRequestRepo) ResetLinkingUsers(ctx context.Context, authReqID, u
|
||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, registerUser *user_model.User, externalIDP *user_model.ExternalIDP, orgMember *org_model.OrgMember, authReqID, userAgentID, resourceOwner string) error {
|
||||
func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, registerUser *user_model.User, externalIDP *user_model.ExternalIDP, orgMember *org_model.OrgMember, authReqID, userAgentID, resourceOwner string, info *model.BrowserInfo) error {
|
||||
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -277,8 +287,13 @@ func (repo *AuthRequestRepo) AutoRegisterExternalUser(ctx context.Context, regis
|
||||
return err
|
||||
}
|
||||
request.UserID = user.AggregateID
|
||||
request.UserOrgID = user.ResourceOwner
|
||||
request.SelectedIDPConfigID = externalIDP.IDPConfigID
|
||||
request.LinkingUsers = nil
|
||||
err = repo.UserEvents.ExternalLoginChecked(ctx, request.UserID, request.WithCurrentInfo(info))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||
}
|
||||
|
||||
@@ -475,7 +490,11 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *model.AuthR
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if request.SelectedIDPConfigID == "" || (request.SelectedIDPConfigID != "" && request.LinkingUsers != nil && len(request.LinkingUsers) > 0) {
|
||||
if (request.SelectedIDPConfigID != "" || userSession.SelectedIDPConfigID != "") && (request.LinkingUsers == nil || len(request.LinkingUsers) == 0) {
|
||||
if !checkVerificationTime(userSession.ExternalLoginVerification, repo.ExternalLoginCheckLifeTime) {
|
||||
return append(steps, &model.ExternalLoginStep{}), nil
|
||||
}
|
||||
} else if (request.SelectedIDPConfigID == "" && userSession.SelectedIDPConfigID == "") || (request.SelectedIDPConfigID != "" && request.LinkingUsers != nil && len(request.LinkingUsers) > 0) {
|
||||
if user.InitRequired {
|
||||
return append(steps, &model.InitUserStep{PasswordSet: user.PasswordSet}), nil
|
||||
}
|
||||
@@ -643,6 +662,7 @@ func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eve
|
||||
es_model.UserDeactivated,
|
||||
es_model.HumanPasswordCheckSucceeded,
|
||||
es_model.HumanPasswordCheckFailed,
|
||||
es_model.HumanExternalLoginCheckSucceeded,
|
||||
es_model.HumanMFAOTPCheckSucceeded,
|
||||
es_model.HumanMFAOTPCheckFailed,
|
||||
es_model.HumanSignedOut:
|
||||
@@ -689,15 +709,23 @@ func activeUserByID(ctx context.Context, userViewProvider userViewProvider, user
|
||||
}
|
||||
|
||||
func userByID(ctx context.Context, viewProvider userViewProvider, eventProvider userEventProvider, userID string) (*user_model.UserView, error) {
|
||||
user, err := viewProvider.UserByID(userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
user, viewErr := viewProvider.UserByID(userID)
|
||||
if viewErr != nil && !errors.IsNotFound(viewErr) {
|
||||
return nil, viewErr
|
||||
} else if user == nil {
|
||||
user = new(user_view_model.UserView)
|
||||
}
|
||||
events, err := eventProvider.UserEventsByID(ctx, userID, user.Sequence)
|
||||
if err != nil {
|
||||
logging.Log("EVENT-dfg42").WithError(err).Debug("error retrieving new events")
|
||||
return user_view_model.UserToModel(user), nil
|
||||
}
|
||||
if len(events) == 0 {
|
||||
if viewErr != nil {
|
||||
return nil, viewErr
|
||||
}
|
||||
return user_view_model.UserToModel(user), viewErr
|
||||
}
|
||||
userCopy := *user
|
||||
for _, event := range events {
|
||||
if err := userCopy.AppendEvent(event); err != nil {
|
||||
|
@@ -42,9 +42,10 @@ func (m *mockViewErrUserSession) UserSessionsByAgentID(string) ([]*user_view_mod
|
||||
}
|
||||
|
||||
type mockViewUserSession struct {
|
||||
PasswordVerification time.Time
|
||||
MfaSoftwareVerification time.Time
|
||||
Users []mockUser
|
||||
ExternalLoginVerification time.Time
|
||||
PasswordVerification time.Time
|
||||
MfaSoftwareVerification time.Time
|
||||
Users []mockUser
|
||||
}
|
||||
|
||||
type mockUser struct {
|
||||
@@ -54,8 +55,9 @@ type mockUser struct {
|
||||
|
||||
func (m *mockViewUserSession) UserSessionByIDs(string, string) (*user_view_model.UserSessionView, error) {
|
||||
return &user_view_model.UserSessionView{
|
||||
PasswordVerification: m.PasswordVerification,
|
||||
MfaSoftwareVerification: m.MfaSoftwareVerification,
|
||||
ExternalLoginVerification: m.ExternalLoginVerification,
|
||||
PasswordVerification: m.PasswordVerification,
|
||||
MfaSoftwareVerification: m.MfaSoftwareVerification,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -157,17 +159,18 @@ func (m *mockViewErrOrg) OrgByPrimaryDomain(string) (*org_view_model.OrgView, er
|
||||
|
||||
func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
type fields struct {
|
||||
UserEvents *user_event.UserEventstore
|
||||
AuthRequests *cache.AuthRequestCache
|
||||
View *view.View
|
||||
userSessionViewProvider userSessionViewProvider
|
||||
userViewProvider userViewProvider
|
||||
userEventProvider userEventProvider
|
||||
orgViewProvider orgViewProvider
|
||||
PasswordCheckLifeTime time.Duration
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
MfaSoftwareCheckLifeTime time.Duration
|
||||
MfaHardwareCheckLifeTime time.Duration
|
||||
UserEvents *user_event.UserEventstore
|
||||
AuthRequests *cache.AuthRequestCache
|
||||
View *view.View
|
||||
userSessionViewProvider userSessionViewProvider
|
||||
userViewProvider userViewProvider
|
||||
userEventProvider userEventProvider
|
||||
orgViewProvider orgViewProvider
|
||||
PasswordCheckLifeTime time.Duration
|
||||
ExternalLoginCheckLifeTime time.Duration
|
||||
MfaInitSkippedLifeTime time.Duration
|
||||
MfaSoftwareCheckLifeTime time.Duration
|
||||
MfaHardwareCheckLifeTime time.Duration
|
||||
}
|
||||
type args struct {
|
||||
request *model.AuthRequest
|
||||
@@ -391,7 +394,7 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"external user (no password set), callback",
|
||||
"external user (no external verification), external login step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
@@ -405,6 +408,26 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID", SelectedIDPConfigID: "IDPConfigID"}, false},
|
||||
[]model.NextStep{&model.ExternalLoginStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"external user (external verification set), callback",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
ExternalLoginVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
IsEmailVerified: true,
|
||||
MfaMaxSetUp: int32(model.MfaLevelSoftware),
|
||||
},
|
||||
userEventProvider: &mockEventUser{},
|
||||
orgViewProvider: &mockViewOrg{State: org_model.OrgStateActive},
|
||||
ExternalLoginCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID", SelectedIDPConfigID: "IDPConfigID"}, false},
|
||||
[]model.NextStep{&model.RedirectToCallbackStep{}},
|
||||
nil,
|
||||
},
|
||||
@@ -427,16 +450,18 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"external user (no password check needed), callback",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
MfaSoftwareVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
ExternalLoginVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
IsEmailVerified: true,
|
||||
MfaMaxSetUp: int32(model.MfaLevelSoftware),
|
||||
},
|
||||
userEventProvider: &mockEventUser{},
|
||||
orgViewProvider: &mockViewOrg{State: org_model.OrgStateActive},
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
userEventProvider: &mockEventUser{},
|
||||
orgViewProvider: &mockViewOrg{State: org_model.OrgStateActive},
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
ExternalLoginCheckLifeTime: 10 * 24 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID", SelectedIDPConfigID: "IDPConfigID"}, false},
|
||||
[]model.NextStep{&model.RedirectToCallbackStep{}},
|
||||
@@ -468,17 +493,19 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"external user, mfa not verified, mfa check step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
PasswordVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
ExternalLoginVerification: time.Now().UTC().Add(-5 * time.Minute),
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
OTPState: int32(user_model.MfaStateReady),
|
||||
MfaMaxSetUp: int32(model.MfaLevelSoftware),
|
||||
},
|
||||
userEventProvider: &mockEventUser{},
|
||||
orgViewProvider: &mockViewOrg{State: org_model.OrgStateActive},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
userEventProvider: &mockEventUser{},
|
||||
orgViewProvider: &mockViewOrg{State: org_model.OrgStateActive},
|
||||
PasswordCheckLifeTime: 10 * 24 * time.Hour,
|
||||
ExternalLoginCheckLifeTime: 10 * 24 * time.Hour,
|
||||
MfaSoftwareCheckLifeTime: 18 * time.Hour,
|
||||
},
|
||||
args{&model.AuthRequest{UserID: "UserID", SelectedIDPConfigID: "IDPConfigID"}, false},
|
||||
[]model.NextStep{&model.MfaVerificationStep{
|
||||
@@ -645,17 +672,18 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
repo := &AuthRequestRepo{
|
||||
UserEvents: tt.fields.UserEvents,
|
||||
AuthRequests: tt.fields.AuthRequests,
|
||||
View: tt.fields.View,
|
||||
UserSessionViewProvider: tt.fields.userSessionViewProvider,
|
||||
UserViewProvider: tt.fields.userViewProvider,
|
||||
UserEventProvider: tt.fields.userEventProvider,
|
||||
OrgViewProvider: tt.fields.orgViewProvider,
|
||||
PasswordCheckLifeTime: tt.fields.PasswordCheckLifeTime,
|
||||
MfaInitSkippedLifeTime: tt.fields.MfaInitSkippedLifeTime,
|
||||
MfaSoftwareCheckLifeTime: tt.fields.MfaSoftwareCheckLifeTime,
|
||||
MfaHardwareCheckLifeTime: tt.fields.MfaHardwareCheckLifeTime,
|
||||
UserEvents: tt.fields.UserEvents,
|
||||
AuthRequests: tt.fields.AuthRequests,
|
||||
View: tt.fields.View,
|
||||
UserSessionViewProvider: tt.fields.userSessionViewProvider,
|
||||
UserViewProvider: tt.fields.userViewProvider,
|
||||
UserEventProvider: tt.fields.userEventProvider,
|
||||
OrgViewProvider: tt.fields.orgViewProvider,
|
||||
PasswordCheckLifeTime: tt.fields.PasswordCheckLifeTime,
|
||||
ExternalLoginCheckLifeTime: tt.fields.ExternalLoginCheckLifeTime,
|
||||
MfaInitSkippedLifeTime: tt.fields.MfaInitSkippedLifeTime,
|
||||
MfaSoftwareCheckLifeTime: tt.fields.MfaSoftwareCheckLifeTime,
|
||||
MfaHardwareCheckLifeTime: tt.fields.MfaHardwareCheckLifeTime,
|
||||
}
|
||||
got, err := repo.nextSteps(context.Background(), tt.args.request, tt.args.checkLoggedIn)
|
||||
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
|
||||
@@ -1024,7 +1052,9 @@ func Test_userByID(t *testing.T) {
|
||||
{
|
||||
"not found, not found error",
|
||||
args{
|
||||
viewProvider: &mockViewNoUser{},
|
||||
userID: "userID",
|
||||
viewProvider: &mockViewNoUser{},
|
||||
eventProvider: &mockEventUser{},
|
||||
},
|
||||
nil,
|
||||
errors.IsNotFound,
|
||||
|
Reference in New Issue
Block a user