feat: pass and handle auth request context for email links (#7815)

* pass and handle auth request context

* tests and cleanup

* cleanup
This commit is contained in:
Livio Spring
2024-04-24 17:50:58 +02:00
committed by GitHub
parent ac985e2dfb
commit d016379e2a
38 changed files with 851 additions and 2018 deletions

View File

@@ -65,7 +65,7 @@ func (s *Server) ResendMyEmailVerification(ctx context.Context, _ *auth_pb.Resen
if err != nil {
return nil, err
}
objectDetails, err := s.command.CreateHumanEmailVerificationCode(ctx, ctxData.UserID, ctxData.ResourceOwner, emailCodeGenerator)
objectDetails, err := s.command.CreateHumanEmailVerificationCode(ctx, ctxData.UserID, ctxData.ResourceOwner, emailCodeGenerator, "")
if err != nil {
return nil, err
}

View File

@@ -473,7 +473,7 @@ func (s *Server) ResendHumanInitialization(ctx context.Context, req *mgmt_pb.Res
if err != nil {
return nil, err
}
details, err := s.command.ResendInitialMail(ctx, req.UserId, domain.EmailAddress(req.Email), authz.GetCtxData(ctx).OrgID, initCodeGenerator)
details, err := s.command.ResendInitialMail(ctx, req.UserId, domain.EmailAddress(req.Email), authz.GetCtxData(ctx).OrgID, initCodeGenerator, "")
if err != nil {
return nil, err
}
@@ -487,7 +487,7 @@ func (s *Server) ResendHumanEmailVerification(ctx context.Context, req *mgmt_pb.
if err != nil {
return nil, err
}
objectDetails, err := s.command.CreateHumanEmailVerificationCode(ctx, req.UserId, authz.GetCtxData(ctx).OrgID, emailCodeGenerator)
objectDetails, err := s.command.CreateHumanEmailVerificationCode(ctx, req.UserId, authz.GetCtxData(ctx).OrgID, emailCodeGenerator, "")
if err != nil {
return nil, err
}
@@ -590,7 +590,7 @@ func (s *Server) SendHumanResetPasswordNotification(ctx context.Context, req *mg
if err != nil {
return nil, err
}
objectDetails, err := s.command.RequestSetPassword(ctx, req.UserId, authz.GetCtxData(ctx).OrgID, notifyTypeToDomain(req.Type), passwordCodeGenerator)
objectDetails, err := s.command.RequestSetPassword(ctx, req.UserId, authz.GetCtxData(ctx).OrgID, notifyTypeToDomain(req.Type), passwordCodeGenerator, "")
if err != nil {
return nil, err
}

View File

@@ -3,6 +3,8 @@ package login
import (
"net/http"
"github.com/zitadel/logging"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/domain"
)
@@ -33,3 +35,23 @@ func (l *Login) getAuthRequestAndParseData(r *http.Request, data interface{}) (*
func (l *Login) getParseData(r *http.Request, data interface{}) error {
return l.parser.Parse(r, data)
}
// checkOptionalAuthRequestOfEmailLinks tries to get the [domain.AuthRequest] from the request.
// In case any error occurs, e.g. if the user agent does not correspond, the `authRequestID` query parameter will be
// removed from the request URL and form to ensure subsequent functions and pages do not use it.
// This function is used for handling links in emails, which could possibly be opened on another device than the
// auth request was initiated.
func (l *Login) checkOptionalAuthRequestOfEmailLinks(r *http.Request) *domain.AuthRequest {
authReq, err := l.getAuthRequest(r)
if err == nil {
return authReq
}
logging.WithError(err).Infof("authrequest could not be found for email link on path %s", r.URL.RequestURI())
queries := r.URL.Query()
queries.Del(QueryAuthRequestID)
r.URL.RawQuery = queries.Encode()
r.RequestURI = r.URL.RequestURI()
r.Form.Del(QueryAuthRequestID)
r.PostForm.Del(QueryAuthRequestID)
return nil
}

View File

@@ -1,8 +1,8 @@
package login
import (
"fmt"
"net/http"
"net/url"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/domain"
@@ -38,14 +38,20 @@ type initPasswordData struct {
HasSymbol string
}
func InitPasswordLink(origin, userID, code, orgID string) string {
return fmt.Sprintf("%s%s?userID=%s&code=%s&orgID=%s", externalLink(origin), EndpointInitPassword, userID, code, orgID)
func InitPasswordLink(origin, userID, code, orgID, authRequestID string) string {
v := url.Values{}
v.Set(queryInitPWUserID, userID)
v.Set(queryInitPWCode, code)
v.Set(queryOrgID, orgID)
v.Set(QueryAuthRequestID, authRequestID)
return externalLink(origin) + EndpointInitPassword + "?" + v.Encode()
}
func (l *Login) handleInitPassword(w http.ResponseWriter, r *http.Request) {
authReq := l.checkOptionalAuthRequestOfEmailLinks(r)
userID := r.FormValue(queryInitPWUserID)
code := r.FormValue(queryInitPWCode)
l.renderInitPassword(w, r, nil, userID, code, nil)
l.renderInitPassword(w, r, authReq, userID, code, nil)
}
func (l *Login) handleInitPasswordCheck(w http.ResponseWriter, r *http.Request) {
@@ -94,7 +100,7 @@ func (l *Login) resendPasswordSet(w http.ResponseWriter, r *http.Request, authRe
l.renderInitPassword(w, r, authReq, userID, "", err)
return
}
_, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator)
_, err = l.command.RequestSetPassword(setContext(r.Context(), userOrg), userID, userOrg, domain.NotificationTypeEmail, passwordCodeGenerator, authReq.ID)
l.renderInitPassword(w, r, authReq, userID, "", err)
}

View File

@@ -1,8 +1,8 @@
package login
import (
"fmt"
"net/http"
"net/url"
"strconv"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
@@ -44,16 +44,24 @@ type initUserData struct {
HasSymbol string
}
func InitUserLink(origin, userID, loginName, code, orgID string, passwordSet bool) string {
return fmt.Sprintf("%s%s?userID=%s&loginname=%s&code=%s&orgID=%s&passwordset=%t", externalLink(origin), EndpointInitUser, userID, loginName, code, orgID, passwordSet)
func InitUserLink(origin, userID, loginName, code, orgID string, passwordSet bool, authRequestID string) string {
v := url.Values{}
v.Set(queryInitUserUserID, userID)
v.Set(queryInitUserLoginName, loginName)
v.Set(queryInitUserCode, code)
v.Set(queryOrgID, orgID)
v.Set(queryInitUserPassword, strconv.FormatBool(passwordSet))
v.Set(QueryAuthRequestID, authRequestID)
return externalLink(origin) + EndpointInitUser + "?" + v.Encode()
}
func (l *Login) handleInitUser(w http.ResponseWriter, r *http.Request) {
authReq := l.checkOptionalAuthRequestOfEmailLinks(r)
userID := r.FormValue(queryInitUserUserID)
code := r.FormValue(queryInitUserCode)
loginName := r.FormValue(queryInitUserLoginName)
passwordSet, _ := strconv.ParseBool(r.FormValue(queryInitUserPassword))
l.renderInitUser(w, r, nil, userID, loginName, code, passwordSet, nil)
l.renderInitUser(w, r, authReq, userID, loginName, code, passwordSet, nil)
}
func (l *Login) handleInitUserCheck(w http.ResponseWriter, r *http.Request) {
@@ -105,7 +113,7 @@ func (l *Login) resendUserInit(w http.ResponseWriter, r *http.Request, authReq *
l.renderInitUser(w, r, authReq, userID, loginName, "", showPassword, err)
return
}
_, err = l.command.ResendInitialMail(setContext(r.Context(), userOrgID), userID, "", userOrgID, initCodeGenerator)
_, err = l.command.ResendInitialMail(setContext(r.Context(), userOrgID), userID, "", userOrgID, initCodeGenerator, authReq.ID)
l.renderInitUser(w, r, authReq, userID, loginName, "", showPassword, err)
}

View File

@@ -1,8 +1,8 @@
package login
import (
"fmt"
"net/http"
"net/url"
"github.com/zitadel/zitadel/internal/domain"
)
@@ -27,18 +27,24 @@ type mailVerificationData struct {
UserID string
}
func MailVerificationLink(origin, userID, code, orgID string) string {
return fmt.Sprintf("%s%s?userID=%s&code=%s&orgID=%s", externalLink(origin), EndpointMailVerification, userID, code, orgID)
func MailVerificationLink(origin, userID, code, orgID, authRequestID string) string {
v := url.Values{}
v.Set(queryUserID, userID)
v.Set(queryCode, code)
v.Set(queryOrgID, orgID)
v.Set(QueryAuthRequestID, authRequestID)
return externalLink(origin) + EndpointMailVerification + "?" + v.Encode()
}
func (l *Login) handleMailVerification(w http.ResponseWriter, r *http.Request) {
authReq := l.checkOptionalAuthRequestOfEmailLinks(r)
userID := r.FormValue(queryUserID)
code := r.FormValue(queryCode)
if code != "" {
l.checkMailCode(w, r, nil, userID, code)
l.checkMailCode(w, r, authReq, userID, code)
return
}
l.renderMailVerification(w, r, nil, userID, nil)
l.renderMailVerification(w, r, authReq, userID, nil)
}
func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Request) {
@@ -61,7 +67,7 @@ func (l *Login) handleMailVerificationCheck(w http.ResponseWriter, r *http.Reque
l.checkMailCode(w, r, authReq, data.UserID, data.Code)
return
}
_, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator)
_, err = l.command.CreateHumanEmailVerificationCode(setContext(r.Context(), userOrg), data.UserID, userOrg, emailCodeGenerator, authReq.ID)
l.renderMailVerification(w, r, authReq, data.UserID, err)
}

View File

@@ -33,7 +33,7 @@ func (l *Login) handlePasswordReset(w http.ResponseWriter, r *http.Request) {
l.renderPasswordResetDone(w, r, authReq, err)
return
}
_, err = l.command.RequestSetPassword(setContext(r.Context(), authReq.UserOrgID), user.ID, authReq.UserOrgID, domain.NotificationTypeEmail, passwordCodeGenerator)
_, err = l.command.RequestSetPassword(setContext(r.Context(), authReq.UserOrgID), user.ID, authReq.UserOrgID, domain.NotificationTypeEmail, passwordCodeGenerator, authReq.ID)
l.renderPasswordResetDone(w, r, authReq, err)
}

View File

@@ -7,6 +7,7 @@ import (
"github.com/zitadel/zitadel/internal/api/authz"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -67,22 +68,6 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
if authRequest != nil && authRequest.RequestedOrgID != "" && authRequest.RequestedOrgID != resourceOwner {
resourceOwner = authRequest.RequestedOrgID
}
initCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeInitCode, l.userCodeAlg)
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return
}
emailCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyEmailCode, l.userCodeAlg)
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return
}
phoneCodeGenerator, err := l.query.InitEncryptionGenerator(r.Context(), domain.SecretGeneratorTypeVerifyPhoneCode, l.userCodeAlg)
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return
}
// For consistency with the external authentication flow,
// the setMetadata() function is provided on the pre creation hook, for now,
// like for the ExternalAuthentication flow.
@@ -96,22 +81,14 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
l.renderRegister(w, r, authRequest, data, err)
return
}
user, err = l.command.RegisterHuman(setContext(r.Context(), resourceOwner), resourceOwner, user, nil, nil, initCodeGenerator, emailCodeGenerator, phoneCodeGenerator)
human := command.AddHumanFromDomain(user, metadatas, authRequest, nil)
err = l.command.AddUserHuman(setContext(r.Context(), resourceOwner), resourceOwner, human, true, l.userCodeAlg)
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return
}
if len(metadatas) > 0 {
_, err = l.command.BulkSetUserMetadata(r.Context(), user.AggregateID, resourceOwner, metadatas...)
if err != nil {
// TODO: What if action is configured to be allowed to fail? Same question for external registration.
l.renderRegister(w, r, authRequest, data, err)
return
}
}
userGrants, err := l.runPostCreationActions(user.AggregateID, authRequest, r, resourceOwner, domain.FlowTypeInternalAuthentication)
userGrants, err := l.runPostCreationActions(human.ID, authRequest, r, resourceOwner, domain.FlowTypeInternalAuthentication)
if err != nil {
l.renderError(w, r, authRequest, err)
return
@@ -128,7 +105,7 @@ func (l *Login) handleRegisterCheck(w http.ResponseWriter, r *http.Request) {
return
}
userAgentID, _ := http_mw.UserAgentIDFromCtx(r.Context())
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, user.AggregateID, userAgentID)
err = l.authRepo.SelectUser(r.Context(), authRequest.ID, human.ID, userAgentID)
if err != nil {
l.renderRegister(w, r, authRequest, data, err)
return