mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-12 02:54:20 +00:00
fix: only reuse active session and use correct policies (from user org) (#6603)
This commit is contained in:
parent
7faab0378f
commit
593d1605ab
@ -50,6 +50,7 @@ type AuthRequestRepo struct {
|
||||
UserGrantProvider userGrantProvider
|
||||
ProjectProvider projectProvider
|
||||
ApplicationProvider applicationProvider
|
||||
CustomTextProvider customTextProvider
|
||||
|
||||
IdGenerator id.Generator
|
||||
}
|
||||
@ -115,6 +116,10 @@ type applicationProvider interface {
|
||||
AppByOIDCClientID(context.Context, string, bool) (*query.App, error)
|
||||
}
|
||||
|
||||
type customTextProvider interface {
|
||||
CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) Health(ctx context.Context) error {
|
||||
return repo.AuthRequests.Health(ctx)
|
||||
}
|
||||
@ -1113,8 +1118,18 @@ func (repo *AuthRequestRepo) nextStepsUser(ctx context.Context, request *domain.
|
||||
if len(steps) > 0 {
|
||||
return steps, nil
|
||||
}
|
||||
// a single user session was found, use that automatically
|
||||
// the single user session was inactive
|
||||
if users[0].UserSessionState != domain.UserSessionStateActive {
|
||||
return append(steps, &domain.SelectUserStep{Users: users}), nil
|
||||
}
|
||||
// a single active user session was found, use that automatically
|
||||
request.SetUserInfo(users[0].UserID, users[0].UserName, users[0].LoginName, users[0].DisplayName, users[0].AvatarKey, users[0].ResourceOwner)
|
||||
if err = repo.fillPolicies(ctx, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = repo.AuthRequests.UpdateAuthRequest(ctx, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return steps, nil
|
||||
}
|
||||
@ -1315,7 +1330,7 @@ func labelPolicyToDomain(p *query.LabelPolicy) *domain.LabelPolicy {
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) getLoginTexts(ctx context.Context, aggregateID string) ([]*domain.CustomText, error) {
|
||||
loginTexts, err := repo.Query.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false)
|
||||
loginTexts, err := repo.CustomTextProvider.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -6,10 +6,12 @@ import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
|
||||
"github.com/zitadel/zitadel/internal/auth_request/repository/cache"
|
||||
cache "github.com/zitadel/zitadel/internal/auth_request/repository"
|
||||
"github.com/zitadel/zitadel/internal/auth_request/repository/mock"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
@ -67,6 +69,7 @@ type mockUser struct {
|
||||
UserID string
|
||||
LoginName string
|
||||
ResourceOwner string
|
||||
SessionState domain.UserSessionState
|
||||
}
|
||||
|
||||
func (m *mockViewUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
|
||||
@ -83,9 +86,10 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie
|
||||
sessions := make([]*user_view_model.UserSessionView, len(m.Users))
|
||||
for i, user := range m.Users {
|
||||
sessions[i] = &user_view_model.UserSessionView{
|
||||
ResourceOwner: user.ResourceOwner,
|
||||
State: int32(user.SessionState),
|
||||
UserID: user.UserID,
|
||||
LoginName: user.LoginName,
|
||||
ResourceOwner: user.ResourceOwner,
|
||||
}
|
||||
}
|
||||
return sessions, nil
|
||||
@ -148,6 +152,30 @@ func (m *mockLoginPolicy) LoginPolicyByID(ctx context.Context, _ bool, id string
|
||||
return m.policy, nil
|
||||
}
|
||||
|
||||
type mockPrivacyPolicy struct {
|
||||
policy *query.PrivacyPolicy
|
||||
}
|
||||
|
||||
func (m mockPrivacyPolicy) PrivacyPolicyByOrg(ctx context.Context, b bool, s string, b2 bool) (*query.PrivacyPolicy, error) {
|
||||
return m.policy, nil
|
||||
}
|
||||
|
||||
type mockLabelPolicy struct {
|
||||
policy *query.LabelPolicy
|
||||
}
|
||||
|
||||
func (m mockLabelPolicy) ActiveLabelPolicyByOrg(ctx context.Context, s string, b bool) (*query.LabelPolicy, error) {
|
||||
return m.policy, nil
|
||||
}
|
||||
|
||||
type mockCustomText struct {
|
||||
texts *query.CustomTexts
|
||||
}
|
||||
|
||||
func (m *mockCustomText) CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error) {
|
||||
return m.texts, nil
|
||||
}
|
||||
|
||||
type mockLockoutPolicy struct {
|
||||
policy *query.LockoutPolicy
|
||||
}
|
||||
@ -258,7 +286,7 @@ func (m *mockIDPUserLinks) IDPUserLinks(ctx context.Context, queries *query.IDPU
|
||||
|
||||
func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
type fields struct {
|
||||
AuthRequests *cache.AuthRequestCache
|
||||
AuthRequests cache.AuthRequestCache
|
||||
View *view.View
|
||||
userSessionViewProvider userSessionViewProvider
|
||||
userViewProvider userViewProvider
|
||||
@ -270,6 +298,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
loginPolicyProvider loginPolicyViewProvider
|
||||
lockoutPolicyProvider lockoutPolicyViewProvider
|
||||
idpUserLinksProvider idpUserLinksProvider
|
||||
privacyPolicyProvider privacyPolicyProvider
|
||||
labelPolicyProvider labelPolicyProvider
|
||||
customTextProvider customTextProvider
|
||||
}
|
||||
type args struct {
|
||||
request *domain.AuthRequest
|
||||
@ -363,11 +394,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"id1",
|
||||
"loginname1",
|
||||
"orgID1",
|
||||
domain.UserSessionStateActive,
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"loginname2",
|
||||
"orgID2",
|
||||
domain.UserSessionStateActive,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -402,11 +435,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"id1",
|
||||
"loginname1",
|
||||
"orgID1",
|
||||
domain.UserSessionStateActive,
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"loginname2",
|
||||
"orgID2",
|
||||
domain.UserSessionStateActive,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -444,6 +479,11 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
{
|
||||
"user not set single active session, callback step",
|
||||
fields{
|
||||
AuthRequests: func() cache.AuthRequestCache {
|
||||
m := mock.NewMockAuthRequestCache(gomock.NewController(t))
|
||||
m.EXPECT().UpdateAuthRequest(gomock.Any(), gomock.Any())
|
||||
return m
|
||||
}(),
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().Add(-5 * time.Minute),
|
||||
@ -452,6 +492,66 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"id1",
|
||||
"loginname1",
|
||||
"orgID1",
|
||||
domain.UserSessionStateActive,
|
||||
},
|
||||
},
|
||||
},
|
||||
userViewProvider: &mockViewUser{
|
||||
PasswordSet: true,
|
||||
IsEmailVerified: true,
|
||||
MFAMaxSetUp: int32(domain.MFALevelSecondFactor),
|
||||
},
|
||||
userEventProvider: &mockEventUser{},
|
||||
orgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
|
||||
userGrantProvider: &mockUserGrants{},
|
||||
projectProvider: &mockProject{},
|
||||
applicationProvider: &mockApp{app: &query.App{OIDCConfig: &query.OIDCApp{AppType: domain.OIDCApplicationTypeWeb}}},
|
||||
lockoutPolicyProvider: &mockLockoutPolicy{
|
||||
policy: &query.LockoutPolicy{
|
||||
ShowFailures: true,
|
||||
},
|
||||
},
|
||||
idpUserLinksProvider: &mockIDPUserLinks{},
|
||||
loginPolicyProvider: &mockLoginPolicy{
|
||||
policy: &query.LoginPolicy{
|
||||
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
|
||||
PasswordCheckLifetime: 10 * 24 * time.Hour,
|
||||
SecondFactorCheckLifetime: 18 * time.Hour,
|
||||
},
|
||||
},
|
||||
privacyPolicyProvider: &mockPrivacyPolicy{
|
||||
policy: &query.PrivacyPolicy{},
|
||||
},
|
||||
labelPolicyProvider: &mockLabelPolicy{
|
||||
policy: &query.LabelPolicy{},
|
||||
},
|
||||
customTextProvider: &mockCustomText{
|
||||
texts: &query.CustomTexts{},
|
||||
},
|
||||
},
|
||||
args{&domain.AuthRequest{
|
||||
Request: &domain.AuthRequestOIDC{},
|
||||
LoginPolicy: &domain.LoginPolicy{
|
||||
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
|
||||
PasswordCheckLifetime: 10 * 24 * time.Hour,
|
||||
SecondFactorCheckLifetime: 18 * time.Hour,
|
||||
},
|
||||
}, false},
|
||||
[]domain.NextStep{&domain.RedirectToCallbackStep{}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"user not set single inactive session, callback step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
PasswordVerification: time.Now().Add(-5 * time.Minute),
|
||||
SecondFactorVerification: time.Now().Add(-5 * time.Minute),
|
||||
Users: []mockUser{
|
||||
{
|
||||
"id1",
|
||||
"loginname1",
|
||||
"orgID1",
|
||||
domain.UserSessionStateTerminated,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -480,11 +580,19 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
SecondFactorCheckLifetime: 18 * time.Hour,
|
||||
},
|
||||
}, false},
|
||||
[]domain.NextStep{&domain.RedirectToCallbackStep{}},
|
||||
[]domain.NextStep{&domain.SelectUserStep{Users: []domain.UserSelection{
|
||||
{
|
||||
UserID: "id1",
|
||||
LoginName: "loginname1",
|
||||
ResourceOwner: "orgID1",
|
||||
UserSessionState: domain.UserSessionStateTerminated,
|
||||
SelectionPossible: true,
|
||||
},
|
||||
}}},
|
||||
nil,
|
||||
},
|
||||
{
|
||||
"user not set multiple active sessions, select account step",
|
||||
"user not set multiple sessions, select account step",
|
||||
fields{
|
||||
userSessionViewProvider: &mockViewUserSession{
|
||||
Users: []mockUser{
|
||||
@ -492,11 +600,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
"id1",
|
||||
"loginname1",
|
||||
"orgID1",
|
||||
domain.UserSessionStateActive,
|
||||
},
|
||||
{
|
||||
"id2",
|
||||
"loginname2",
|
||||
"orgID2",
|
||||
domain.UserSessionStateTerminated,
|
||||
},
|
||||
},
|
||||
},
|
||||
@ -532,12 +642,14 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
LoginName: "loginname1",
|
||||
SelectionPossible: true,
|
||||
ResourceOwner: "orgID1",
|
||||
UserSessionState: domain.UserSessionStateActive,
|
||||
},
|
||||
{
|
||||
UserID: "id2",
|
||||
LoginName: "loginname2",
|
||||
SelectionPossible: true,
|
||||
ResourceOwner: "orgID2",
|
||||
UserSessionState: domain.UserSessionStateTerminated,
|
||||
},
|
||||
},
|
||||
}},
|
||||
@ -1544,6 +1656,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
|
||||
LoginPolicyViewProvider: tt.fields.loginPolicyProvider,
|
||||
LockoutPolicyViewProvider: tt.fields.lockoutPolicyProvider,
|
||||
IDPUserLinksProvider: tt.fields.idpUserLinksProvider,
|
||||
PrivacyPolicyProvider: tt.fields.privacyPolicyProvider,
|
||||
LabelPolicyProvider: tt.fields.labelPolicyProvider,
|
||||
CustomTextProvider: tt.fields.customTextProvider,
|
||||
}
|
||||
got, err := repo.nextSteps(context.Background(), tt.args.request, tt.args.checkLoggedIn)
|
||||
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
|
||||
|
@ -87,6 +87,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c
|
||||
UserGrantProvider: queryView,
|
||||
ProjectProvider: queryView,
|
||||
ApplicationProvider: queries,
|
||||
CustomTextProvider: queries,
|
||||
IdGenerator: idGenerator,
|
||||
},
|
||||
eventstore.TokenRepo{
|
||||
|
@ -8,8 +8,8 @@ import (
|
||||
context "context"
|
||||
reflect "reflect"
|
||||
|
||||
domain "github.com/zitadel/zitadel/internal/domain"
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
domain "github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
|
||||
// MockAuthRequestCache is a mock of AuthRequestCache interface.
|
||||
@ -36,47 +36,47 @@ func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder {
|
||||
}
|
||||
|
||||
// DeleteAuthRequest mocks base method.
|
||||
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1, arg2 string) error {
|
||||
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1, arg2)
|
||||
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
|
||||
ret0, _ := ret[0].(error)
|
||||
return ret0
|
||||
}
|
||||
|
||||
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest.
|
||||
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetAuthRequestByCode mocks base method.
|
||||
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
|
||||
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1, arg2)
|
||||
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1)
|
||||
ret0, _ := ret[0].(*domain.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode.
|
||||
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1)
|
||||
}
|
||||
|
||||
// GetAuthRequestByID mocks base method.
|
||||
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
|
||||
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1, arg2)
|
||||
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1)
|
||||
ret0, _ := ret[0].(*domain.AuthRequest)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// GetAuthRequestByID indicates an expected call of GetAuthRequestByID.
|
||||
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1, arg2 interface{}) *gomock.Call {
|
||||
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1, arg2)
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1)
|
||||
}
|
||||
|
||||
// Health mocks base method.
|
||||
|
Loading…
Reference in New Issue
Block a user