mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 19:07:30 +00:00
feat: handle instanceID in projections (#3442)
* feat: handle instanceID in projections * rename functions * fix key lock * fix import
This commit is contained in:
@@ -99,8 +99,8 @@ func (_ *Notification) AggregateTypes() []models.AggregateType {
|
||||
return []models.AggregateType{user_repo.AggregateType}
|
||||
}
|
||||
|
||||
func (n *Notification) CurrentSequence() (uint64, error) {
|
||||
sequence, err := n.view.GetLatestNotificationSequence()
|
||||
func (n *Notification) CurrentSequence(instanceID string) (uint64, error) {
|
||||
sequence, err := n.view.GetLatestNotificationSequence(instanceID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -108,11 +108,29 @@ func (n *Notification) CurrentSequence() (uint64, error) {
|
||||
}
|
||||
|
||||
func (n *Notification) EventQuery() (*models.SearchQuery, error) {
|
||||
sequence, err := n.view.GetLatestNotificationSequence()
|
||||
sequences, err := n.view.GetLatestNotificationSequences()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return view.UserQuery(sequence.CurrentSequence), nil
|
||||
query := models.NewSearchQuery()
|
||||
instances := make([]string, 0)
|
||||
for _, sequence := range sequences {
|
||||
for _, instance := range instances {
|
||||
if sequence.InstanceID == instance {
|
||||
break
|
||||
}
|
||||
}
|
||||
instances = append(instances, sequence.InstanceID)
|
||||
query.AddQuery().
|
||||
AggregateTypeFilter(n.AggregateTypes()...).
|
||||
LatestSequenceFilter(sequence.CurrentSequence).
|
||||
InstanceIDFilter(sequence.InstanceID)
|
||||
}
|
||||
return query.AddQuery().
|
||||
AggregateTypeFilter(n.AggregateTypes()...).
|
||||
LatestSequenceFilter(0).
|
||||
ExcludedInstanceIDsFilter(instances...).
|
||||
SearchQuery(), nil
|
||||
}
|
||||
|
||||
func (n *Notification) Reduce(event *models.Event) (err error) {
|
||||
@@ -162,7 +180,7 @@ func (n *Notification) handleInitUserCode(event *models.Event) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := n.getUserByID(event.AggregateID)
|
||||
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -201,7 +219,7 @@ func (n *Notification) handlePasswordCode(event *models.Event) (err error) {
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := n.getUserByID(event.AggregateID)
|
||||
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -239,7 +257,7 @@ func (n *Notification) handleEmailVerificationCode(event *models.Event) (err err
|
||||
return err
|
||||
}
|
||||
|
||||
user, err := n.getUserByID(event.AggregateID)
|
||||
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -268,7 +286,7 @@ func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err err
|
||||
if err != nil || alreadyHandled {
|
||||
return nil
|
||||
}
|
||||
user, err := n.getUserByID(event.AggregateID)
|
||||
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -285,7 +303,7 @@ func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err err
|
||||
|
||||
func (n *Notification) handleDomainClaimed(event *models.Event) (err error) {
|
||||
ctx := getSetNotifyContextData(event.InstanceID, event.ResourceOwner)
|
||||
alreadyHandled, err := n.checkIfAlreadyHandled(ctx, event.AggregateID, event.Sequence, user_repo.UserDomainClaimedType, user_repo.UserDomainClaimedSentType)
|
||||
alreadyHandled, err := n.checkIfAlreadyHandled(ctx, event.AggregateID, event.InstanceID, event.Sequence, user_repo.UserDomainClaimedType, user_repo.UserDomainClaimedSentType)
|
||||
if err != nil || alreadyHandled {
|
||||
return nil
|
||||
}
|
||||
@@ -294,7 +312,7 @@ func (n *Notification) handleDomainClaimed(event *models.Event) (err error) {
|
||||
logging.Log("HANDLE-Gghq2").WithError(err).Error("could not unmarshal event data")
|
||||
return errors.ThrowInternal(err, "HANDLE-7hgj3", "could not unmarshal event")
|
||||
}
|
||||
user, err := n.getUserByID(event.AggregateID)
|
||||
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -329,7 +347,7 @@ func (n *Notification) handlePasswordlessRegistrationLink(event *models.Event) (
|
||||
return err
|
||||
}
|
||||
ctx := getSetNotifyContextData(event.InstanceID, event.ResourceOwner)
|
||||
events, err := n.getUserEvents(ctx, event.AggregateID, event.Sequence)
|
||||
events, err := n.getUserEvents(ctx, event.AggregateID, event.InstanceID, event.Sequence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -344,7 +362,7 @@ func (n *Notification) handlePasswordlessRegistrationLink(event *models.Event) (
|
||||
}
|
||||
}
|
||||
}
|
||||
user, err := n.getUserByID(event.AggregateID)
|
||||
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -374,11 +392,11 @@ func (n *Notification) checkIfCodeAlreadyHandledOrExpired(ctx context.Context, e
|
||||
if event.CreationDate.Add(expiry).Before(time.Now().UTC()) {
|
||||
return true, nil
|
||||
}
|
||||
return n.checkIfAlreadyHandled(ctx, event.AggregateID, event.Sequence, eventTypes...)
|
||||
return n.checkIfAlreadyHandled(ctx, event.AggregateID, event.InstanceID, event.Sequence, eventTypes...)
|
||||
}
|
||||
|
||||
func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID string, sequence uint64, eventTypes ...eventstore.EventType) (bool, error) {
|
||||
events, err := n.getUserEvents(ctx, userID, sequence)
|
||||
func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID, instanceID string, sequence uint64, eventTypes ...eventstore.EventType) (bool, error) {
|
||||
events, err := n.getUserEvents(ctx, userID, instanceID, sequence)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -392,8 +410,8 @@ func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID string,
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (n *Notification) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) {
|
||||
query, err := view.UserByIDQuery(userID, sequence)
|
||||
func (n *Notification) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) {
|
||||
query, err := view.UserByIDQuery(userID, instanceID, sequence)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -514,6 +532,6 @@ func (n *Notification) getTranslatorWithOrgTexts(ctx context.Context, orgID, tex
|
||||
return translator, nil
|
||||
}
|
||||
|
||||
func (n *Notification) getUserByID(userID string) (*model.NotifyUser, error) {
|
||||
return n.view.NotifyUserByID(userID)
|
||||
func (n *Notification) getUserByID(userID, instanceID string) (*model.NotifyUser, error) {
|
||||
return n.view.NotifyUserByID(userID, instanceID)
|
||||
}
|
||||
|
@@ -67,8 +67,8 @@ func (_ *NotifyUser) AggregateTypes() []es_models.AggregateType {
|
||||
return []es_models.AggregateType{user.AggregateType, org.AggregateType}
|
||||
}
|
||||
|
||||
func (p *NotifyUser) CurrentSequence() (uint64, error) {
|
||||
sequence, err := p.view.GetLatestNotifyUserSequence()
|
||||
func (p *NotifyUser) CurrentSequence(instanceID string) (uint64, error) {
|
||||
sequence, err := p.view.GetLatestNotifyUserSequence(instanceID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@@ -76,13 +76,29 @@ func (p *NotifyUser) CurrentSequence() (uint64, error) {
|
||||
}
|
||||
|
||||
func (p *NotifyUser) EventQuery() (*es_models.SearchQuery, error) {
|
||||
sequence, err := p.view.GetLatestNotifyUserSequence()
|
||||
sequences, err := p.view.GetLatestNotifyUserSequences()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return es_models.NewSearchQuery().
|
||||
query := es_models.NewSearchQuery()
|
||||
instances := make([]string, 0)
|
||||
for _, sequence := range sequences {
|
||||
for _, instance := range instances {
|
||||
if sequence.InstanceID == instance {
|
||||
break
|
||||
}
|
||||
}
|
||||
instances = append(instances, sequence.InstanceID)
|
||||
query.AddQuery().
|
||||
AggregateTypeFilter(p.AggregateTypes()...).
|
||||
LatestSequenceFilter(sequence.CurrentSequence).
|
||||
InstanceIDFilter(sequence.InstanceID)
|
||||
}
|
||||
return query.AddQuery().
|
||||
AggregateTypeFilter(p.AggregateTypes()...).
|
||||
LatestSequenceFilter(sequence.CurrentSequence), nil
|
||||
LatestSequenceFilter(0).
|
||||
ExcludedInstanceIDsFilter(instances...).
|
||||
SearchQuery(), nil
|
||||
}
|
||||
|
||||
func (u *NotifyUser) Reduce(event *es_models.Event) (err error) {
|
||||
@@ -122,14 +138,14 @@ func (u *NotifyUser) ProcessUser(event *es_models.Event) (err error) {
|
||||
user.HumanPhoneVerifiedType,
|
||||
user.HumanPhoneRemovedType,
|
||||
user.MachineChangedEventType:
|
||||
notifyUser, err = u.view.NotifyUserByID(event.AggregateID)
|
||||
notifyUser, err = u.view.NotifyUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = notifyUser.AppendEvent(event)
|
||||
case user.UserDomainClaimedType,
|
||||
user.UserUserNameChangedType:
|
||||
notifyUser, err = u.view.NotifyUserByID(event.AggregateID)
|
||||
notifyUser, err = u.view.NotifyUserByID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -139,7 +155,7 @@ func (u *NotifyUser) ProcessUser(event *es_models.Event) (err error) {
|
||||
}
|
||||
err = u.fillLoginNames(notifyUser)
|
||||
case user.UserRemovedType:
|
||||
return u.view.DeleteNotifyUser(event.AggregateID, event)
|
||||
return u.view.DeleteNotifyUser(event.AggregateID, event.InstanceID, event)
|
||||
default:
|
||||
return u.view.ProcessedNotifyUserSequence(event)
|
||||
}
|
||||
@@ -169,7 +185,7 @@ func (u *NotifyUser) fillLoginNamesOnOrgUsers(event *es_models.Event) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
users, err := u.view.NotifyUsersByOrgID(event.AggregateID)
|
||||
users, err := u.view.NotifyUsersByOrgID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -191,7 +207,7 @@ func (u *NotifyUser) fillPreferredLoginNamesOnOrgUsers(event *es_models.Event) e
|
||||
if !userLoginMustBeDomain {
|
||||
return nil
|
||||
}
|
||||
users, err := u.view.NotifyUsersByOrgID(event.AggregateID)
|
||||
users, err := u.view.NotifyUsersByOrgID(event.AggregateID, event.InstanceID)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@@ -2,8 +2,9 @@ package spooler
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker"
|
||||
"time"
|
||||
|
||||
es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -14,6 +15,6 @@ type locker struct {
|
||||
dbClient *sql.DB
|
||||
}
|
||||
|
||||
func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error {
|
||||
return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, waitTime)
|
||||
func (l *locker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error {
|
||||
return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, instanceID, waitTime)
|
||||
}
|
||||
|
@@ -12,6 +12,6 @@ func (v *View) saveFailedEvent(failedEvent *repository.FailedEvent) error {
|
||||
return repository.SaveFailedEvent(v.Db, errTable, failedEvent)
|
||||
}
|
||||
|
||||
func (v *View) latestFailedEvent(viewName string, sequence uint64) (*repository.FailedEvent, error) {
|
||||
return repository.LatestFailedEvent(v.Db, errTable, viewName, sequence)
|
||||
func (v *View) latestFailedEvent(viewName, instanceID string, sequence uint64) (*repository.FailedEvent, error) {
|
||||
return repository.LatestFailedEvent(v.Db, errTable, viewName, instanceID, sequence)
|
||||
}
|
||||
|
@@ -9,8 +9,12 @@ const (
|
||||
notificationTable = "notification.notifications"
|
||||
)
|
||||
|
||||
func (v *View) GetLatestNotificationSequence() (*repository.CurrentSequence, error) {
|
||||
return v.latestSequence(notificationTable)
|
||||
func (v *View) GetLatestNotificationSequence(instanceID string) (*repository.CurrentSequence, error) {
|
||||
return v.latestSequence(notificationTable, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestNotificationSequences() ([]*repository.CurrentSequence, error) {
|
||||
return v.latestSequences(notificationTable)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedNotificationSequence(event *models.Event) error {
|
||||
@@ -21,8 +25,8 @@ func (v *View) UpdateNotificationSpoolerRunTimestamp() error {
|
||||
return v.updateSpoolerRunSequence(notificationTable)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestNotificationFailedEvent(sequence uint64) (*repository.FailedEvent, error) {
|
||||
return v.latestFailedEvent(notificationTable, sequence)
|
||||
func (v *View) GetLatestNotificationFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return v.latestFailedEvent(notificationTable, instanceID, sequence)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedNotificationFailedEvent(failedEvent *repository.FailedEvent) error {
|
||||
|
@@ -12,8 +12,8 @@ const (
|
||||
notifyUserTable = "notification.notify_users"
|
||||
)
|
||||
|
||||
func (v *View) NotifyUserByID(userID string) (*model.NotifyUser, error) {
|
||||
return view.NotifyUserByID(v.Db, notifyUserTable, userID)
|
||||
func (v *View) NotifyUserByID(userID, instanceID string) (*model.NotifyUser, error) {
|
||||
return view.NotifyUserByID(v.Db, notifyUserTable, userID, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) PutNotifyUser(user *model.NotifyUser, event *models.Event) error {
|
||||
@@ -24,20 +24,24 @@ func (v *View) PutNotifyUser(user *model.NotifyUser, event *models.Event) error
|
||||
return v.ProcessedNotifyUserSequence(event)
|
||||
}
|
||||
|
||||
func (v *View) NotifyUsersByOrgID(orgID string) ([]*model.NotifyUser, error) {
|
||||
return view.NotifyUsersByOrgID(v.Db, notifyUserTable, orgID)
|
||||
func (v *View) NotifyUsersByOrgID(orgID, instanceID string) ([]*model.NotifyUser, error) {
|
||||
return view.NotifyUsersByOrgID(v.Db, notifyUserTable, orgID, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) DeleteNotifyUser(userID string, event *models.Event) error {
|
||||
err := view.DeleteNotifyUser(v.Db, notifyUserTable, userID)
|
||||
func (v *View) DeleteNotifyUser(userID, instanceID string, event *models.Event) error {
|
||||
err := view.DeleteNotifyUser(v.Db, notifyUserTable, userID, instanceID)
|
||||
if err != nil && !errors.IsNotFound(err) {
|
||||
return err
|
||||
}
|
||||
return v.ProcessedNotifyUserSequence(event)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestNotifyUserSequence() (*repository.CurrentSequence, error) {
|
||||
return v.latestSequence(notifyUserTable)
|
||||
func (v *View) GetLatestNotifyUserSequence(instanceID string) (*repository.CurrentSequence, error) {
|
||||
return v.latestSequence(notifyUserTable, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestNotifyUserSequences() ([]*repository.CurrentSequence, error) {
|
||||
return v.latestSequences(notifyUserTable)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedNotifyUserSequence(event *models.Event) error {
|
||||
@@ -48,8 +52,8 @@ func (v *View) UpdateNotifyUserSpoolerRunTimestamp() error {
|
||||
return v.updateSpoolerRunSequence(notifyUserTable)
|
||||
}
|
||||
|
||||
func (v *View) GetLatestNotifyUserFailedEvent(sequence uint64) (*repository.FailedEvent, error) {
|
||||
return v.latestFailedEvent(notifyUserTable, sequence)
|
||||
func (v *View) GetLatestNotifyUserFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) {
|
||||
return v.latestFailedEvent(notifyUserTable, instanceID, sequence)
|
||||
}
|
||||
|
||||
func (v *View) ProcessedNotifyUserFailedEvent(failedEvent *repository.FailedEvent) error {
|
||||
|
@@ -12,21 +12,27 @@ const (
|
||||
)
|
||||
|
||||
func (v *View) saveCurrentSequence(viewName string, event *models.Event) error {
|
||||
return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.Sequence, event.CreationDate)
|
||||
return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate)
|
||||
}
|
||||
|
||||
func (v *View) latestSequence(viewName string) (*repository.CurrentSequence, error) {
|
||||
return repository.LatestSequence(v.Db, sequencesTable, viewName)
|
||||
func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) {
|
||||
return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID)
|
||||
}
|
||||
|
||||
func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) {
|
||||
return repository.LatestSequences(v.Db, sequencesTable, viewName)
|
||||
}
|
||||
|
||||
func (v *View) updateSpoolerRunSequence(viewName string) error {
|
||||
currentSequence, err := repository.LatestSequence(v.Db, sequencesTable, viewName)
|
||||
currentSequences, err := repository.LatestSequences(v.Db, sequencesTable, viewName)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if currentSequence.ViewName == "" {
|
||||
currentSequence.ViewName = viewName
|
||||
for _, currentSequence := range currentSequences {
|
||||
if currentSequence.ViewName == "" {
|
||||
currentSequence.ViewName = viewName
|
||||
}
|
||||
currentSequence.LastSuccessfulSpoolerRun = time.Now()
|
||||
}
|
||||
currentSequence.LastSuccessfulSpoolerRun = time.Now()
|
||||
return repository.UpdateCurrentSequence(v.Db, sequencesTable, currentSequence)
|
||||
return repository.UpdateCurrentSequences(v.Db, sequencesTable, currentSequences)
|
||||
}
|
||||
|
Reference in New Issue
Block a user