diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 3017210447..950809f0c8 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -50,6 +50,7 @@ type AuthRequestRepo struct { UserGrantProvider userGrantProvider ProjectProvider projectProvider ApplicationProvider applicationProvider + CustomTextProvider customTextProvider IdGenerator id.Generator } @@ -115,6 +116,10 @@ type applicationProvider interface { AppByOIDCClientID(context.Context, string, bool) (*query.App, error) } +type customTextProvider interface { + CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error) +} + func (repo *AuthRequestRepo) Health(ctx context.Context) error { return repo.AuthRequests.Health(ctx) } @@ -1113,8 +1118,18 @@ func (repo *AuthRequestRepo) nextStepsUser(ctx context.Context, request *domain. if len(steps) > 0 { return steps, nil } - // a single user session was found, use that automatically + // the single user session was inactive + if users[0].UserSessionState != domain.UserSessionStateActive { + return append(steps, &domain.SelectUserStep{Users: users}), nil + } + // a single active user session was found, use that automatically request.SetUserInfo(users[0].UserID, users[0].UserName, users[0].LoginName, users[0].DisplayName, users[0].AvatarKey, users[0].ResourceOwner) + if err = repo.fillPolicies(ctx, request); err != nil { + return nil, err + } + if err = repo.AuthRequests.UpdateAuthRequest(ctx, request); err != nil { + return nil, err + } } return steps, nil } @@ -1315,7 +1330,7 @@ func labelPolicyToDomain(p *query.LabelPolicy) *domain.LabelPolicy { } func (repo *AuthRequestRepo) getLoginTexts(ctx context.Context, aggregateID string) ([]*domain.CustomText, error) { - loginTexts, err := repo.Query.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false) + loginTexts, err := repo.CustomTextProvider.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false) if err != nil { return nil, err } diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 7e94c61ef6..7a469e4729 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -6,10 +6,12 @@ import ( "testing" "time" + "github.com/golang/mock/gomock" "github.com/stretchr/testify/assert" "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view" - "github.com/zitadel/zitadel/internal/auth_request/repository/cache" + cache "github.com/zitadel/zitadel/internal/auth_request/repository" + "github.com/zitadel/zitadel/internal/auth_request/repository/mock" "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/errors" @@ -67,6 +69,7 @@ type mockUser struct { UserID string LoginName string ResourceOwner string + SessionState domain.UserSessionState } func (m *mockViewUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) { @@ -83,9 +86,10 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie sessions := make([]*user_view_model.UserSessionView, len(m.Users)) for i, user := range m.Users { sessions[i] = &user_view_model.UserSessionView{ + ResourceOwner: user.ResourceOwner, + State: int32(user.SessionState), UserID: user.UserID, LoginName: user.LoginName, - ResourceOwner: user.ResourceOwner, } } return sessions, nil @@ -148,6 +152,30 @@ func (m *mockLoginPolicy) LoginPolicyByID(ctx context.Context, _ bool, id string return m.policy, nil } +type mockPrivacyPolicy struct { + policy *query.PrivacyPolicy +} + +func (m mockPrivacyPolicy) PrivacyPolicyByOrg(ctx context.Context, b bool, s string, b2 bool) (*query.PrivacyPolicy, error) { + return m.policy, nil +} + +type mockLabelPolicy struct { + policy *query.LabelPolicy +} + +func (m mockLabelPolicy) ActiveLabelPolicyByOrg(ctx context.Context, s string, b bool) (*query.LabelPolicy, error) { + return m.policy, nil +} + +type mockCustomText struct { + texts *query.CustomTexts +} + +func (m *mockCustomText) CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error) { + return m.texts, nil +} + type mockLockoutPolicy struct { policy *query.LockoutPolicy } @@ -258,7 +286,7 @@ func (m *mockIDPUserLinks) IDPUserLinks(ctx context.Context, queries *query.IDPU func TestAuthRequestRepo_nextSteps(t *testing.T) { type fields struct { - AuthRequests *cache.AuthRequestCache + AuthRequests cache.AuthRequestCache View *view.View userSessionViewProvider userSessionViewProvider userViewProvider userViewProvider @@ -270,6 +298,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { loginPolicyProvider loginPolicyViewProvider lockoutPolicyProvider lockoutPolicyViewProvider idpUserLinksProvider idpUserLinksProvider + privacyPolicyProvider privacyPolicyProvider + labelPolicyProvider labelPolicyProvider + customTextProvider customTextProvider } type args struct { request *domain.AuthRequest @@ -363,11 +394,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { "id1", "loginname1", "orgID1", + domain.UserSessionStateActive, }, { "id2", "loginname2", "orgID2", + domain.UserSessionStateActive, }, }, }, @@ -402,11 +435,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { "id1", "loginname1", "orgID1", + domain.UserSessionStateActive, }, { "id2", "loginname2", "orgID2", + domain.UserSessionStateActive, }, }, }, @@ -444,6 +479,11 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { { "user not set single active session, callback step", fields{ + AuthRequests: func() cache.AuthRequestCache { + m := mock.NewMockAuthRequestCache(gomock.NewController(t)) + m.EXPECT().UpdateAuthRequest(gomock.Any(), gomock.Any()) + return m + }(), userSessionViewProvider: &mockViewUserSession{ PasswordVerification: time.Now().Add(-5 * time.Minute), SecondFactorVerification: time.Now().Add(-5 * time.Minute), @@ -452,6 +492,66 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { "id1", "loginname1", "orgID1", + domain.UserSessionStateActive, + }, + }, + }, + userViewProvider: &mockViewUser{ + PasswordSet: true, + IsEmailVerified: true, + MFAMaxSetUp: int32(domain.MFALevelSecondFactor), + }, + userEventProvider: &mockEventUser{}, + orgViewProvider: &mockViewOrg{State: domain.OrgStateActive}, + userGrantProvider: &mockUserGrants{}, + projectProvider: &mockProject{}, + applicationProvider: &mockApp{app: &query.App{OIDCConfig: &query.OIDCApp{AppType: domain.OIDCApplicationTypeWeb}}}, + lockoutPolicyProvider: &mockLockoutPolicy{ + policy: &query.LockoutPolicy{ + ShowFailures: true, + }, + }, + idpUserLinksProvider: &mockIDPUserLinks{}, + loginPolicyProvider: &mockLoginPolicy{ + policy: &query.LoginPolicy{ + SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP}, + PasswordCheckLifetime: 10 * 24 * time.Hour, + SecondFactorCheckLifetime: 18 * time.Hour, + }, + }, + privacyPolicyProvider: &mockPrivacyPolicy{ + policy: &query.PrivacyPolicy{}, + }, + labelPolicyProvider: &mockLabelPolicy{ + policy: &query.LabelPolicy{}, + }, + customTextProvider: &mockCustomText{ + texts: &query.CustomTexts{}, + }, + }, + args{&domain.AuthRequest{ + Request: &domain.AuthRequestOIDC{}, + LoginPolicy: &domain.LoginPolicy{ + SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP}, + PasswordCheckLifetime: 10 * 24 * time.Hour, + SecondFactorCheckLifetime: 18 * time.Hour, + }, + }, false}, + []domain.NextStep{&domain.RedirectToCallbackStep{}}, + nil, + }, + { + "user not set single inactive session, callback step", + fields{ + userSessionViewProvider: &mockViewUserSession{ + PasswordVerification: time.Now().Add(-5 * time.Minute), + SecondFactorVerification: time.Now().Add(-5 * time.Minute), + Users: []mockUser{ + { + "id1", + "loginname1", + "orgID1", + domain.UserSessionStateTerminated, }, }, }, @@ -480,11 +580,19 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { SecondFactorCheckLifetime: 18 * time.Hour, }, }, false}, - []domain.NextStep{&domain.RedirectToCallbackStep{}}, + []domain.NextStep{&domain.SelectUserStep{Users: []domain.UserSelection{ + { + UserID: "id1", + LoginName: "loginname1", + ResourceOwner: "orgID1", + UserSessionState: domain.UserSessionStateTerminated, + SelectionPossible: true, + }, + }}}, nil, }, { - "user not set multiple active sessions, select account step", + "user not set multiple sessions, select account step", fields{ userSessionViewProvider: &mockViewUserSession{ Users: []mockUser{ @@ -492,11 +600,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { "id1", "loginname1", "orgID1", + domain.UserSessionStateActive, }, { "id2", "loginname2", "orgID2", + domain.UserSessionStateTerminated, }, }, }, @@ -532,12 +642,14 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { LoginName: "loginname1", SelectionPossible: true, ResourceOwner: "orgID1", + UserSessionState: domain.UserSessionStateActive, }, { UserID: "id2", LoginName: "loginname2", SelectionPossible: true, ResourceOwner: "orgID2", + UserSessionState: domain.UserSessionStateTerminated, }, }, }}, @@ -1544,6 +1656,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) { LoginPolicyViewProvider: tt.fields.loginPolicyProvider, LockoutPolicyViewProvider: tt.fields.lockoutPolicyProvider, IDPUserLinksProvider: tt.fields.idpUserLinksProvider, + PrivacyPolicyProvider: tt.fields.privacyPolicyProvider, + LabelPolicyProvider: tt.fields.labelPolicyProvider, + CustomTextProvider: tt.fields.customTextProvider, } got, err := repo.nextSteps(context.Background(), tt.args.request, tt.args.checkLoggedIn) if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) { diff --git a/internal/auth/repository/eventsourcing/repository.go b/internal/auth/repository/eventsourcing/repository.go index 454daffd11..257b49d706 100644 --- a/internal/auth/repository/eventsourcing/repository.go +++ b/internal/auth/repository/eventsourcing/repository.go @@ -87,6 +87,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c UserGrantProvider: queryView, ProjectProvider: queryView, ApplicationProvider: queries, + CustomTextProvider: queries, IdGenerator: idGenerator, }, eventstore.TokenRepo{ diff --git a/internal/auth_request/repository/mock/repository.mock.go b/internal/auth_request/repository/mock/repository.mock.go index 929e9aba11..74186fcf1e 100644 --- a/internal/auth_request/repository/mock/repository.mock.go +++ b/internal/auth_request/repository/mock/repository.mock.go @@ -8,8 +8,8 @@ import ( context "context" reflect "reflect" - domain "github.com/zitadel/zitadel/internal/domain" gomock "github.com/golang/mock/gomock" + domain "github.com/zitadel/zitadel/internal/domain" ) // MockAuthRequestCache is a mock of AuthRequestCache interface. @@ -36,47 +36,47 @@ func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder { } // DeleteAuthRequest mocks base method. -func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1, arg2 string) error { +func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1 string) error { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1) ret0, _ := ret[0].(error) return ret0 } // DeleteAuthRequest indicates an expected call of DeleteAuthRequest. -func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1) } // GetAuthRequestByCode mocks base method. -func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) { +func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1) ret0, _ := ret[0].(*domain.AuthRequest) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode. -func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1) } // GetAuthRequestByID mocks base method. -func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) { +func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) { m.ctrl.T.Helper() - ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1, arg2) + ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1) ret0, _ := ret[0].(*domain.AuthRequest) ret1, _ := ret[1].(error) return ret0, ret1 } // GetAuthRequestByID indicates an expected call of GetAuthRequestByID. -func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1, arg2 interface{}) *gomock.Call { +func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1 interface{}) *gomock.Call { mr.mock.ctrl.T.Helper() - return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1, arg2) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1) } // Health mocks base method.