fix: only reuse active session and use correct policies (from user org) (#6603)

This commit is contained in:
Livio Spring 2023-09-21 16:45:41 +02:00 committed by GitHub
parent 7faab0378f
commit 593d1605ab
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 151 additions and 20 deletions

View File

@ -50,6 +50,7 @@ type AuthRequestRepo struct {
UserGrantProvider userGrantProvider
ProjectProvider projectProvider
ApplicationProvider applicationProvider
CustomTextProvider customTextProvider
IdGenerator id.Generator
}
@ -115,6 +116,10 @@ type applicationProvider interface {
AppByOIDCClientID(context.Context, string, bool) (*query.App, error)
}
type customTextProvider interface {
CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error)
}
func (repo *AuthRequestRepo) Health(ctx context.Context) error {
return repo.AuthRequests.Health(ctx)
}
@ -1113,8 +1118,18 @@ func (repo *AuthRequestRepo) nextStepsUser(ctx context.Context, request *domain.
if len(steps) > 0 {
return steps, nil
}
// a single user session was found, use that automatically
// the single user session was inactive
if users[0].UserSessionState != domain.UserSessionStateActive {
return append(steps, &domain.SelectUserStep{Users: users}), nil
}
// a single active user session was found, use that automatically
request.SetUserInfo(users[0].UserID, users[0].UserName, users[0].LoginName, users[0].DisplayName, users[0].AvatarKey, users[0].ResourceOwner)
if err = repo.fillPolicies(ctx, request); err != nil {
return nil, err
}
if err = repo.AuthRequests.UpdateAuthRequest(ctx, request); err != nil {
return nil, err
}
}
return steps, nil
}
@ -1315,7 +1330,7 @@ func labelPolicyToDomain(p *query.LabelPolicy) *domain.LabelPolicy {
}
func (repo *AuthRequestRepo) getLoginTexts(ctx context.Context, aggregateID string) ([]*domain.CustomText, error) {
loginTexts, err := repo.Query.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false)
loginTexts, err := repo.CustomTextProvider.CustomTextListByTemplate(ctx, aggregateID, domain.LoginCustomText, false)
if err != nil {
return nil, err
}

View File

@ -6,10 +6,12 @@ import (
"testing"
"time"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/auth/repository/eventsourcing/view"
"github.com/zitadel/zitadel/internal/auth_request/repository/cache"
cache "github.com/zitadel/zitadel/internal/auth_request/repository"
"github.com/zitadel/zitadel/internal/auth_request/repository/mock"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
@ -67,6 +69,7 @@ type mockUser struct {
UserID string
LoginName string
ResourceOwner string
SessionState domain.UserSessionState
}
func (m *mockViewUserSession) UserSessionByIDs(string, string, string) (*user_view_model.UserSessionView, error) {
@ -83,9 +86,10 @@ func (m *mockViewUserSession) UserSessionsByAgentID(string, string) ([]*user_vie
sessions := make([]*user_view_model.UserSessionView, len(m.Users))
for i, user := range m.Users {
sessions[i] = &user_view_model.UserSessionView{
ResourceOwner: user.ResourceOwner,
State: int32(user.SessionState),
UserID: user.UserID,
LoginName: user.LoginName,
ResourceOwner: user.ResourceOwner,
}
}
return sessions, nil
@ -148,6 +152,30 @@ func (m *mockLoginPolicy) LoginPolicyByID(ctx context.Context, _ bool, id string
return m.policy, nil
}
type mockPrivacyPolicy struct {
policy *query.PrivacyPolicy
}
func (m mockPrivacyPolicy) PrivacyPolicyByOrg(ctx context.Context, b bool, s string, b2 bool) (*query.PrivacyPolicy, error) {
return m.policy, nil
}
type mockLabelPolicy struct {
policy *query.LabelPolicy
}
func (m mockLabelPolicy) ActiveLabelPolicyByOrg(ctx context.Context, s string, b bool) (*query.LabelPolicy, error) {
return m.policy, nil
}
type mockCustomText struct {
texts *query.CustomTexts
}
func (m *mockCustomText) CustomTextListByTemplate(ctx context.Context, aggregateID string, text string, withOwnerRemoved bool) (texts *query.CustomTexts, err error) {
return m.texts, nil
}
type mockLockoutPolicy struct {
policy *query.LockoutPolicy
}
@ -258,7 +286,7 @@ func (m *mockIDPUserLinks) IDPUserLinks(ctx context.Context, queries *query.IDPU
func TestAuthRequestRepo_nextSteps(t *testing.T) {
type fields struct {
AuthRequests *cache.AuthRequestCache
AuthRequests cache.AuthRequestCache
View *view.View
userSessionViewProvider userSessionViewProvider
userViewProvider userViewProvider
@ -270,6 +298,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
loginPolicyProvider loginPolicyViewProvider
lockoutPolicyProvider lockoutPolicyViewProvider
idpUserLinksProvider idpUserLinksProvider
privacyPolicyProvider privacyPolicyProvider
labelPolicyProvider labelPolicyProvider
customTextProvider customTextProvider
}
type args struct {
request *domain.AuthRequest
@ -363,11 +394,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1",
"loginname1",
"orgID1",
domain.UserSessionStateActive,
},
{
"id2",
"loginname2",
"orgID2",
domain.UserSessionStateActive,
},
},
},
@ -402,11 +435,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1",
"loginname1",
"orgID1",
domain.UserSessionStateActive,
},
{
"id2",
"loginname2",
"orgID2",
domain.UserSessionStateActive,
},
},
},
@ -444,6 +479,11 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
{
"user not set single active session, callback step",
fields{
AuthRequests: func() cache.AuthRequestCache {
m := mock.NewMockAuthRequestCache(gomock.NewController(t))
m.EXPECT().UpdateAuthRequest(gomock.Any(), gomock.Any())
return m
}(),
userSessionViewProvider: &mockViewUserSession{
PasswordVerification: time.Now().Add(-5 * time.Minute),
SecondFactorVerification: time.Now().Add(-5 * time.Minute),
@ -452,6 +492,66 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1",
"loginname1",
"orgID1",
domain.UserSessionStateActive,
},
},
},
userViewProvider: &mockViewUser{
PasswordSet: true,
IsEmailVerified: true,
MFAMaxSetUp: int32(domain.MFALevelSecondFactor),
},
userEventProvider: &mockEventUser{},
orgViewProvider: &mockViewOrg{State: domain.OrgStateActive},
userGrantProvider: &mockUserGrants{},
projectProvider: &mockProject{},
applicationProvider: &mockApp{app: &query.App{OIDCConfig: &query.OIDCApp{AppType: domain.OIDCApplicationTypeWeb}}},
lockoutPolicyProvider: &mockLockoutPolicy{
policy: &query.LockoutPolicy{
ShowFailures: true,
},
},
idpUserLinksProvider: &mockIDPUserLinks{},
loginPolicyProvider: &mockLoginPolicy{
policy: &query.LoginPolicy{
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
PasswordCheckLifetime: 10 * 24 * time.Hour,
SecondFactorCheckLifetime: 18 * time.Hour,
},
},
privacyPolicyProvider: &mockPrivacyPolicy{
policy: &query.PrivacyPolicy{},
},
labelPolicyProvider: &mockLabelPolicy{
policy: &query.LabelPolicy{},
},
customTextProvider: &mockCustomText{
texts: &query.CustomTexts{},
},
},
args{&domain.AuthRequest{
Request: &domain.AuthRequestOIDC{},
LoginPolicy: &domain.LoginPolicy{
SecondFactors: []domain.SecondFactorType{domain.SecondFactorTypeTOTP},
PasswordCheckLifetime: 10 * 24 * time.Hour,
SecondFactorCheckLifetime: 18 * time.Hour,
},
}, false},
[]domain.NextStep{&domain.RedirectToCallbackStep{}},
nil,
},
{
"user not set single inactive session, callback step",
fields{
userSessionViewProvider: &mockViewUserSession{
PasswordVerification: time.Now().Add(-5 * time.Minute),
SecondFactorVerification: time.Now().Add(-5 * time.Minute),
Users: []mockUser{
{
"id1",
"loginname1",
"orgID1",
domain.UserSessionStateTerminated,
},
},
},
@ -480,11 +580,19 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
SecondFactorCheckLifetime: 18 * time.Hour,
},
}, false},
[]domain.NextStep{&domain.RedirectToCallbackStep{}},
[]domain.NextStep{&domain.SelectUserStep{Users: []domain.UserSelection{
{
UserID: "id1",
LoginName: "loginname1",
ResourceOwner: "orgID1",
UserSessionState: domain.UserSessionStateTerminated,
SelectionPossible: true,
},
}}},
nil,
},
{
"user not set multiple active sessions, select account step",
"user not set multiple sessions, select account step",
fields{
userSessionViewProvider: &mockViewUserSession{
Users: []mockUser{
@ -492,11 +600,13 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
"id1",
"loginname1",
"orgID1",
domain.UserSessionStateActive,
},
{
"id2",
"loginname2",
"orgID2",
domain.UserSessionStateTerminated,
},
},
},
@ -532,12 +642,14 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
LoginName: "loginname1",
SelectionPossible: true,
ResourceOwner: "orgID1",
UserSessionState: domain.UserSessionStateActive,
},
{
UserID: "id2",
LoginName: "loginname2",
SelectionPossible: true,
ResourceOwner: "orgID2",
UserSessionState: domain.UserSessionStateTerminated,
},
},
}},
@ -1544,6 +1656,9 @@ func TestAuthRequestRepo_nextSteps(t *testing.T) {
LoginPolicyViewProvider: tt.fields.loginPolicyProvider,
LockoutPolicyViewProvider: tt.fields.lockoutPolicyProvider,
IDPUserLinksProvider: tt.fields.idpUserLinksProvider,
PrivacyPolicyProvider: tt.fields.privacyPolicyProvider,
LabelPolicyProvider: tt.fields.labelPolicyProvider,
CustomTextProvider: tt.fields.customTextProvider,
}
got, err := repo.nextSteps(context.Background(), tt.args.request, tt.args.checkLoggedIn)
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {

View File

@ -87,6 +87,7 @@ func Start(ctx context.Context, conf Config, systemDefaults sd.SystemDefaults, c
UserGrantProvider: queryView,
ProjectProvider: queryView,
ApplicationProvider: queries,
CustomTextProvider: queries,
IdGenerator: idGenerator,
},
eventstore.TokenRepo{

View File

@ -8,8 +8,8 @@ import (
context "context"
reflect "reflect"
domain "github.com/zitadel/zitadel/internal/domain"
gomock "github.com/golang/mock/gomock"
domain "github.com/zitadel/zitadel/internal/domain"
)
// MockAuthRequestCache is a mock of AuthRequestCache interface.
@ -36,47 +36,47 @@ func (m *MockAuthRequestCache) EXPECT() *MockAuthRequestCacheMockRecorder {
}
// DeleteAuthRequest mocks base method.
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1, arg2 string) error {
func (m *MockAuthRequestCache) DeleteAuthRequest(arg0 context.Context, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "DeleteAuthRequest", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// DeleteAuthRequest indicates an expected call of DeleteAuthRequest.
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockAuthRequestCacheMockRecorder) DeleteAuthRequest(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DeleteAuthRequest", reflect.TypeOf((*MockAuthRequestCache)(nil).DeleteAuthRequest), arg0, arg1)
}
// GetAuthRequestByCode mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
func (m *MockAuthRequestCache) GetAuthRequestByCode(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "GetAuthRequestByCode", arg0, arg1)
ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthRequestByCode indicates an expected call of GetAuthRequestByCode.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByCode(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByCode", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByCode), arg0, arg1)
}
// GetAuthRequestByID mocks base method.
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1, arg2 string) (*domain.AuthRequest, error) {
func (m *MockAuthRequestCache) GetAuthRequestByID(arg0 context.Context, arg1 string) (*domain.AuthRequest, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "GetAuthRequestByID", arg0, arg1)
ret0, _ := ret[0].(*domain.AuthRequest)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetAuthRequestByID indicates an expected call of GetAuthRequestByID.
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1, arg2 interface{}) *gomock.Call {
func (mr *MockAuthRequestCacheMockRecorder) GetAuthRequestByID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetAuthRequestByID", reflect.TypeOf((*MockAuthRequestCache)(nil).GetAuthRequestByID), arg0, arg1)
}
// Health mocks base method.