diff --git a/internal/api/ui/login/external_provider_handler.go b/internal/api/ui/login/external_provider_handler.go index 73b134d330..514ff450fe 100644 --- a/internal/api/ui/login/external_provider_handler.go +++ b/internal/api/ui/login/external_provider_handler.go @@ -578,7 +578,7 @@ func (l *Login) checkAutoLinking(r *http.Request, authReq *domain.AuthRequest, p } func (l *Login) autoLinkUser(r *http.Request, authReq *domain.AuthRequest, user *query.NotifyUser) error { - if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID); err != nil { + if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID, false); err != nil { return err } if err := l.authRepo.LinkExternalUsers(r.Context(), authReq.ID, authReq.AgentID, domain.BrowserInfoFromRequest(r)); err != nil { diff --git a/internal/api/ui/login/register_handler.go b/internal/api/ui/login/register_handler.go index bd5629c432..1c6af85519 100644 --- a/internal/api/ui/login/register_handler.go +++ b/internal/api/ui/login/register_handler.go @@ -133,7 +133,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) { return } userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) - err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID) + err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID, false) if err != nil { l.renderRegister(w, r, authRequest, data, err) return diff --git a/internal/api/ui/login/select_user_handler.go b/internal/api/ui/login/select_user_handler.go index b15366baa1..01a4f060c2 100644 --- a/internal/api/ui/login/select_user_handler.go +++ b/internal/api/ui/login/select_user_handler.go @@ -46,7 +46,7 @@ func (l *Login) handleSelectUser(w http.ResponseWriter, r *http.Request) { return } userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) - err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID) + err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID, true) if err != nil { l.renderError(w, r, authSession, err) return diff --git a/internal/auth/repository/auth_request.go b/internal/auth/repository/auth_request.go index 53272d2d2f..6c607bc1a5 100644 --- a/internal/auth/repository/auth_request.go +++ b/internal/auth/repository/auth_request.go @@ -19,7 +19,7 @@ type AuthRequestRepository interface { CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo, migrationCheck bool) error SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error SetLinkingUser(ctx context.Context, request *domain.AuthRequest, externalUser *domain.ExternalUser) error - SelectUser(ctx context.Context, authReqID, userID, userAgentID string) error + SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) error SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string, idpArguments map[string]any) error VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request.go b/internal/auth/repository/eventsourcing/eventstore/auth_request.go index 4081b66d64..ea260eeb1b 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request.go @@ -332,13 +332,24 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain return repo.AuthRequests.UpdateAuthRequest(ctx, request) } -func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string) (err error) { +func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) (err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) if err != nil { return err } + // Check if the session already exists in the same user agent, e.g. when selecting the user from the user selection page. + // This is to prevent username enumeration attacks by checking if the user exists in the system. + if enforceExistingSession { + userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, userID) + if err != nil { + return err + } + if userSession.Sequence == 0 { + return zerrors.ThrowNotFound(nil, "AUTH-2d3f4", "Errors.UserSession.NotFound") + } + } user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false) if err != nil { return err @@ -1061,7 +1072,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth if request.UserOrgID == "" { request.UserOrgID = user.ResourceOwner } - userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user) + userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user.ID) if err != nil { return nil, err } @@ -1641,27 +1652,27 @@ var ( } ) -func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) { +func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID, userID string) (*user_model.UserSessionView, error) { instanceID := authz.GetInstance(ctx).InstanceID() // always load the latest sequence first, so in case the session was not found by id, // the sequence will be equal or lower than the actual projection and no events are lost sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID) - logging.WithFields("instanceID", instanceID, "userID", user.ID). + logging.WithFields("instanceID", instanceID, "userID", userID). OnError(err). Errorf("could not get current sequence for userSessionByIDs") - session, err := provider.UserSessionByIDs(ctx, agentID, user.ID, instanceID) + session, err := provider.UserSessionByIDs(ctx, agentID, userID, instanceID) if err != nil { if !zerrors.IsNotFound(err) { return nil, err } - session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID} + session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: userID} if sequence != nil { session.ChangeDate = sequence.EventCreatedAt } } - events, err := eventProvider.UserEventsByID(ctx, user.ID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...)) + events, err := eventProvider.UserEventsByID(ctx, userID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...)) if err != nil { logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events") return user_view_model.UserSessionToModel(session), nil diff --git a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go index 079d65a94f..0956c3801e 100644 --- a/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go +++ b/internal/auth/repository/eventsourcing/eventstore/auth_request_test.go @@ -2796,7 +2796,7 @@ func Test_userSessionByIDs(t *testing.T) { userProvider userSessionViewProvider eventProvider userEventProvider agentID string - user *user_model.UserView + userID string } tests := []struct { name string @@ -2809,7 +2809,7 @@ func Test_userSessionByIDs(t *testing.T) { args{ userProvider: &mockViewNoUserSession{}, eventProvider: &mockEventErrUser{}, - user: &user_model.UserView{ID: "id"}, + userID: "id", }, &user_model.UserSessionView{UserID: "id"}, nil, @@ -2818,7 +2818,7 @@ func Test_userSessionByIDs(t *testing.T) { "internal error, internal error", args{ userProvider: &mockViewErrUserSession{}, - user: &user_model.UserView{ID: "id"}, + userID: "id", }, nil, zerrors.IsInternal, @@ -2829,7 +2829,7 @@ func Test_userSessionByIDs(t *testing.T) { userProvider: &mockViewUserSession{ PasswordVerification: testNow, }, - user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, + userID: "id", eventProvider: &mockEventErrUser{}, }, &user_model.UserSessionView{ @@ -2846,7 +2846,7 @@ func Test_userSessionByIDs(t *testing.T) { PasswordVerification: testNow, }, agentID: "agentID", - user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, + userID: "id", eventProvider: &mockEventUser{ Events: []eventstore.Event{ &es_models.Event{ @@ -2871,7 +2871,7 @@ func Test_userSessionByIDs(t *testing.T) { PasswordVerification: testNow, }, agentID: "agentID", - user: &user_model.UserView{ID: "id"}, + userID: "id", eventProvider: &mockEventUser{ Events: []eventstore.Event{ &es_models.Event{ @@ -2900,7 +2900,7 @@ func Test_userSessionByIDs(t *testing.T) { PasswordVerification: testNow, }, agentID: "agentID", - user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, + userID: "id", eventProvider: &mockEventUser{ Events: []eventstore.Event{ &es_models.Event{ @@ -2929,7 +2929,7 @@ func Test_userSessionByIDs(t *testing.T) { PasswordVerification: testNow, }, agentID: "agentID", - user: &user_model.UserView{ID: "id"}, + userID: "id", eventProvider: &mockEventUser{ Events: []eventstore.Event{ &es_models.Event{ @@ -2953,7 +2953,7 @@ func Test_userSessionByIDs(t *testing.T) { PasswordVerification: testNow, }, agentID: "agentID", - user: &user_model.UserView{ID: "id"}, + userID: "id", eventProvider: &mockEventUser{ Events: []eventstore.Event{ &es_models.Event{ @@ -2986,7 +2986,7 @@ func Test_userSessionByIDs(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.user) + got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.userID) if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) { t.Errorf("nextSteps() wrong error = %v", err) return