mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 08:27:32 +00:00
feat: pass and handle auth request context for email links (#7815)
* pass and handle auth request context * tests and cleanup * cleanup
This commit is contained in:
@@ -3,6 +3,8 @@ package login
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
@@ -33,3 +35,23 @@ func (l *Login) getAuthRequestAndParseData(r *http.Request, data interface{}) (*
|
||||
func (l *Login) getParseData(r *http.Request, data interface{}) error {
|
||||
return l.parser.Parse(r, data)
|
||||
}
|
||||
|
||||
// checkOptionalAuthRequestOfEmailLinks tries to get the [domain.AuthRequest] from the request.
|
||||
// In case any error occurs, e.g. if the user agent does not correspond, the `authRequestID` query parameter will be
|
||||
// removed from the request URL and form to ensure subsequent functions and pages do not use it.
|
||||
// This function is used for handling links in emails, which could possibly be opened on another device than the
|
||||
// auth request was initiated.
|
||||
func (l *Login) checkOptionalAuthRequestOfEmailLinks(r *http.Request) *domain.AuthRequest {
|
||||
authReq, err := l.getAuthRequest(r)
|
||||
if err == nil {
|
||||
return authReq
|
||||
}
|
||||
logging.WithError(err).Infof("authrequest could not be found for email link on path %s", r.URL.RequestURI())
|
||||
queries := r.URL.Query()
|
||||
queries.Del(QueryAuthRequestID)
|
||||
r.URL.RawQuery = queries.Encode()
|
||||
r.RequestURI = r.URL.RequestURI()
|
||||
r.Form.Del(QueryAuthRequestID)
|
||||
r.PostForm.Del(QueryAuthRequestID)
|
||||
return nil
|
||||
}
|
||||
|
@@ -1,8 +1,8 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
@@ -38,14 +38,20 @@ type initPasswordData struct {
|
||||
HasSymbol string
|
||||
}
|
||||
|
||||
func InitPasswordLink(origin, userID, code, orgID string) string {
|
||||
return fmt.Sprintf("%s%s?userID=%s&code=%s&orgID=%s", externalLink(origin), EndpointInitPassword, userID, code, orgID)
|
||||
func InitPasswordLink(origin, userID, code, orgID, authRequestID string) string {
|
||||
v := url.Values{}
|
||||
v.Set(queryInitPWUserID, userID)
|
||||
v.Set(queryInitPWCode, code)
|
||||
v.Set(queryOrgID, orgID)
|
||||
v.Set(QueryAuthRequestID, authRequestID)
|
||||
return externalLink(origin) + EndpointInitPassword + "?" + v.Encode()
|
||||
}
|
||||
|
||||
func (l *Login) handleInitPassword(w http.ResponseWriter, r *http.Request) {
|
||||
authReq := l.checkOptionalAuthRequestOfEmailLinks(r)
|
||||
userID := r.FormValue(queryInitPWUserID)
|
||||
code := r.FormValue(queryInitPWCode)
|
||||
l.renderInitPassword(w, r, nil, userID, code, nil)
|
||||
l.renderInitPassword(w, r, authReq, userID, code, nil)
|
||||
}
|
||||
|
||||
func (l *Login) handleInitPasswordCheck(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -94,7 +100,7 @@ func (l *Login) resendPasswordSet(w http.ResponseWriter, r *http.Request, authRe
|
||||
l.renderInitPassword(w, r, authReq, userID, "", err)
|
||||
return
|
||||
}
|
||||
_, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator)
|
||||
_, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator, authReq.ID)
|
||||
l.renderInitPassword(w, r, authReq, userID, "", err)
|
||||
}
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
|
||||
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
@@ -44,16 +44,24 @@ type initUserData struct {
|
||||
HasSymbol string
|
||||
}
|
||||
|
||||
func InitUserLink(origin, userID, loginName, code, orgID string, passwordSet bool) string {
|
||||
return fmt.Sprintf("%s%s?userID=%s&loginname=%s&code=%s&orgID=%s&passwordset=%t", externalLink(origin), EndpointInitUser, userID, loginName, code, orgID, passwordSet)
|
||||
func InitUserLink(origin, userID, loginName, code, orgID string, passwordSet bool, authRequestID string) string {
|
||||
v := url.Values{}
|
||||
v.Set(queryInitUserUserID, userID)
|
||||
v.Set(queryInitUserLoginName, loginName)
|
||||
v.Set(queryInitUserCode, code)
|
||||
v.Set(queryOrgID, orgID)
|
||||
v.Set(queryInitUserPassword, strconv.FormatBool(passwordSet))
|
||||
v.Set(QueryAuthRequestID, authRequestID)
|
||||
return externalLink(origin) + EndpointInitUser + "?" + v.Encode()
|
||||
}
|
||||
|
||||
func (l *Login) handleInitUser(w http.ResponseWriter, r *http.Request) {
|
||||
authReq := l.checkOptionalAuthRequestOfEmailLinks(r)
|
||||
userID := r.FormValue(queryInitUserUserID)
|
||||
code := r.FormValue(queryInitUserCode)
|
||||
loginName := r.FormValue(queryInitUserLoginName)
|
||||
passwordSet, _ := strconv.ParseBool(r.FormValue(queryInitUserPassword))
|
||||
l.renderInitUser(w, r, nil, userID, loginName, code, passwordSet, nil)
|
||||
l.renderInitUser(w, r, authReq, userID, loginName, code, passwordSet, nil)
|
||||
}
|
||||
|
||||
func (l *Login) handleInitUserCheck(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -105,7 +113,7 @@ func (l *Login) resendUserInit(w http.ResponseWriter, r *http.Request, authReq *
|
||||
l.renderInitUser(w, r, authReq, userID, loginName, "", showPassword, err)
|
||||
return
|
||||
}
|
||||
_, err = l.command.ResendInitialMail(setContext(r.Context(), userOrgID), userID, "", userOrgID, initCodeGenerator)
|
||||
_, err = l.command.ResendInitialMail(setContext(r.Context(), userOrgID), userID, "", userOrgID, initCodeGenerator, authReq.ID)
|
||||
l.renderInitUser(w, r, authReq, userID, loginName, "", showPassword, err)
|
||||
}
|
||||
|
||||
|
@@ -1,8 +1,8 @@
|
||||
package login
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
)
|
||||
@@ -27,18 +27,24 @@ type mailVerificationData struct {
|
||||
UserID string
|
||||
}
|
||||
|
||||
func MailVerificationLink(origin, userID, code, orgID string) string {
|
||||
return fmt.Sprintf("%s%s?userID=%s&code=%s&orgID=%s", externalLink(origin), EndpointMailVerification, userID, code, orgID)
|
||||
func MailVerificationLink(origin, userID, code, orgID, authRequestID string) string {
|
||||
v := url.Values{}
|
||||
v.Set(queryUserID, userID)
|
||||
v.Set(queryCode, code)
|
||||
v.Set(queryOrgID, orgID)
|
||||
v.Set(QueryAuthRequestID, authRequestID)
|
||||
return externalLink(origin) + EndpointMailVerification + "?" + v.Encode()
|
||||
}
|
||||
|
||||
func (l *Login) handleMailVerification(w http.ResponseWriter, r *http.Request) {
|
||||
authReq := l.checkOptionalAuthRequestOfEmailLinks(r)
|
||||
userID := r.FormValue(queryUserID)
|
||||
code := r.FormValue(queryCode)
|
||||
if code != "" {
|
||||
l.checkMailCode(w, r, nil, userID, code)
|
||||
l.checkMailCode(w, r, authReq, userID, code)
|
||||
return
|
||||
}
|
||||
l.renderMailVerification(w, r, nil, userID, nil)
|
||||
l.renderMailVerification(w, r, authReq, userID, nil)
|
||||
}
|
||||
|
||||
func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Request) {
|
||||
@@ -61,7 +67,7 @@ func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Reque
|
||||
l.checkMailCode(w, r, authReq, data.UserID, data.Code)
|
||||
return
|
||||
}
|
||||
_, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator)
|
||||
_, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator, authReq.ID)
|
||||
l.renderMailVerification(w, r, authReq, data.UserID, err)
|
||||
}
|
||||
|
||||
|
@@ -33,7 +33,7 @@ func (l *Login) handlePasswordReset(w http.ResponseWriter, r *http.Request) {
|
||||
l.renderPasswordResetDone(w, r, authReq, err)
|
||||
return
|
||||
}
|
||||
_, err = l.command.RequestSetPassword(setContext(r.Context(), authReq.UserOrgID), user.ID, authReq.UserOrgID, domain.NotificationTypeEmail, passwordCodeGenerator)
|
||||
_, err = l.command.RequestSetPassword(setContext(r.Context(), authReq.UserOrgID), user.ID, authReq.UserOrgID, domain.NotificationTypeEmail, passwordCodeGenerator, authReq.ID)
|
||||
l.renderPasswordResetDone(w, r, authReq, err)
|
||||
}
|
||||
|
||||
|
@@ -7,6 +7,7 @@ import (
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
@@ -67,22 +68,6 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
|
||||
if authRequest != nil && authRequest.RequestedOrgID != "" && authRequest.RequestedOrgID != resourceOwner {
|
||||
resourceOwner = authRequest.RequestedOrgID
|
||||
}
|
||||
initCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeInitCode, l.userCodeAlg)
|
||||
if err != nil {
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
}
|
||||
emailCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyEmailCode, l.userCodeAlg)
|
||||
if err != nil {
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
}
|
||||
phoneCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyPhoneCode, l.userCodeAlg)
|
||||
if err != nil {
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
}
|
||||
|
||||
// For consistency with the external authentication flow,
|
||||
// the setMetadata() function is provided on the pre creation hook, for now,
|
||||
// like for the ExternalAuthentication flow.
|
||||
@@ -96,22 +81,14 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
}
|
||||
user, err = l.command.RegisterHuman(setContext(r.Context(), resourceOwner), resourceOwner, user, nil, nil, initCodeGenerator, emailCodeGenerator, phoneCodeGenerator)
|
||||
|
||||
human := command.AddHumanFromDomain(user, metadatas, authRequest, nil)
|
||||
err = l.command.AddUserHuman(setContext(r.Context(), resourceOwner), resourceOwner, human, true, l.userCodeAlg)
|
||||
if err != nil {
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(metadatas) > 0 {
|
||||
_, err = l.command.BulkSetUserMetadata(r.Context(), user.AggregateID, resourceOwner, metadatas...)
|
||||
if err != nil {
|
||||
// TODO: What if action is configured to be allowed to fail? Same question for external registration.
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
userGrants, err := l.runPostCreationActions(user.AggregateID, authRequest, r, resourceOwner, domain.FlowTypeInternalAuthentication)
|
||||
userGrants, err := l.runPostCreationActions(human.ID, authRequest, r, resourceOwner, domain.FlowTypeInternalAuthentication)
|
||||
if err != nil {
|
||||
l.renderError(w, r, authRequest, err)
|
||||
return
|
||||
@@ -128,7 +105,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
|
||||
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, user.AggregateID, userAgentID)
|
||||
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID)
|
||||
if err != nil {
|
||||
l.renderRegister(w, r, authRequest, data, err)
|
||||
return
|
||||
|
Reference in New Issue
Block a user