Merge commit from fork

This commit is contained in:
Livio Spring
2025-08-21 09:02:32 +02:00
committed by GitHub
parent 6c8d027e72
commit 7abe759c95
6 changed files with 32 additions and 21 deletions

View File

@@ -578,7 +578,7 @@ func (l *Login) checkAutoLinking(r *http.Request, authReq *domain.AuthRequest, p
}
func (l *Login) autoLinkUser(r *http.Request, authReq *domain.AuthRequest, user *query.NotifyUser) error {
if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID); err != nil {
if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID, false); err != nil {
return err
}
if err := l.authRepo.LinkExternalUsers(r.Context(), authReq.ID, authReq.AgentID, domain.BrowserInfoFromRequest(r)); err != nil {

View File

@@ -133,7 +133,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID)
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID, false)
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return

View File

@@ -46,7 +46,7 @@ func (l *Login) handleSelectUser(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID)
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID, true)
if err != nil {
l.renderError(w, r, authSession, err)
return

View File

@@ -19,7 +19,7 @@ type AuthRequestRepository interface {
CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo, migrationCheck bool) error
SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error
SetLinkingUser(ctx context.Context, request *domain.AuthRequest, externalUser *domain.ExternalUser) error
SelectUser(ctx context.Context, authReqID, userID, userAgentID string) error
SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) error
SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string, idpArguments map[string]any) error
VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error

View File

@@ -332,13 +332,24 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
}
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string) (err error) {
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) (err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
if err != nil {
return err
}
// Check if the session already exists in the same user agent, e.g. when selecting the user from the user selection page.
// This is to prevent username enumeration attacks by checking if the user exists in the system.
if enforceExistingSession {
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, userID)
if err != nil {
return err
}
if userSession.Sequence == 0 {
return zerrors.ThrowNotFound(nil, "AUTH-2d3f4", "Errors.UserSession.NotFound")
}
}
user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false)
if err != nil {
return err
@@ -1061,7 +1072,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
if request.UserOrgID == "" {
request.UserOrgID = user.ResourceOwner
}
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user)
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user.ID)
if err != nil {
return nil, err
}
@@ -1641,27 +1652,27 @@ var (
}
)
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID, userID string) (*user_model.UserSessionView, error) {
instanceID := authz.GetInstance(ctx).InstanceID()
// always load the latest sequence first, so in case the session was not found by id,
// the sequence will be equal or lower than the actual projection and no events are lost
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
logging.WithFields("instanceID", instanceID, "userID", user.ID).
logging.WithFields("instanceID", instanceID, "userID", userID).
OnError(err).
Errorf("could not get current sequence for userSessionByIDs")
session, err := provider.UserSessionByIDs(ctx, agentID, user.ID, instanceID)
session, err := provider.UserSessionByIDs(ctx, agentID, userID, instanceID)
if err != nil {
if !zerrors.IsNotFound(err) {
return nil, err
}
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID}
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: userID}
if sequence != nil {
session.ChangeDate = sequence.EventCreatedAt
}
}
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
events, err := eventProvider.UserEventsByID(ctx, userID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
if err != nil {
logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events")
return user_view_model.UserSessionToModel(session), nil

View File

@@ -2796,7 +2796,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider userSessionViewProvider
eventProvider userEventProvider
agentID string
user *user_model.UserView
userID string
}
tests := []struct {
name string
@@ -2809,7 +2809,7 @@ func Test_userSessionByIDs(t *testing.T) {
args{
userProvider: &mockViewNoUserSession{},
eventProvider: &mockEventErrUser{},
user: &user_model.UserView{ID: "id"},
userID: "id",
},
&user_model.UserSessionView{UserID: "id"},
nil,
@@ -2818,7 +2818,7 @@ func Test_userSessionByIDs(t *testing.T) {
"internal error, internal error",
args{
userProvider: &mockViewErrUserSession{},
user: &user_model.UserView{ID: "id"},
userID: "id",
},
nil,
zerrors.IsInternal,
@@ -2829,7 +2829,7 @@ func Test_userSessionByIDs(t *testing.T) {
userProvider: &mockViewUserSession{
PasswordVerification: testNow,
},
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
userID: "id",
eventProvider: &mockEventErrUser{},
},
&user_model.UserSessionView{
@@ -2846,7 +2846,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2871,7 +2871,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2900,7 +2900,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2929,7 +2929,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2953,7 +2953,7 @@ func Test_userSessionByIDs(t *testing.T) {
PasswordVerification: testNow,
},
agentID: "agentID",
user: &user_model.UserView{ID: "id"},
userID: "id",
eventProvider: &mockEventUser{
Events: []eventstore.Event{
&es_models.Event{
@@ -2986,7 +2986,7 @@ func Test_userSessionByIDs(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.user)
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.userID)
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
t.Errorf("nextSteps() wrong error = %v", err)
return