mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-21 12:27:47 +00:00
Merge commit from fork
This commit is contained in:
@@ -578,7 +578,7 @@ func (l *Login) checkAutoLinking(r *http.Request, authReq *domain.AuthRequest, p
|
||||
}
|
||||
|
||||
func (l *Login) autoLinkUser(r *http.Request, authReq *domain.AuthRequest, user *query.NotifyUser) error {
|
||||
if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID); err != nil {
|
||||
if err := l.authRepo.SelectUser(r.Context(), authReq.ID, user.ID, authReq.AgentID, false); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := l.authRepo.LinkExternalUsers(r.Context(), authReq.ID, authReq.AgentID, domain.BrowserInfoFromRequest(r)); err != nil {
|
||||
|
@@ -133,7 +133,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
|
||||
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID)
|
||||
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID, false)
|
||||
if err != nil {
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
|
@@ -46,7 +46,7 @@ func (l *Login) handleSelectUser(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
|
||||
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID)
|
||||
err = l.authRepo.SelectUser(r.Context(), authSession.ID, data.UserID, userAgentID, true)
|
||||
if err != nil {
|
||||
l.renderError(w, r, authSession, err)
|
||||
return
|
||||
|
@@ -19,7 +19,7 @@ type AuthRequestRepository interface {
|
||||
CheckExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser, info *domain.BrowserInfo, migrationCheck bool) error
|
||||
SetExternalUserLogin(ctx context.Context, authReqID, userAgentID string, user *domain.ExternalUser) error
|
||||
SetLinkingUser(ctx context.Context, request *domain.AuthRequest, externalUser *domain.ExternalUser) error
|
||||
SelectUser(ctx context.Context, authReqID, userID, userAgentID string) error
|
||||
SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) error
|
||||
SelectExternalIDP(ctx context.Context, authReqID, idpConfigID, userAgentID string, idpArguments map[string]any) error
|
||||
VerifyPassword(ctx context.Context, id, userID, resourceOwner, password, userAgentID string, info *domain.BrowserInfo) error
|
||||
|
||||
|
@@ -332,13 +332,24 @@ func (repo *AuthRequestRepo) setLinkingUser(ctx context.Context, request *domain
|
||||
return repo.AuthRequests.UpdateAuthRequest(ctx, request)
|
||||
}
|
||||
|
||||
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string) (err error) {
|
||||
func (repo *AuthRequestRepo) SelectUser(ctx context.Context, authReqID, userID, userAgentID string, enforceExistingSession bool) (err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
request, err := repo.getAuthRequest(ctx, authReqID, userAgentID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Check if the session already exists in the same user agent, e.g. when selecting the user from the user selection page.
|
||||
// This is to prevent username enumeration attacks by checking if the user exists in the system.
|
||||
if enforceExistingSession {
|
||||
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, userID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if userSession.Sequence == 0 {
|
||||
return zerrors.ThrowNotFound(nil, "AUTH-2d3f4", "Errors.UserSession.NotFound")
|
||||
}
|
||||
}
|
||||
user, err := activeUserByID(ctx, repo.UserViewProvider, repo.UserEventProvider, repo.OrgViewProvider, repo.LockoutPolicyViewProvider, userID, false)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -1061,7 +1072,7 @@ func (repo *AuthRequestRepo) nextSteps(ctx context.Context, request *domain.Auth
|
||||
if request.UserOrgID == "" {
|
||||
request.UserOrgID = user.ResourceOwner
|
||||
}
|
||||
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user)
|
||||
userSession, err := userSessionByIDs(ctx, repo.UserSessionViewProvider, repo.UserEventProvider, request.AgentID, user.ID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -1641,27 +1652,27 @@ var (
|
||||
}
|
||||
)
|
||||
|
||||
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID string, user *user_model.UserView) (*user_model.UserSessionView, error) {
|
||||
func userSessionByIDs(ctx context.Context, provider userSessionViewProvider, eventProvider userEventProvider, agentID, userID string) (*user_model.UserSessionView, error) {
|
||||
instanceID := authz.GetInstance(ctx).InstanceID()
|
||||
|
||||
// always load the latest sequence first, so in case the session was not found by id,
|
||||
// the sequence will be equal or lower than the actual projection and no events are lost
|
||||
sequence, err := provider.GetLatestUserSessionSequence(ctx, instanceID)
|
||||
logging.WithFields("instanceID", instanceID, "userID", user.ID).
|
||||
logging.WithFields("instanceID", instanceID, "userID", userID).
|
||||
OnError(err).
|
||||
Errorf("could not get current sequence for userSessionByIDs")
|
||||
|
||||
session, err := provider.UserSessionByIDs(ctx, agentID, user.ID, instanceID)
|
||||
session, err := provider.UserSessionByIDs(ctx, agentID, userID, instanceID)
|
||||
if err != nil {
|
||||
if !zerrors.IsNotFound(err) {
|
||||
return nil, err
|
||||
}
|
||||
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: user.ID}
|
||||
session = &user_view_model.UserSessionView{UserAgentID: agentID, UserID: userID}
|
||||
if sequence != nil {
|
||||
session.ChangeDate = sequence.EventCreatedAt
|
||||
}
|
||||
}
|
||||
events, err := eventProvider.UserEventsByID(ctx, user.ID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
|
||||
events, err := eventProvider.UserEventsByID(ctx, userID, session.ChangeDate, append(session.EventTypes(), userSessionEventTypes...))
|
||||
if err != nil {
|
||||
logging.WithFields("traceID", tracing.TraceIDFromCtx(ctx)).WithError(err).Debug("error retrieving new events")
|
||||
return user_view_model.UserSessionToModel(session), nil
|
||||
|
@@ -2796,7 +2796,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
userProvider userSessionViewProvider
|
||||
eventProvider userEventProvider
|
||||
agentID string
|
||||
user *user_model.UserView
|
||||
userID string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
@@ -2809,7 +2809,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
args{
|
||||
userProvider: &mockViewNoUserSession{},
|
||||
eventProvider: &mockEventErrUser{},
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
},
|
||||
&user_model.UserSessionView{UserID: "id"},
|
||||
nil,
|
||||
@@ -2818,7 +2818,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
"internal error, internal error",
|
||||
args{
|
||||
userProvider: &mockViewErrUserSession{},
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
},
|
||||
nil,
|
||||
zerrors.IsInternal,
|
||||
@@ -2829,7 +2829,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
userProvider: &mockViewUserSession{
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventErrUser{},
|
||||
},
|
||||
&user_model.UserSessionView{
|
||||
@@ -2846,7 +2846,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2871,7 +2871,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2900,7 +2900,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id", HumanView: &user_model.HumanView{FirstName: "FirstName"}},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2929,7 +2929,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2953,7 +2953,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
PasswordVerification: testNow,
|
||||
},
|
||||
agentID: "agentID",
|
||||
user: &user_model.UserView{ID: "id"},
|
||||
userID: "id",
|
||||
eventProvider: &mockEventUser{
|
||||
Events: []eventstore.Event{
|
||||
&es_models.Event{
|
||||
@@ -2986,7 +2986,7 @@ func Test_userSessionByIDs(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.user)
|
||||
got, err := userSessionByIDs(context.Background(), tt.args.userProvider, tt.args.eventProvider, tt.args.agentID, tt.args.userID)
|
||||
if (err != nil && tt.wantErr == nil) || (tt.wantErr != nil && !tt.wantErr(err)) {
|
||||
t.Errorf("nextSteps() wrong error = %v", err)
|
||||
return
|
||||
|
Reference in New Issue
Block a user