fix: only reuse active session and use correct policies (from user org) (#6603)

This commit is contained in:
Livio Spring 2023-09-21 16:45:41 +02:00 committed by GitHub
parent 7faab0378f
commit 593d1605ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 151 additions and 20 deletions

View File

@ -50,6 +50,7 @@ type AuthRequestRepo struct {
UserGrantProvider userGrantProvider UserGrantProvider userGrantProvider
ProjectProvider projectProvider ProjectProvider projectProvider
ApplicationProvider applicationProvider ApplicationProvider applicationProvider
CustomTextProvider customTextProvider
IdGenerator id.Generator IdGenerator id.Generator
} }
@ -115,6 +116,10 @@ type applicationProvider interface {
AppByOIDCClientID(context.Context, string, bool) (*query.App, error) AppByOIDCClientID(context.Context, string, bool) (*query.App, error)
} }
type customTextProvider interface {
CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error)
}
func (repo *AuthRequestRepo) Health(ctx context.Context) error { func (repo *AuthRequestRepo) Health(ctx context.Context) error {
return repo.AuthRequests.Health(ctx) return repo.AuthRequests.Health(ctx)
} }
@ -1113,8 +1118,18 @@ func (repo *AuthRequestRepo) nextStepsUser(ctx context.Context, request *domain.
if len(steps) > 0 { if len(steps) > 0 {
return steps, nil return steps, nil
} }
// a single user session was found, use that automatically // the single user session was inactive
if users[0].UserSessionState != domain.UserSessionStateActive {
return append(steps, &domain.SelectUserStep{Users: users}), nil
}
// a single active user session was found, use that automatically
request.SetUserInfo(users[0].UserID, users[0].UserName, users[0].LoginName, users[0].DisplayName, users[0].AvatarKey, users[0].ResourceOwner) request.SetUserInfo(users[0].UserID, users[0].UserName, users[0].LoginName, users[0].DisplayName, users[0].AvatarKey, users[0].ResourceOwner)
if err = repo.fillPolicies(ctx, request); err != nil {
return nil, err
}
if err = repo.AuthRequests.UpdateAuthRequest(ctx, request); err != nil {
return nil, err
}
} }
return steps, nil return steps, nil
} }
@ -1315,7 +1330,7 @@ func labelPolicyToDomain(p *query.LabelPolicy) *domain.LabelPolicy {
} }
func (repo *AuthRequestRepo) getLoginTexts(ctx context.Context, aggregateID string) ([]*domain.CustomText, error) { func (repo *AuthRequestRepo) getLoginTexts(ctx context.Context, aggregateID string) ([]*domain.CustomText, error) {
loginTexts, err := repo.Query.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false) loginTexts, err := repo.CustomTextProvider.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -6,10 +6,12 @@ import (
"testing" "testing"
"time" "time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view" "github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/zitadel/zitadel/internal/auth_request/repository/cache" cache "github.com/zitadel/zitadel/internal/auth_request/repository"
"github.com/zitadel/zitadel/internal/auth_request/repository/mock"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/errors"
@ -67,6 +69,7 @@ type mockUser struct {
UserID string UserID string
LoginName string LoginName string
ResourceOwner string ResourceOwner string
SessionState domain.UserSessionState
} }
func (m *mockViewUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) { func (m *mockViewUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
@ -83,9 +86,10 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie
sessions := make([]*user_view_model.UserSessionView, len(m.Users)) sessions := make([]*user_view_model.UserSessionView, len(m.Users))
for i, user := range m.Users { for i, user := range m.Users {
sessions[i] = &user_view_model.UserSessionView{ sessions[i] = &user_view_model.UserSessionView{
ResourceOwner: user.ResourceOwner,
State: int32(user.SessionState),
UserID: user.UserID, UserID: user.UserID,
LoginName: user.LoginName, LoginName: user.LoginName,
ResourceOwner: user.ResourceOwner,
} }
} }
return sessions, nil return sessions, nil
@ -148,6 +152,30 @@ func (m *mockLoginPolicy) LoginPolicyByID(ctx context.Context, _ bool, id string
return m.policy, nil return m.policy, nil
} }
type mockPrivacyPolicy struct {
policy *query.PrivacyPolicy
}
func (m mockPrivacyPolicy) PrivacyPolicyByOrg(ctx context.Context, b bool, s string, b2 bool) (*query.PrivacyPolicy, error) {
return m.policy, nil
}
type mockLabelPolicy struct {
policy *query.LabelPolicy
}
func (m mockLabelPolicy) ActiveLabelPolicyByOrg(ctx context.Context, s string, b bool) (*query.LabelPolicy, error) {
return m.policy, nil
}
type mockCustomText struct {
texts *query.CustomTexts
}
func (m *mockCustomText) CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error) {
return m.texts, nil
}
type mockLockoutPolicy struct { type mockLockoutPolicy struct {
policy *query.LockoutPolicy policy *query.LockoutPolicy
} }
@ -258,7 +286,7 @@ func (m *mockIDPUserLinks) IDPUserLinks(ctx context.Context, queries *query.IDPU
func TestAuthRequestRepo_nextSteps(t *testing.T) { func TestAuthRequestRepo_nextSteps(t *testing.T) {
type fields struct { type fields struct {
AuthRequests *cache.AuthRequestCache AuthRequests cache.AuthRequestCache
View *view.View View *view.View
userSessionViewProvider userSessionViewProvider userSessionViewProvider userSessionViewProvider
userViewProvider userViewProvider userViewProvider userViewProvider
@ -270,6 +298,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
loginPolicyProvider loginPolicyViewProvider loginPolicyProvider loginPolicyViewProvider
lockoutPolicyProvider lockoutPolicyViewProvider lockoutPolicyProvider lockoutPolicyViewProvider
idpUserLinksProvider idpUserLinksProvider idpUserLinksProvider idpUserLinksProvider
privacyPolicyProvider privacyPolicyProvider
labelPolicyProvider labelPolicyProvider
customTextProvider customTextProvider
} }
type args struct { type args struct {
request *domain.AuthRequest request *domain.AuthRequest
@ -363,11 +394,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1", "id1",
"loginname1", "loginname1",
"orgID1", "orgID1",
domain.UserSessionStateActive,
}, },
{ {
"id2", "id2",
"loginname2", "loginname2",
"orgID2", "orgID2",
domain.UserSessionStateActive,
}, },
}, },
}, },
@ -402,11 +435,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1", "id1",
"loginname1", "loginname1",
"orgID1", "orgID1",
domain.UserSessionStateActive,
}, },
{ {
"id2", "id2",
"loginname2", "loginname2",
"orgID2", "orgID2",
domain.UserSessionStateActive,
}, },
}, },
}, },
@ -444,6 +479,11 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
{ {
"user not set single active session, callback step", "user not set single active session, callback step",
fields{ fields{
AuthRequests: func() cache.AuthRequestCache {
m := mock.NewMockAuthRequestCache(gomock.NewController(t))
m.EXPECT().UpdateAuthRequest(gomock.Any(), gomock.Any())
return m
}(),
userSessionViewProvider: &mockViewUserSession{ userSessionViewProvider: &mockViewUserSession{
PasswordVerification: time.Now().Add(-5 * time.Minute), PasswordVerification: time.Now().Add(-5 * time.Minute),
SecondFactorVerification: time.Now().Add(-5 * time.Minute), SecondFactorVerification: time.Now().Add(-5 * time.Minute),
@ -452,6 +492,66 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1", "id1",
"loginname1", "loginname1",
"orgID1", "orgID1",
domain.UserSessionStateActive,
},
},
},
userViewProvider: &mockViewUser{
PasswordSet: true,
IsEmailVerified: true,
MFAMaxSetUp: int32(domain.MFALevelSecondFactor),
},
userEventProvider: &mockEventUser{},
orgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
userGrantProvider: &mockUserGrants{},
projectProvider: &mockProject{},
applicationProvider: &mockApp{app: &query.App{OIDCConfig: &query.OIDCApp{AppType: domain.OIDCApplicationTypeWeb}}},
lockoutPolicyProvider: &mockLockoutPolicy{
policy: &query.LockoutPolicy{
ShowFailures: true,
},
},
idpUserLinksProvider: &mockIDPUserLinks{},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
PasswordCheckLifetime: 10 * 24 * time.Hour,
SecondFactorCheckLifetime: 18 * time.Hour,
},
},
privacyPolicyProvider: &mockPrivacyPolicy{
policy: &query.PrivacyPolicy{},
},
labelPolicyProvider: &mockLabelPolicy{
policy: &query.LabelPolicy{},
},
customTextProvider: &mockCustomText{
texts: &query.CustomTexts{},
},
},
args{&domain.AuthRequest{
Request: &domain.AuthRequestOIDC{},
LoginPolicy: &domain.LoginPolicy{
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
PasswordCheckLifetime: 10 * 24 * time.Hour,
SecondFactorCheckLifetime: 18 * time.Hour,
},
}, false},
[]domain.NextStep{&domain.RedirectToCallbackStep{}},
nil,
},
{
"user not set single inactive session, callback step",
fields{
userSessionViewProvider: &mockViewUserSession{
PasswordVerification: time.Now().Add(-5 * time.Minute),
SecondFactorVerification: time.Now().Add(-5 * time.Minute),
Users: []mockUser{
{
"id1",
"loginname1",
"orgID1",
domain.UserSessionStateTerminated,
}, },
}, },
}, },
@ -480,11 +580,19 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
SecondFactorCheckLifetime: 18 * time.Hour, SecondFactorCheckLifetime: 18 * time.Hour,
}, },
}, false}, }, false},
[]domain.NextStep{&domain.RedirectToCallbackStep{}}, []domain.NextStep{&domain.SelectUserStep{Users: []domain.UserSelection{
{
UserID: "id1",
LoginName: "loginname1",
ResourceOwner: "orgID1",
UserSessionState: domain.UserSessionStateTerminated,
SelectionPossible: true,
},
}}},
nil, nil,
}, },
{ {
"user not set multiple active sessions, select account step", "user not set multiple sessions, select account step",
fields{ fields{
userSessionViewProvider: &mockViewUserSession{ userSessionViewProvider: &mockViewUserSession{
Users: []mockUser{ Users: []mockUser{
@ -492,11 +600,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1", "id1",
"loginname1", "loginname1",
"orgID1", "orgID1",
domain.UserSessionStateActive,
}, },
{ {
"id2", "id2",
"loginname2", "loginname2",
"orgID2", "orgID2",
domain.UserSessionStateTerminated,
}, },
}, },
}, },
@ -532,12 +642,14 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
LoginName: "loginname1", LoginName: "loginname1",
SelectionPossible: true, SelectionPossible: true,
ResourceOwner: "orgID1", ResourceOwner: "orgID1",
UserSessionState: domain.UserSessionStateActive,
}, },
{ {
UserID: "id2", UserID: "id2",
LoginName: "loginname2", LoginName: "loginname2",
SelectionPossible: true, SelectionPossible: true,
ResourceOwner: "orgID2", ResourceOwner: "orgID2",
UserSessionState: domain.UserSessionStateTerminated,
}, },
}, },
}}, }},
@ -1544,6 +1656,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
LoginPolicyViewProvider: tt.fields.loginPolicyProvider, LoginPolicyViewProvider: tt.fields.loginPolicyProvider,
LockoutPolicyViewProvider: tt.fields.lockoutPolicyProvider, LockoutPolicyViewProvider: tt.fields.lockoutPolicyProvider,
IDPUserLinksProvider: tt.fields.idpUserLinksProvider, IDPUserLinksProvider: tt.fields.idpUserLinksProvider,
PrivacyPolicyProvider: tt.fields.privacyPolicyProvider,
LabelPolicyProvider: tt.fields.labelPolicyProvider,
CustomTextProvider: tt.fields.customTextProvider,
} }
got, err := repo.nextSteps(context.Background(), tt.args.request, tt.args.checkLoggedIn) got, err := repo.nextSteps(context.Background(), tt.args.request, tt.args.checkLoggedIn)
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) { if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {

View File

@ -87,6 +87,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c
UserGrantProvider: queryView, UserGrantProvider: queryView,
ProjectProvider: queryView, ProjectProvider: queryView,
ApplicationProvider: queries, ApplicationProvider: queries,
CustomTextProvider: queries,
IdGenerator: idGenerator, IdGenerator: idGenerator,
}, },
eventstore.TokenRepo{ eventstore.TokenRepo{

View File

@ -8,8 +8,8 @@ import (
context "context" context "context"
reflect "reflect" reflect "reflect"
domain "github.com/zitadel/zitadel/internal/domain"
gomock "github.com/golang/mock/gomock" gomock "github.com/golang/mock/gomock"
domain "github.com/zitadel/zitadel/internal/domain"
) )
// MockAuthRequestCache is a mock of AuthRequestCache interface. // MockAuthRequestCache is a mock of AuthRequestCache interface.
@ -36,47 +36,47 @@ func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder {
} }
// DeleteAuthRequest mocks base method. // DeleteAuthRequest mocks base method.
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1, arg2 string) error { func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1, arg2) ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest. // DeleteAuthRequest indicates an expected call of DeleteAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1)
} }
// GetAuthRequestByCode mocks base method. // GetAuthRequestByCode mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) { func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1, arg2) ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1)
ret0, _ := ret[0].(*domain.AuthRequest) ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode. // GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1)
} }
// GetAuthRequestByID mocks base method. // GetAuthRequestByID mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) { func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1, arg2) ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1)
ret0, _ := ret[0].(*domain.AuthRequest) ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error) ret1, _ := ret[1].(error)
return ret0, ret1 return ret0, ret1
} }
// GetAuthRequestByID indicates an expected call of GetAuthRequestByID. // GetAuthRequestByID indicates an expected call of GetAuthRequestByID.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1, arg2 interface{}) *gomock.Call { func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1, arg2) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1)
} }
// Health mocks base method. // Health mocks base method.