diff --git a/internal/notification/repository/eventsourcing/handler/notification.go b/internal/notification/repository/eventsourcing/handler/notification.go index 99e79effa4..e591b2fd15 100644 --- a/internal/notification/repository/eventsourcing/handler/notification.go +++ b/internal/notification/repository/eventsourcing/handler/notification.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net/http" + "time" "github.com/caos/logging" @@ -79,7 +80,13 @@ func (n *Notification) Reduce(event *models.Event) (err error) { } func (n *Notification) handleInitUserCode(event *models.Event) (err error) { - alreadyHandled, err := n.checkIfCodeAlreadyHandled(event.AggregateID, event.Sequence, es_model.InitializedUserCodeAdded, es_model.InitializedUserCodeSent) + initCode := new(es_model.InitUserCode) + if err := initCode.SetData(event); err != nil { + return err + } + alreadyHandled, err := n.checkIfCodeAlreadyHandledOrExpired(event, initCode.Expiry, + es_model.InitializedUserCodeAdded, es_model.InitializedUserCodeSent, + es_model.InitializedHumanCodeAdded, es_model.InitializedHumanCodeSent) if err != nil || alreadyHandled { return err } @@ -89,8 +96,6 @@ func (n *Notification) handleInitUserCode(event *models.Event) (err error) { return err } - initCode := new(es_model.InitUserCode) - initCode.SetData(event) user, err := n.view.NotifyUserByID(event.AggregateID) if err != nil { return err @@ -103,7 +108,13 @@ func (n *Notification) handleInitUserCode(event *models.Event) (err error) { } func (n *Notification) handlePasswordCode(event *models.Event) (err error) { - alreadyHandled, err := n.checkIfCodeAlreadyHandled(event.AggregateID, event.Sequence, es_model.UserPasswordCodeAdded, es_model.UserPasswordCodeSent) + pwCode := new(es_model.PasswordCode) + if err := pwCode.SetData(event); err != nil { + return err + } + alreadyHandled, err := n.checkIfCodeAlreadyHandledOrExpired(event, pwCode.Expiry, + es_model.UserPasswordCodeAdded, es_model.UserPasswordCodeSent, + es_model.HumanPasswordCodeAdded, es_model.HumanPasswordCodeSent) if err != nil || alreadyHandled { return err } @@ -113,8 +124,6 @@ func (n *Notification) handlePasswordCode(event *models.Event) (err error) { return err } - pwCode := new(es_model.PasswordCode) - pwCode.SetData(event) user, err := n.view.NotifyUserByID(event.AggregateID) if err != nil { return err @@ -127,7 +136,13 @@ func (n *Notification) handlePasswordCode(event *models.Event) (err error) { } func (n *Notification) handleEmailVerificationCode(event *models.Event) (err error) { - alreadyHandled, err := n.checkIfCodeAlreadyHandled(event.AggregateID, event.Sequence, es_model.UserEmailCodeAdded, es_model.UserEmailCodeSent) + emailCode := new(es_model.EmailCode) + if err := emailCode.SetData(event); err != nil { + return err + } + alreadyHandled, err := n.checkIfCodeAlreadyHandledOrExpired(event, emailCode.Expiry, + es_model.UserEmailCodeAdded, es_model.UserEmailCodeSent, + es_model.HumanEmailCodeAdded, es_model.HumanEmailCodeSent) if err != nil || alreadyHandled { return nil } @@ -137,8 +152,6 @@ func (n *Notification) handleEmailVerificationCode(event *models.Event) (err err return err } - emailCode := new(es_model.EmailCode) - emailCode.SetData(event) user, err := n.view.NotifyUserByID(event.AggregateID) if err != nil { return err @@ -151,12 +164,16 @@ func (n *Notification) handleEmailVerificationCode(event *models.Event) (err err } func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err error) { - alreadyHandled, err := n.checkIfCodeAlreadyHandled(event.AggregateID, event.Sequence, es_model.UserPhoneCodeAdded, es_model.UserPhoneCodeSent) + phoneCode := new(es_model.PhoneCode) + if err := phoneCode.SetData(event); err != nil { + return err + } + alreadyHandled, err := n.checkIfCodeAlreadyHandledOrExpired(event, phoneCode.Expiry, + es_model.UserPhoneCodeAdded, es_model.UserPhoneCodeSent, + es_model.HumanPhoneCodeAdded, es_model.HumanPhoneCodeSent) if err != nil || alreadyHandled { return nil } - phoneCode := new(es_model.PhoneCode) - phoneCode.SetData(event) user, err := n.view.NotifyUserByID(event.AggregateID) if err != nil { return err @@ -169,7 +186,7 @@ func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err err } func (n *Notification) handleDomainClaimed(event *models.Event) (err error) { - alreadyHandled, err := n.checkIfCodeAlreadyHandled(event.AggregateID, event.Sequence, es_model.DomainClaimed, es_model.DomainClaimedSent) + alreadyHandled, err := n.checkIfAlreadyHandled(event.AggregateID, event.Sequence, es_model.DomainClaimed, es_model.DomainClaimedSent) if err != nil || alreadyHandled { return nil } @@ -189,14 +206,23 @@ func (n *Notification) handleDomainClaimed(event *models.Event) (err error) { return n.userEvents.DomainClaimedSent(getSetNotifyContextData(event.ResourceOwner), event.AggregateID) } -func (n *Notification) checkIfCodeAlreadyHandled(userID string, sequence uint64, addedType, sentType models.EventType) (bool, error) { +func (n *Notification) checkIfCodeAlreadyHandledOrExpired(event *models.Event, expiry time.Duration, eventTypes ...models.EventType) (bool, error) { + if event.CreationDate.Add(expiry).Before(time.Now().UTC()) { + return true, nil + } + return n.checkIfAlreadyHandled(event.AggregateID, event.Sequence, eventTypes...) +} + +func (n *Notification) checkIfAlreadyHandled(userID string, sequence uint64, eventTypes ...models.EventType) (bool, error) { events, err := n.getUserEvents(userID, sequence) if err != nil { return false, err } for _, event := range events { - if event.Type == addedType || event.Type == sentType { - return true, nil + for _, eventType := range eventTypes { + if event.Type == eventType { + return true, nil + } } } return false, nil