fix(login): only allow previously authenticated users on select account page

# Which Problems Are Solved

User enumeration was possible on the select account page by passing any userID as part of the form POST. Existing users could be selected even if they never authenticated on the same user agent (browser).

# How the Problems Are Solved

A check for an existing session on the same user agent was added to the select user function, resp. only required for the account selection page, since in other cases there doesn't have to be an existing session and the user agent integrity is already checked.

# Additional Changes

None

# Additional Context

None

(cherry picked from commit 7abe759c95)
This commit is contained in:
Livio Spring
2025-08-21 09:02:32 +02:00
parent 95848219d5
commit 1df24bebfe
6 changed files with 32 additions and 21 deletions

View File

@@ -332,13 +332,24 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string) (err error) {
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
if err != nil {
return err
}
// Check if the session already exists in the same user agent, e.g. when selecting the user from the user selection page.
// This is to prevent username enumeration attacks by checking if the user exists in the system.
if enforceExistingSession {
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, userID)
if err != nil {
return err
}
if userSession.Sequence == 0 {
return zerrors.ThrowNotFound(nil, "AUTH-2d3f4", "Errors.UserSession.NotFound")
}
}
user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false)
if err != nil {
return err
@@ -1061,7 +1072,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
if request.UserOrgID == "" {
request.UserOrgID = user.ResourceOwner
}
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user)
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user.ID)
if err != nil {
return nil, err
}
@@ -1641,27 +1652,27 @@ var (
}
)
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID, userID string) (*user_model.UserSessionView, error) {
instanceID := authz.GetInstance(ctx).InstanceID()
// always load the latest sequence first, so in case the session was not found by id,
// the sequence will be equal or lower than the actual projection and no events are lost
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", user.ID).
logging.WithFields("instanceID", instanceID, "userID", userID).
OnError(err).
Errorf("could not get current sequence for userSessionByIDs")
session, err := provider.UserSessionByIDs(ctx, agentID, user.ID, instanceID)
session, err := provider.UserSessionByIDs(ctx, agentID, userID, instanceID)
if err != nil {
if !zerrors.IsNotFound(err) {
return nil, err
}
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID}
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: userID}
if sequence != nil {
session.ChangeDate = sequence.EventCreatedAt
}
}
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
events, err := eventProvider.UserEventsByID(ctx, userID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
if err != nil {
logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events")
return user_view_model.UserSessionToModel(session), nil

View File

@@ -2796,7 +2796,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider userSessionViewProvider
eventProvider userEventProvider
agentID string
user *user_model.UserView
userID string
}
tests := []struct {
name string
@@ -2809,7 +2809,7 @@ func Test_userSessionByIDs(t *testing.T) {
args{
userProvider: &mockViewNoUserSession{},
eventProvider: &mockEventErrUser{},
user: &user_model.UserView{ID: "id"},
userID: "id",
},
&user_model.UserSessionView{UserID: "id"},
nil,
@@ -2818,7 +2818,7 @@ func Test_userSessionByIDs(t *testing.T) {
"internal error, internal error",
args{
userProvider: &mockViewErrUserSession{},
user: &user_model.UserView{ID: "id"},
userID: "id",
},
nil,
zerrors.IsInternal,
@@ -2829,7 +2829,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider: &mockViewUserSession{
PasswordVerification: testNow,
},
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
userID: "id",
eventProvider: &mockEventErrUser{},
},
&user_model.UserSessionView{
@@ -2846,7 +2846,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2871,7 +2871,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2900,7 +2900,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2929,7 +2929,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2953,7 +2953,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2986,7 +2986,7 @@ func Test_userSessionByIDs(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.user)
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.userID)
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
t.Errorf("nextSteps() wrong error = %v", err)
return