Merge commit from fork

This commit is contained in:
Livio Spring
2025-08-21 09:02:32 +02:00
committed by GitHub
parent 6c8d027e72
commit 7abe759c95
6 changed files with 32 additions and 21 deletions

View File

@@ -578,7 +578,7 @@ func (l *Login) checkAutoLinking(r *http.Request, authReq *domain.AuthRequest, p
} }
func (l *Login) autoLinkUser(r *http.Request, authReq *domain.AuthRequest, user *query.NotifyUser) error { func (l *Login) autoLinkUser(r *http.Request, authReq *domain.AuthRequest, user *query.NotifyUser) error {
if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID); err != nil { if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID, false); err != nil {
return err return err
} }
if err := l.authRepo.LinkExternalUsers(r.Context(), authReq.ID, authReq.AgentID, domain.BrowserInfoFromRequest(r)); err != nil { if err := l.authRepo.LinkExternalUsers(r.Context(), authReq.ID, authReq.AgentID, domain.BrowserInfoFromRequest(r)); err != nil {

View File

@@ -133,7 +133,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID) err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID, false)
if err != nil { if err != nil {
l.renderRegister(w, r, authRequest, data, err) l.renderRegister(w, r, authRequest, data, err)
return return

View File

@@ -46,7 +46,7 @@ func (l *Login) handleSelectUser(w http.ResponseWriter, r *http.Request) {
return return
} }
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context()) userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID) err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID, true)
if err != nil { if err != nil {
l.renderError(w, r, authSession, err) l.renderError(w, r, authSession, err)
return return

View File

@@ -19,7 +19,7 @@ type AuthRequestRepository interface {
CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo, migrationCheck bool) error CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo, migrationCheck bool) error
SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error
SetLinkingUser(ctx context.Context, request *domain.AuthRequest, externalUser *domain.ExternalUser) error SetLinkingUser(ctx context.Context, request *domain.AuthRequest, externalUser *domain.ExternalUser) error
SelectUser(ctx context.Context, authReqID, userID, userAgentID string) error SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) error
SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string, idpArguments map[string]any) error SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string, idpArguments map[string]any) error
VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error

View File

@@ -332,13 +332,24 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
return repo.AuthRequests.UpdateAuthRequest(ctx, request) return repo.AuthRequests.UpdateAuthRequest(ctx, request)
} }
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string) (err error) { func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) (err error) {
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID) request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
if err != nil { if err != nil {
return err return err
} }
// Check if the session already exists in the same user agent, e.g. when selecting the user from the user selection page.
// This is to prevent username enumeration attacks by checking if the user exists in the system.
if enforceExistingSession {
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, userID)
if err != nil {
return err
}
if userSession.Sequence == 0 {
return zerrors.ThrowNotFound(nil, "AUTH-2d3f4", "Errors.UserSession.NotFound")
}
}
user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false) user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false)
if err != nil { if err != nil {
return err return err
@@ -1061,7 +1072,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
if request.UserOrgID == "" { if request.UserOrgID == "" {
request.UserOrgID = user.ResourceOwner request.UserOrgID = user.ResourceOwner
} }
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user) userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user.ID)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1641,27 +1652,27 @@ var (
} }
) )
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) { func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID, userID string) (*user_model.UserSessionView, error) {
instanceID := authz.GetInstance(ctx).InstanceID() instanceID := authz.GetInstance(ctx).InstanceID()
// always load the latest sequence first, so in case the session was not found by id, // always load the latest sequence first, so in case the session was not found by id,
// the sequence will be equal or lower than the actual projection and no events are lost // the sequence will be equal or lower than the actual projection and no events are lost
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID) sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", user.ID). logging.WithFields("instanceID", instanceID, "userID", userID).
OnError(err). OnError(err).
Errorf("could not get current sequence for userSessionByIDs") Errorf("could not get current sequence for userSessionByIDs")
session, err := provider.UserSessionByIDs(ctx, agentID, user.ID, instanceID) session, err := provider.UserSessionByIDs(ctx, agentID, userID, instanceID)
if err != nil { if err != nil {
if !zerrors.IsNotFound(err) { if !zerrors.IsNotFound(err) {
return nil, err return nil, err
} }
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID} session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: userID}
if sequence != nil { if sequence != nil {
session.ChangeDate = sequence.EventCreatedAt session.ChangeDate = sequence.EventCreatedAt
} }
} }
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...)) events, err := eventProvider.UserEventsByID(ctx, userID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
if err != nil { if err != nil {
logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events") logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events")
return user_view_model.UserSessionToModel(session), nil return user_view_model.UserSessionToModel(session), nil

View File

@@ -2796,7 +2796,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider userSessionViewProvider userProvider userSessionViewProvider
eventProvider userEventProvider eventProvider userEventProvider
agentID string agentID string
user *user_model.UserView userID string
} }
tests := []struct { tests := []struct {
name string name string
@@ -2809,7 +2809,7 @@ func Test_userSessionByIDs(t *testing.T) {
args{ args{
userProvider: &mockViewNoUserSession{}, userProvider: &mockViewNoUserSession{},
eventProvider: &mockEventErrUser{}, eventProvider: &mockEventErrUser{},
user: &user_model.UserView{ID: "id"}, userID: "id",
}, },
&user_model.UserSessionView{UserID: "id"}, &user_model.UserSessionView{UserID: "id"},
nil, nil,
@@ -2818,7 +2818,7 @@ func Test_userSessionByIDs(t *testing.T) {
"internal error, internal error", "internal error, internal error",
args{ args{
userProvider: &mockViewErrUserSession{}, userProvider: &mockViewErrUserSession{},
user: &user_model.UserView{ID: "id"}, userID: "id",
}, },
nil, nil,
zerrors.IsInternal, zerrors.IsInternal,
@@ -2829,7 +2829,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider: &mockViewUserSession{ userProvider: &mockViewUserSession{
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, userID: "id",
eventProvider: &mockEventErrUser{}, eventProvider: &mockEventErrUser{},
}, },
&user_model.UserSessionView{ &user_model.UserSessionView{
@@ -2846,7 +2846,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2871,7 +2871,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id"}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2900,7 +2900,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2929,7 +2929,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id"}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2953,7 +2953,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow, PasswordVerification: testNow,
}, },
agentID: "agentID", agentID: "agentID",
user: &user_model.UserView{ID: "id"}, userID: "id",
eventProvider: &mockEventUser{ eventProvider: &mockEventUser{
Events: []eventstore.Event{ Events: []eventstore.Event{
&es_models.Event{ &es_models.Event{
@@ -2986,7 +2986,7 @@ func Test_userSessionByIDs(t *testing.T) {
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.user) got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.userID)
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) { if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
t.Errorf("nextSteps() wrong error = %v", err) t.Errorf("nextSteps() wrong error = %v", err)
return return