feat: handle instanceID in projections (#3442)

* feat: handle instanceID in projections

* rename functions

* fix key lock

* fix import
This commit is contained in:
Livio Amstutz
2022-04-19 08:26:12 +02:00
committed by GitHub
parent c25d853820
commit 1305c14e49
120 changed files with 2078 additions and 1209 deletions

View File

@@ -99,8 +99,8 @@ func (_ *Notification) AggregateTypes() []models.AggregateType {
return []models.AggregateType{user_repo.AggregateType}
}
func (n *Notification) CurrentSequence() (uint64, error) {
sequence, err := n.view.GetLatestNotificationSequence()
func (n *Notification) CurrentSequence(instanceID string) (uint64, error) {
sequence, err := n.view.GetLatestNotificationSequence(instanceID)
if err != nil {
return 0, err
}
@@ -108,11 +108,29 @@ func (n *Notification) CurrentSequence() (uint64, error) {
}
func (n *Notification) EventQuery() (*models.SearchQuery, error) {
sequence, err := n.view.GetLatestNotificationSequence()
sequences, err := n.view.GetLatestNotificationSequences()
if err != nil {
return nil, err
}
return view.UserQuery(sequence.CurrentSequence), nil
query := models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(n.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(n.AggregateTypes()...).
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
}
func (n *Notification) Reduce(event *models.Event) (err error) {
@@ -162,7 +180,7 @@ func (n *Notification) handleInitUserCode(event *models.Event) (err error) {
return err
}
user, err := n.getUserByID(event.AggregateID)
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -201,7 +219,7 @@ func (n *Notification) handlePasswordCode(event *models.Event) (err error) {
return err
}
user, err := n.getUserByID(event.AggregateID)
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -239,7 +257,7 @@ func (n *Notification) handleEmailVerificationCode(event *models.Event) (err err
return err
}
user, err := n.getUserByID(event.AggregateID)
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -268,7 +286,7 @@ func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err err
if err != nil || alreadyHandled {
return nil
}
user, err := n.getUserByID(event.AggregateID)
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -285,7 +303,7 @@ func (n *Notification) handlePhoneVerificationCode(event *models.Event) (err err
func (n *Notification) handleDomainClaimed(event *models.Event) (err error) {
ctx := getSetNotifyContextData(event.InstanceID, event.ResourceOwner)
alreadyHandled, err := n.checkIfAlreadyHandled(ctx, event.AggregateID, event.Sequence, user_repo.UserDomainClaimedType, user_repo.UserDomainClaimedSentType)
alreadyHandled, err := n.checkIfAlreadyHandled(ctx, event.AggregateID, event.InstanceID, event.Sequence, user_repo.UserDomainClaimedType, user_repo.UserDomainClaimedSentType)
if err != nil || alreadyHandled {
return nil
}
@@ -294,7 +312,7 @@ func (n *Notification) handleDomainClaimed(event *models.Event) (err error) {
logging.Log("HANDLE-Gghq2").WithError(err).Error("could not unmarshal event data")
return errors.ThrowInternal(err, "HANDLE-7hgj3", "could not unmarshal event")
}
user, err := n.getUserByID(event.AggregateID)
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -329,7 +347,7 @@ func (n *Notification) handlePasswordlessRegistrationLink(event *models.Event) (
return err
}
ctx := getSetNotifyContextData(event.InstanceID, event.ResourceOwner)
events, err := n.getUserEvents(ctx, event.AggregateID, event.Sequence)
events, err := n.getUserEvents(ctx, event.AggregateID, event.InstanceID, event.Sequence)
if err != nil {
return err
}
@@ -344,7 +362,7 @@ func (n *Notification) handlePasswordlessRegistrationLink(event *models.Event) (
}
}
}
user, err := n.getUserByID(event.AggregateID)
user, err := n.getUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -374,11 +392,11 @@ func (n *Notification) checkIfCodeAlreadyHandledOrExpired(ctx context.Context, e
if event.CreationDate.Add(expiry).Before(time.Now().UTC()) {
return true, nil
}
return n.checkIfAlreadyHandled(ctx, event.AggregateID, event.Sequence, eventTypes...)
return n.checkIfAlreadyHandled(ctx, event.AggregateID, event.InstanceID, event.Sequence, eventTypes...)
}
func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID string, sequence uint64, eventTypes ...eventstore.EventType) (bool, error) {
events, err := n.getUserEvents(ctx, userID, sequence)
func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID, instanceID string, sequence uint64, eventTypes ...eventstore.EventType) (bool, error) {
events, err := n.getUserEvents(ctx, userID, instanceID, sequence)
if err != nil {
return false, err
}
@@ -392,8 +410,8 @@ func (n *Notification) checkIfAlreadyHandled(ctx context.Context, userID string,
return false, nil
}
func (n *Notification) getUserEvents(ctx context.Context, userID string, sequence uint64) ([]*models.Event, error) {
query, err := view.UserByIDQuery(userID, sequence)
func (n *Notification) getUserEvents(ctx context.Context, userID, instanceID string, sequence uint64) ([]*models.Event, error) {
query, err := view.UserByIDQuery(userID, instanceID, sequence)
if err != nil {
return nil, err
}
@@ -514,6 +532,6 @@ func (n *Notification) getTranslatorWithOrgTexts(ctx context.Context, orgID, tex
return translator, nil
}
func (n *Notification) getUserByID(userID string) (*model.NotifyUser, error) {
return n.view.NotifyUserByID(userID)
func (n *Notification) getUserByID(userID, instanceID string) (*model.NotifyUser, error) {
return n.view.NotifyUserByID(userID, instanceID)
}

View File

@@ -67,8 +67,8 @@ func (_ *NotifyUser) AggregateTypes() []es_models.AggregateType {
return []es_models.AggregateType{user.AggregateType, org.AggregateType}
}
func (p *NotifyUser) CurrentSequence() (uint64, error) {
sequence, err := p.view.GetLatestNotifyUserSequence()
func (p *NotifyUser) CurrentSequence(instanceID string) (uint64, error) {
sequence, err := p.view.GetLatestNotifyUserSequence(instanceID)
if err != nil {
return 0, err
}
@@ -76,13 +76,29 @@ func (p *NotifyUser) CurrentSequence() (uint64, error) {
}
func (p *NotifyUser) EventQuery() (*es_models.SearchQuery, error) {
sequence, err := p.view.GetLatestNotifyUserSequence()
sequences, err := p.view.GetLatestNotifyUserSequences()
if err != nil {
return nil, err
}
return es_models.NewSearchQuery().
query := es_models.NewSearchQuery()
instances := make([]string, 0)
for _, sequence := range sequences {
for _, instance := range instances {
if sequence.InstanceID == instance {
break
}
}
instances = append(instances, sequence.InstanceID)
query.AddQuery().
AggregateTypeFilter(p.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence).
InstanceIDFilter(sequence.InstanceID)
}
return query.AddQuery().
AggregateTypeFilter(p.AggregateTypes()...).
LatestSequenceFilter(sequence.CurrentSequence), nil
LatestSequenceFilter(0).
ExcludedInstanceIDsFilter(instances...).
SearchQuery(), nil
}
func (u *NotifyUser) Reduce(event *es_models.Event) (err error) {
@@ -122,14 +138,14 @@ func (u *NotifyUser) ProcessUser(event *es_models.Event) (err error) {
user.HumanPhoneVerifiedType,
user.HumanPhoneRemovedType,
user.MachineChangedEventType:
notifyUser, err = u.view.NotifyUserByID(event.AggregateID)
notifyUser, err = u.view.NotifyUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
err = notifyUser.AppendEvent(event)
case user.UserDomainClaimedType,
user.UserUserNameChangedType:
notifyUser, err = u.view.NotifyUserByID(event.AggregateID)
notifyUser, err = u.view.NotifyUserByID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -139,7 +155,7 @@ func (u *NotifyUser) ProcessUser(event *es_models.Event) (err error) {
}
err = u.fillLoginNames(notifyUser)
case user.UserRemovedType:
return u.view.DeleteNotifyUser(event.AggregateID, event)
return u.view.DeleteNotifyUser(event.AggregateID, event.InstanceID, event)
default:
return u.view.ProcessedNotifyUserSequence(event)
}
@@ -169,7 +185,7 @@ func (u *NotifyUser) fillLoginNamesOnOrgUsers(event *es_models.Event) error {
if err != nil {
return err
}
users, err := u.view.NotifyUsersByOrgID(event.AggregateID)
users, err := u.view.NotifyUsersByOrgID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}
@@ -191,7 +207,7 @@ func (u *NotifyUser) fillPreferredLoginNamesOnOrgUsers(event *es_models.Event) e
if !userLoginMustBeDomain {
return nil
}
users, err := u.view.NotifyUsersByOrgID(event.AggregateID)
users, err := u.view.NotifyUsersByOrgID(event.AggregateID, event.InstanceID)
if err != nil {
return err
}

View File

@@ -2,8 +2,9 @@ package spooler
import (
"database/sql"
es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker"
"time"
es_locker "github.com/caos/zitadel/internal/eventstore/v1/locker"
)
const (
@@ -14,6 +15,6 @@ type locker struct {
dbClient *sql.DB
}
func (l *locker) Renew(lockerID, viewModel string, waitTime time.Duration) error {
return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, waitTime)
func (l *locker) Renew(lockerID, viewModel, instanceID string, waitTime time.Duration) error {
return es_locker.Renew(l.dbClient, lockTable, lockerID, viewModel, instanceID, waitTime)
}

View File

@@ -12,6 +12,6 @@ func (v *View) saveFailedEvent(failedEvent *repository.FailedEvent) error {
return repository.SaveFailedEvent(v.Db, errTable, failedEvent)
}
func (v *View) latestFailedEvent(viewName string, sequence uint64) (*repository.FailedEvent, error) {
return repository.LatestFailedEvent(v.Db, errTable, viewName, sequence)
func (v *View) latestFailedEvent(viewName, instanceID string, sequence uint64) (*repository.FailedEvent, error) {
return repository.LatestFailedEvent(v.Db, errTable, viewName, instanceID, sequence)
}

View File

@@ -9,8 +9,12 @@ const (
notificationTable = "notification.notifications"
)
func (v *View) GetLatestNotificationSequence() (*repository.CurrentSequence, error) {
return v.latestSequence(notificationTable)
func (v *View) GetLatestNotificationSequence(instanceID string) (*repository.CurrentSequence, error) {
return v.latestSequence(notificationTable, instanceID)
}
func (v *View) GetLatestNotificationSequences() ([]*repository.CurrentSequence, error) {
return v.latestSequences(notificationTable)
}
func (v *View) ProcessedNotificationSequence(event *models.Event) error {
@@ -21,8 +25,8 @@ func (v *View) UpdateNotificationSpoolerRunTimestamp() error {
return v.updateSpoolerRunSequence(notificationTable)
}
func (v *View) GetLatestNotificationFailedEvent(sequence uint64) (*repository.FailedEvent, error) {
return v.latestFailedEvent(notificationTable, sequence)
func (v *View) GetLatestNotificationFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) {
return v.latestFailedEvent(notificationTable, instanceID, sequence)
}
func (v *View) ProcessedNotificationFailedEvent(failedEvent *repository.FailedEvent) error {

View File

@@ -12,8 +12,8 @@ const (
notifyUserTable = "notification.notify_users"
)
func (v *View) NotifyUserByID(userID string) (*model.NotifyUser, error) {
return view.NotifyUserByID(v.Db, notifyUserTable, userID)
func (v *View) NotifyUserByID(userID, instanceID string) (*model.NotifyUser, error) {
return view.NotifyUserByID(v.Db, notifyUserTable, userID, instanceID)
}
func (v *View) PutNotifyUser(user *model.NotifyUser, event *models.Event) error {
@@ -24,20 +24,24 @@ func (v *View) PutNotifyUser(user *model.NotifyUser, event *models.Event) error
return v.ProcessedNotifyUserSequence(event)
}
func (v *View) NotifyUsersByOrgID(orgID string) ([]*model.NotifyUser, error) {
return view.NotifyUsersByOrgID(v.Db, notifyUserTable, orgID)
func (v *View) NotifyUsersByOrgID(orgID, instanceID string) ([]*model.NotifyUser, error) {
return view.NotifyUsersByOrgID(v.Db, notifyUserTable, orgID, instanceID)
}
func (v *View) DeleteNotifyUser(userID string, event *models.Event) error {
err := view.DeleteNotifyUser(v.Db, notifyUserTable, userID)
func (v *View) DeleteNotifyUser(userID, instanceID string, event *models.Event) error {
err := view.DeleteNotifyUser(v.Db, notifyUserTable, userID, instanceID)
if err != nil && !errors.IsNotFound(err) {
return err
}
return v.ProcessedNotifyUserSequence(event)
}
func (v *View) GetLatestNotifyUserSequence() (*repository.CurrentSequence, error) {
return v.latestSequence(notifyUserTable)
func (v *View) GetLatestNotifyUserSequence(instanceID string) (*repository.CurrentSequence, error) {
return v.latestSequence(notifyUserTable, instanceID)
}
func (v *View) GetLatestNotifyUserSequences() ([]*repository.CurrentSequence, error) {
return v.latestSequences(notifyUserTable)
}
func (v *View) ProcessedNotifyUserSequence(event *models.Event) error {
@@ -48,8 +52,8 @@ func (v *View) UpdateNotifyUserSpoolerRunTimestamp() error {
return v.updateSpoolerRunSequence(notifyUserTable)
}
func (v *View) GetLatestNotifyUserFailedEvent(sequence uint64) (*repository.FailedEvent, error) {
return v.latestFailedEvent(notifyUserTable, sequence)
func (v *View) GetLatestNotifyUserFailedEvent(sequence uint64, instanceID string) (*repository.FailedEvent, error) {
return v.latestFailedEvent(notifyUserTable, instanceID, sequence)
}
func (v *View) ProcessedNotifyUserFailedEvent(failedEvent *repository.FailedEvent) error {

View File

@@ -12,21 +12,27 @@ const (
)
func (v *View) saveCurrentSequence(viewName string, event *models.Event) error {
return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.Sequence, event.CreationDate)
return repository.SaveCurrentSequence(v.Db, sequencesTable, viewName, event.InstanceID, event.Sequence, event.CreationDate)
}
func (v *View) latestSequence(viewName string) (*repository.CurrentSequence, error) {
return repository.LatestSequence(v.Db, sequencesTable, viewName)
func (v *View) latestSequence(viewName, instanceID string) (*repository.CurrentSequence, error) {
return repository.LatestSequence(v.Db, sequencesTable, viewName, instanceID)
}
func (v *View) latestSequences(viewName string) ([]*repository.CurrentSequence, error) {
return repository.LatestSequences(v.Db, sequencesTable, viewName)
}
func (v *View) updateSpoolerRunSequence(viewName string) error {
currentSequence, err := repository.LatestSequence(v.Db, sequencesTable, viewName)
currentSequences, err := repository.LatestSequences(v.Db, sequencesTable, viewName)
if err != nil {
return err
}
if currentSequence.ViewName == "" {
currentSequence.ViewName = viewName
for _, currentSequence := range currentSequences {
if currentSequence.ViewName == "" {
currentSequence.ViewName = viewName
}
currentSequence.LastSuccessfulSpoolerRun = time.Now()
}
currentSequence.LastSuccessfulSpoolerRun = time.Now()
return repository.UpdateCurrentSequence(v.Db, sequencesTable, currentSequence)
return repository.UpdateCurrentSequences(v.Db, sequencesTable, currentSequences)
}