mirror of
https://github.com/zitadel/zitadel.git
synced 2025-11-03 12:32:52 +00:00
fix(login): only allow previously authenticated users on select account page
# Which Problems Are Solved
User enumeration was possible on the select account page by passing any userID as part of the form POST. Existing users could be selected even if they never authenticated on the same user agent (browser).
# How the Problems Are Solved
A check for an existing session on the same user agent was added to the select user function, resp. only required for the account selection page, since in other cases there doesn't have to be an existing session and the user agent integrity is already checked.
# Additional Changes
None
# Additional Context
None
(cherry picked from commit 7abe759c95)
This commit is contained in:
@@ -332,13 +332,24 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
|
||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string) (err error) {
|
||||
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if the session already exists in the same user agent, e.g. when selecting the user from the user selection page.
|
||||
// This is to prevent username enumeration attacks by checking if the user exists in the system.
|
||||
if enforceExistingSession {
|
||||
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if userSession.Sequence == 0 {
|
||||
return zerrors.ThrowNotFound(nil, "AUTH-2d3f4", "Errors.UserSession.NotFound")
|
||||
}
|
||||
}
|
||||
user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1061,7 +1072,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
|
||||
if request.UserOrgID == "" {
|
||||
request.UserOrgID = user.ResourceOwner
|
||||
}
|
||||
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user)
|
||||
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1641,27 +1652,27 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
|
||||
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID, userID string) (*user_model.UserSessionView, error) {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
|
||||
// always load the latest sequence first, so in case the session was not found by id,
|
||||
// the sequence will be equal or lower than the actual projection and no events are lost
|
||||
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
|
||||
logging.WithFields("instanceID", instanceID, "userID", user.ID).
|
||||
logging.WithFields("instanceID", instanceID, "userID", userID).
|
||||
OnError(err).
|
||||
Errorf("could not get current sequence for userSessionByIDs")
|
||||
|
||||
session, err := provider.UserSessionByIDs(ctx, agentID, user.ID, instanceID)
|
||||
session, err := provider.UserSessionByIDs(ctx, agentID, userID, instanceID)
|
||||
if err != nil {
|
||||
if !zerrors.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID}
|
||||
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: userID}
|
||||
if sequence != nil {
|
||||
session.ChangeDate = sequence.EventCreatedAt
|
||||
}
|
||||
}
|
||||
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
|
||||
events, err := eventProvider.UserEventsByID(ctx, userID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
|
||||
if err != nil {
|
||||
logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events")
|
||||
return user_view_model.UserSessionToModel(session), nil
|
||||
|
||||
@@ -2796,7 +2796,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
userProvider userSessionViewProvider
|
||||
eventProvider userEventProvider
|
||||
agentID string
|
||||
user *user_model.UserView
|
||||
userID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -2809,7 +2809,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
args{
|
||||
userProvider: &mockViewNoUserSession{},
|
||||
eventProvider: &mockEventErrUser{},
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
},
|
||||
&user_model.UserSessionView{UserID: "id"},
|
||||
nil,
|
||||
@@ -2818,7 +2818,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
"internal error, internal error",
|
||||
args{
|
||||
userProvider: &mockViewErrUserSession{},
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
},
|
||||
nil,
|
||||
zerrors.IsInternal,
|
||||
@@ -2829,7 +2829,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
userProvider: &mockViewUserSession{
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventErrUser{},
|
||||
},
|
||||
&user_model.UserSessionView{
|
||||
@@ -2846,7 +2846,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2871,7 +2871,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2900,7 +2900,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2929,7 +2929,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2953,7 +2953,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2986,7 +2986,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.user)
|
||||
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.userID)
|
||||
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
|
||||
t.Errorf("nextSteps() wrong error = %v", err)
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user