fix(login): only allow previously authenticated users on select account page

# Which Problems Are Solved

User enumeration was possible on the select account page by passing any userID as part of the form POST. Existing users could be selected even if they never authenticated on the same user agent (browser).

# How the Problems Are Solved

A check for an existing session on the same user agent was added to the select user function, resp. only required for the account selection page, since in other cases there doesn't have to be an existing session and the user agent integrity is already checked.

# Additional Changes

None

# Additional Context

None

(cherry picked from commit 7abe759c95)
This commit is contained in:
Livio Spring
2025-08-21 09:02:32 +02:00
parent 8cb28190ee
commit 9daec3dd5e
6 changed files with 32 additions and 21 deletions

View File

@@ -560,7 +560,7 @@ func (l *Login) checkAutoLinking(r *http.Request, authReq *domain.AuthRequest, p
} }
func (l *Login) autoLinkUser(r *http.Request, authReq *domain.AuthRequest, user *query.NotifyUser) error { func (l *Login) autoLinkUser(r *http.Request, authReq *domain.AuthRequest, user *query.NotifyUser) error {
if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID); err != nil { if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID, false); err != nil {
return err return err
} }
if err := l.authRepo.LinkExternalUsers(r.Context(), authReq.ID, authReq.AgentID, domain.BrowserInfoFromRequest(r)); err != nil { if err := l.authRepo.LinkExternalUsers(r.Context(), authReq.ID, authReq.AgentID, domain.BrowserInfoFromRequest(r)); err != nil {

View File

@@ -133,7 +133,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID) err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID, false)
if err != nil { if err != nil {
l.renderRegister(w, r, authRequest, data, err) l.renderRegister(w, r, authRequest, data, err)
return return

View File

@@ -46,7 +46,7 @@ func (l *Login) handleSelectUser(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID) err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID, true)
if err != nil { if err != nil {
l.renderError(w, r, authSession, err) l.renderError(w, r, authSession, err)
return return

View File

@@ -19,7 +19,7 @@ type AuthRequestRepository interface {
CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo, migrationCheck bool) error CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo, migrationCheck bool) error
SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error
SetLinkingUser(ctx context.Context, request *domain.AuthRequest, externalUser *domain.ExternalUser) error SetLinkingUser(ctx context.Context, request *domain.AuthRequest, externalUser *domain.ExternalUser) error
SelectUser(ctx context.Context, authReqID, userID, userAgentID string) error SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) error
SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string, idpArguments map[string]any) error SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string, idpArguments map[string]any) error
VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error

View File

@@ -332,13 +332,24 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string) (err error) { func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
if err != nil { if err != nil {
return err return err
} }
// Check if the session already exists in the same user agent, e.g. when selecting the user from the user selection page.
// This is to prevent username enumeration attacks by checking if the user exists in the system.
if enforceExistingSession {
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, userID)
if err != nil {
return err
}
if userSession.Sequence == 0 {
return zerrors.ThrowNotFound(nil, "AUTH-2d3f4", "Errors.UserSession.NotFound")
}
}
user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false) user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false)
if err != nil { if err != nil {
return err return err
@@ -1059,7 +1070,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
if request.UserOrgID == "" { if request.UserOrgID == "" {
request.UserOrgID = user.ResourceOwner request.UserOrgID = user.ResourceOwner
} }
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user) userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1639,27 +1650,27 @@ var (
} }
) )
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) { func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID, userID string) (*user_model.UserSessionView, error) {
instanceID := authz.GetInstance(ctx).InstanceID() instanceID := authz.GetInstance(ctx).InstanceID()
// always load the latest sequence first, so in case the session was not found by id, // always load the latest sequence first, so in case the session was not found by id,
// the sequence will be equal or lower than the actual projection and no events are lost // the sequence will be equal or lower than the actual projection and no events are lost
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID) sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", user.ID). logging.WithFields("instanceID", instanceID, "userID", userID).
OnError(err). OnError(err).
Errorf("could not get current sequence for userSessionByIDs") Errorf("could not get current sequence for userSessionByIDs")
session, err := provider.UserSessionByIDs(ctx, agentID, user.ID, instanceID) session, err := provider.UserSessionByIDs(ctx, agentID, userID, instanceID)
if err != nil { if err != nil {
if !zerrors.IsNotFound(err) { if !zerrors.IsNotFound(err) {
return nil, err return nil, err
} }
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID} session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: userID}
if sequence != nil { if sequence != nil {
session.ChangeDate = sequence.EventCreatedAt session.ChangeDate = sequence.EventCreatedAt
} }
} }
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...)) events, err := eventProvider.UserEventsByID(ctx, userID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
if err != nil { if err != nil {
logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events") logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events")
return user_view_model.UserSessionToModel(session), nil return user_view_model.UserSessionToModel(session), nil

View File

@@ -2796,7 +2796,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider userSessionViewProvider userProvider userSessionViewProvider
eventProvider userEventProvider eventProvider userEventProvider
agentID string agentID string
user *user_model.UserView userID string
} }
tests := []struct { tests := []struct {
name string name string
@@ -2809,7 +2809,7 @@ func Test_userSessionByIDs(t *testing.T) {
args{ args{
userProvider: &mockViewNoUserSession{}, userProvider: &mockViewNoUserSession{},
eventProvider: &mockEventErrUser{}, eventProvider: &mockEventErrUser{},
user: &user_model.UserView{ID: "id"}, userID: "id",
}, },
&user_model.UserSessionView{UserID: "id"}, &user_model.UserSessionView{UserID: "id"},
nil, nil,
@@ -2818,7 +2818,7 @@ func Test_userSessionByIDs(t *testing.T) {
"internal error, internal error", "internal error, internal error",
args{ args{
userProvider: &mockViewErrUserSession{}, userProvider: &mockViewErrUserSession{},
user: &user_model.UserView{ID: "id"}, userID: "id",
}, },
nil, nil,
zerrors.IsInternal, zerrors.IsInternal,
@@ -2829,7 +2829,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider: &mockViewUserSession{ userProvider: &mockViewUserSession{
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, userID: "id",
eventProvider: &mockEventErrUser{}, eventProvider: &mockEventErrUser{},
}, },
&user_model.UserSessionView{ &user_model.UserSessionView{
@@ -2846,7 +2846,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2871,7 +2871,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id"}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2900,7 +2900,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2929,7 +2929,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id"}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2953,7 +2953,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id"}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2986,7 +2986,7 @@ func Test_userSessionByIDs(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.user) got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.userID)
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) { if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
t.Errorf("nextSteps() wrong error = %v", err) t.Errorf("nextSteps() wrong error = %v", err)
return return