refactor(notification): use new queue package (#9360)

# Which Problems Are Solved

The recently introduced notification queue have potential race conditions.

# How the Problems Are Solved

Current code is refactored to use the queue package, which is safe in
regards of concurrency.

# Additional Changes

- the queue is included in startup
- improved code quality of queue

# Additional Context

- closes https://github.com/zitadel/zitadel/issues/9278
This commit is contained in:
Silvan 2025-02-27 11:49:12 +01:00 committed by GitHub
parent 83614562a2
commit 444f682e25
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
45 changed files with 1936 additions and 2818 deletions

View File

@ -416,12 +416,10 @@ Projections:
TransactionDuration: 0s TransactionDuration: 0s
BulkLimit: 2000 BulkLimit: 2000
# The Notifications projection is used for sending emails and SMS to users # The Notifications projection is used for preparing the messages (emails and SMS) to be sent to users
Notifications: Notifications:
# As notification projections don't result in database statements, retries don't have an effect # As notification projections don't result in database statements, retries don't have an effect
MaxFailureCount: 10 # ZITADEL_PROJECTIONS_CUSTOMIZATIONS_NOTIFICATIONS_MAXFAILURECOUNT MaxFailureCount: 10 # ZITADEL_PROJECTIONS_CUSTOMIZATIONS_NOTIFICATIONS_MAXFAILURECOUNT
# Sending emails can take longer than 500ms
TransactionDuration: 5s # ZITADEL_PROJECTIONS_CUSTOMIZATIONS_NOTIFICATIONS_TRANSACTIONDURATION
password_complexities: password_complexities:
TransactionDuration: 2s # ZITADEL_PROJECTIONS_CUSTOMIZATIONS_PASSWORD_COMPLEXITIES_TRANSACTIONDURATION TransactionDuration: 2s # ZITADEL_PROJECTIONS_CUSTOMIZATIONS_PASSWORD_COMPLEXITIES_TRANSACTIONDURATION
lockout_policy: lockout_policy:
@ -453,34 +451,12 @@ Notifications:
# If set to 0, no notification request events will be handled. This can be useful when running in # If set to 0, no notification request events will be handled. This can be useful when running in
# multi binary / pod setup and allowing only certain executables to process the events. # multi binary / pod setup and allowing only certain executables to process the events.
Workers: 1 # ZITADEL_NOTIFIACATIONS_WORKERS Workers: 1 # ZITADEL_NOTIFIACATIONS_WORKERS
# The amount of events a single worker will process in a run. # The maximum duration a job can do it's work before it is considered as failed.
BulkLimit: 10 # ZITADEL_NOTIFIACATIONS_BULKLIMIT
# Time interval between scheduled notifications for request events
RequeueEvery: 5s # ZITADEL_NOTIFIACATIONS_REQUEUEEVERY
# The amount of workers processing the notification retry events.
# If set to 0, no notification retry events will be handled. This can be useful when running in
# multi binary / pod setup and allowing only certain executables to process the events.
RetryWorkers: 1 # ZITADEL_NOTIFIACATIONS_RETRYWORKERS
# Time interval between scheduled notifications for retry events
RetryRequeueEvery: 5s # ZITADEL_NOTIFIACATIONS_RETRYREQUEUEEVERY
# Only instances are projected, for which at least a projection-relevant event exists within the timeframe
# from HandleActiveInstances duration in the past until the projection's current time
# If set to 0 (default), every instance is always considered active
HandleActiveInstances: 0s # ZITADEL_NOTIFIACATIONS_HANDLEACTIVEINSTANCES
# The maximum duration a transaction remains open
# before it spots left folding additional events
# and updates the table.
TransactionDuration: 10s # ZITADEL_NOTIFIACATIONS_TRANSACTIONDURATION TransactionDuration: 10s # ZITADEL_NOTIFIACATIONS_TRANSACTIONDURATION
# Automatically cancel the notification after the amount of failed attempts # Automatically cancel the notification after the amount of failed attempts
MaxAttempts: 3 # ZITADEL_NOTIFIACATIONS_MAXATTEMPTS MaxAttempts: 3 # ZITADEL_NOTIFIACATIONS_MAXATTEMPTS
# Automatically cancel the notification if it cannot be handled within a specific time # Automatically cancel the notification if it cannot be handled within a specific time
MaxTtl: 5m # ZITADEL_NOTIFIACATIONS_MAXTTL MaxTtl: 5m # ZITADEL_NOTIFIACATIONS_MAXTTL
# Failed attempts are retried after a confogired delay (with exponential backoff).
# Set a minimum and maximum delay and a factor for the backoff
MinRetryDelay: 5s # ZITADEL_NOTIFIACATIONS_MINRETRYDELAY
MaxRetryDelay: 1m # ZITADEL_NOTIFIACATIONS_MAXRETRYDELAY
# Any factor below 1 will be set to 1
RetryDelayFactor: 1.5 # ZITADEL_NOTIFIACATIONS_RETRYDELAYFACTOR
Auth: Auth:
# See Projections.BulkLimit # See Projections.BulkLimit

View File

@ -221,6 +221,7 @@ func projections(
keys.OIDC, keys.OIDC,
config.OIDC.DefaultBackChannelLogoutLifetime, config.OIDC.DefaultBackChannelLogoutLifetime,
client, client,
nil,
) )
config.Auth.Spooler.Client = client config.Auth.Spooler.Client = client

View File

@ -16,7 +16,7 @@ func (mig *RiverMigrateRepeatable) Execute(ctx context.Context, _ eventstore.Eve
if mig.client.Type() != "postgres" { if mig.client.Type() != "postgres" {
return nil return nil
} }
return queue.New(mig.client).ExecuteMigrations(ctx) return queue.NewMigrator(mig.client).Execute(ctx)
} }
func (mig *RiverMigrateRepeatable) String() string { func (mig *RiverMigrateRepeatable) String() string {

View File

@ -37,6 +37,7 @@ import (
notify_handler "github.com/zitadel/zitadel/internal/notification" notify_handler "github.com/zitadel/zitadel/internal/notification"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/queue"
es_v4 "github.com/zitadel/zitadel/internal/v2/eventstore" es_v4 "github.com/zitadel/zitadel/internal/v2/eventstore"
es_v4_pg "github.com/zitadel/zitadel/internal/v2/eventstore/postgres" es_v4_pg "github.com/zitadel/zitadel/internal/v2/eventstore/postgres"
"github.com/zitadel/zitadel/internal/webauthn" "github.com/zitadel/zitadel/internal/webauthn"
@ -466,6 +467,10 @@ func startCommandsQueries(
config.DefaultInstance.SecretGenerators, config.DefaultInstance.SecretGenerators,
) )
logging.OnError(err).Fatal("unable to start commands") logging.OnError(err).Fatal("unable to start commands")
q, err := queue.NewQueue(&queue.Config{
Client: dbClient,
})
logging.OnError(err).Fatal("unable to start queue")
notify_handler.Register( notify_handler.Register(
ctx, ctx,
@ -489,6 +494,7 @@ func startCommandsQueries(
keys.OIDC, keys.OIDC,
config.OIDC.DefaultBackChannelLogoutLifetime, config.OIDC.DefaultBackChannelLogoutLifetime,
dbClient, dbClient,
q,
) )
return commands, queries, adminView, authView return commands, queries, adminView, authView

View File

@ -92,6 +92,7 @@ import (
"github.com/zitadel/zitadel/internal/net" "github.com/zitadel/zitadel/internal/net"
"github.com/zitadel/zitadel/internal/notification" "github.com/zitadel/zitadel/internal/notification"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/queue"
"github.com/zitadel/zitadel/internal/static" "github.com/zitadel/zitadel/internal/static"
es_v4 "github.com/zitadel/zitadel/internal/v2/eventstore" es_v4 "github.com/zitadel/zitadel/internal/v2/eventstore"
es_v4_pg "github.com/zitadel/zitadel/internal/v2/eventstore/postgres" es_v4_pg "github.com/zitadel/zitadel/internal/v2/eventstore/postgres"
@ -267,6 +268,13 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
actionsLogstoreSvc := logstore.New(queries, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter) actionsLogstoreSvc := logstore.New(queries, actionsExecutionDBEmitter, actionsExecutionStdoutEmitter)
actions.SetLogstoreService(actionsLogstoreSvc) actions.SetLogstoreService(actionsLogstoreSvc)
q, err := queue.NewQueue(&queue.Config{
Client: dbClient,
})
if err != nil {
return err
}
notification.Register( notification.Register(
ctx, ctx,
config.Projections.Customizations["notifications"], config.Projections.Customizations["notifications"],
@ -289,9 +297,14 @@ func startZitadel(ctx context.Context, config *Config, masterKey string, server
keys.OIDC, keys.OIDC,
config.OIDC.DefaultBackChannelLogoutLifetime, config.OIDC.DefaultBackChannelLogoutLifetime,
dbClient, dbClient,
q,
) )
notification.Start(ctx) notification.Start(ctx)
if err = q.Start(ctx); err != nil {
return err
}
router := mux.NewRouter() router := mux.NewRouter()
tlsConfig, err := config.TLS.Config() tlsConfig, err := config.TLS.Config()
if err != nil { if err != nil {

View File

@ -851,9 +851,6 @@ func (c *Commands) prepareSetDefaultLanguage(a *instance.Aggregate, defaultLangu
if err := domain.LanguageIsAllowed(false, restrictionsWM.allowedLanguages, defaultLanguage); err != nil { if err := domain.LanguageIsAllowed(false, restrictionsWM.allowedLanguages, defaultLanguage); err != nil {
return nil, err return nil, err
} }
if err != nil {
return nil, err
}
return []eventstore.Command{instance.NewDefaultLanguageSetEvent(ctx, &a.Aggregate, defaultLanguage)}, nil return []eventstore.Command{instance.NewDefaultLanguageSetEvent(ctx, &a.Aggregate, defaultLanguage)}, nil
}, nil }, nil
} }

View File

@ -1,162 +0,0 @@
package command
import (
"context"
"database/sql"
"time"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/notification"
)
type NotificationRequest struct {
UserID string
UserResourceOwner string
TriggerOrigin string
URLTemplate string
Code *crypto.CryptoValue
CodeExpiry time.Duration
EventType eventstore.EventType
NotificationType domain.NotificationType
MessageType string
UnverifiedNotificationChannel bool
Args *domain.NotificationArguments
AggregateID string
AggregateResourceOwner string
IsOTP bool
RequiresPreviousDomain bool
}
type NotificationRetryRequest struct {
NotificationRequest
BackOff time.Duration
NotifyUser *query.NotifyUser
}
func NewNotificationRequest(
userID, resourceOwner, triggerOrigin string,
eventType eventstore.EventType,
notificationType domain.NotificationType,
messageType string,
) *NotificationRequest {
return &NotificationRequest{
UserID: userID,
UserResourceOwner: resourceOwner,
TriggerOrigin: triggerOrigin,
EventType: eventType,
NotificationType: notificationType,
MessageType: messageType,
}
}
func (r *NotificationRequest) WithCode(code *crypto.CryptoValue, expiry time.Duration) *NotificationRequest {
r.Code = code
r.CodeExpiry = expiry
return r
}
func (r *NotificationRequest) WithURLTemplate(urlTemplate string) *NotificationRequest {
r.URLTemplate = urlTemplate
return r
}
func (r *NotificationRequest) WithUnverifiedChannel() *NotificationRequest {
r.UnverifiedNotificationChannel = true
return r
}
func (r *NotificationRequest) WithArgs(args *domain.NotificationArguments) *NotificationRequest {
r.Args = args
return r
}
func (r *NotificationRequest) WithAggregate(id, resourceOwner string) *NotificationRequest {
r.AggregateID = id
r.AggregateResourceOwner = resourceOwner
return r
}
func (r *NotificationRequest) WithOTP() *NotificationRequest {
r.IsOTP = true
return r
}
func (r *NotificationRequest) WithPreviousDomain() *NotificationRequest {
r.RequiresPreviousDomain = true
return r
}
// RequestNotification writes a new notification.RequestEvent with the notification.Aggregate to the eventstore
func (c *Commands) RequestNotification(
ctx context.Context,
resourceOwner string,
request *NotificationRequest,
) error {
id, err := c.idGenerator.Next()
if err != nil {
return err
}
_, err = c.eventstore.Push(ctx, notification.NewRequestedEvent(ctx, &notification.NewAggregate(id, resourceOwner).Aggregate,
request.UserID,
request.UserResourceOwner,
request.AggregateID,
request.AggregateResourceOwner,
request.TriggerOrigin,
request.URLTemplate,
request.Code,
request.CodeExpiry,
request.EventType,
request.NotificationType,
request.MessageType,
request.UnverifiedNotificationChannel,
request.IsOTP,
request.RequiresPreviousDomain,
request.Args))
return err
}
// NotificationCanceled writes a new notification.CanceledEvent with the notification.Aggregate to the eventstore
func (c *Commands) NotificationCanceled(ctx context.Context, tx *sql.Tx, id, resourceOwner string, requestError error) error {
var errorMessage string
if requestError != nil {
errorMessage = requestError.Error()
}
_, err := c.eventstore.PushWithClient(ctx, tx, notification.NewCanceledEvent(ctx, &notification.NewAggregate(id, resourceOwner).Aggregate, errorMessage))
return err
}
// NotificationSent writes a new notification.SentEvent with the notification.Aggregate to the eventstore
func (c *Commands) NotificationSent(ctx context.Context, tx *sql.Tx, id, resourceOwner string) error {
_, err := c.eventstore.PushWithClient(ctx, tx, notification.NewSentEvent(ctx, &notification.NewAggregate(id, resourceOwner).Aggregate))
return err
}
// NotificationRetryRequested writes a new notification.RetryRequestEvent with the notification.Aggregate to the eventstore
func (c *Commands) NotificationRetryRequested(ctx context.Context, tx *sql.Tx, id, resourceOwner string, request *NotificationRetryRequest, requestError error) error {
var errorMessage string
if requestError != nil {
errorMessage = requestError.Error()
}
_, err := c.eventstore.PushWithClient(ctx, tx, notification.NewRetryRequestedEvent(ctx, &notification.NewAggregate(id, resourceOwner).Aggregate,
request.UserID,
request.UserResourceOwner,
request.AggregateID,
request.AggregateResourceOwner,
request.TriggerOrigin,
request.URLTemplate,
request.Code,
request.CodeExpiry,
request.EventType,
request.NotificationType,
request.MessageType,
request.UnverifiedNotificationChannel,
request.IsOTP,
request.Args,
request.NotifyUser,
request.BackOff,
errorMessage))
return err
}

View File

@ -78,15 +78,15 @@ func AggregateFromWriteModelCtx(
// Aggregate is the basic implementation of Aggregater // Aggregate is the basic implementation of Aggregater
type Aggregate struct { type Aggregate struct {
// ID is the unique identitfier of this aggregate // ID is the unique identitfier of this aggregate
ID string `json:"-"` ID string `json:"id"`
// Type is the name of the aggregate. // Type is the name of the aggregate.
Type AggregateType `json:"-"` Type AggregateType `json:"type"`
// ResourceOwner is the org this aggregates belongs to // ResourceOwner is the org this aggregates belongs to
ResourceOwner string `json:"-"` ResourceOwner string `json:"resourceOwner"`
// InstanceID is the instance this aggregate belongs to // InstanceID is the instance this aggregate belongs to
InstanceID string `json:"-"` InstanceID string `json:"instanceId"`
// Version is the semver this aggregate represents // Version is the semver this aggregate represents
Version Version `json:"-"` Version Version `json:"version"`
} }
// AggregateType is the object name // AggregateType is the object name

View File

@ -22,7 +22,7 @@ type BaseEvent struct {
ID string ID string
EventType EventType `json:"-"` EventType EventType `json:"-"`
Agg *Aggregate Agg *Aggregate `json:"-"`
Seq uint64 Seq uint64
Pos float64 Pos float64

View File

@ -97,22 +97,6 @@ type Instance struct {
WebAuthN *webauthn.Client WebAuthN *webauthn.Client
} }
// GetFirstInstance returns the default instance and org information,
// with authorized machine users.
// Using the first instance is not recommended as parallel test might
// interfere with each other.
// It is recommended to use [NewInstance] instead.
func GetFirstInstance(ctx context.Context) *Instance {
i := &Instance{
Config: loadedConfig,
Domain: loadedConfig.Hostname,
}
token := loadInstanceOwnerPAT()
i.setClient(ctx)
i.setupInstance(ctx, token)
return i
}
// NewInstance returns a new instance that can be used for integration tests. // NewInstance returns a new instance that can be used for integration tests.
// The instance contains a gRPC client connected to the domain of this instance. // The instance contains a gRPC client connected to the domain of this instance.
// The included users are the IAM_OWNER, ORG_OWNER of the default org and // The included users are the IAM_OWNER, ORG_OWNER of the default org and

View File

@ -3,7 +3,7 @@ package channels
import "github.com/zitadel/zitadel/internal/eventstore" import "github.com/zitadel/zitadel/internal/eventstore"
type Message interface { type Message interface {
GetTriggeringEvent() eventstore.Event GetTriggeringEventType() eventstore.EventType
GetContent() (string, error) GetContent() (string, error)
} }

View File

@ -13,7 +13,7 @@ func logMessages(ctx context.Context, channel channels.NotificationChannel) chan
return channels.HandleMessageFunc(func(message channels.Message) error { return channels.HandleMessageFunc(func(message channels.Message) error {
logEntry := logging.WithFields( logEntry := logging.WithFields(
"instance", authz.GetInstance(ctx).InstanceID(), "instance", authz.GetInstance(ctx).InstanceID(),
"triggering_event_type", message.GetTriggeringEvent().Type(), "triggering_event_type", message.GetTriggeringEventType(),
) )
logEntry.Debug("sending notification") logEntry.Debug("sending notification")
err := channel.HandleMessage(message) err := channel.HandleMessage(message)

View File

@ -24,7 +24,7 @@ func countMessages(ctx context.Context, channel channels.NotificationChannel, su
func addCount(ctx context.Context, metricName string, message channels.Message) { func addCount(ctx context.Context, metricName string, message channels.Message) {
labels := map[string]attribute.Value{ labels := map[string]attribute.Value{
"triggering_event_type": attribute.StringValue(string(message.GetTriggeringEvent().Type())), "triggering_event_type": attribute.StringValue(string(message.GetTriggeringEventType())),
} }
addCountErr := metrics.AddCount(ctx, metricName, 1, labels) addCountErr := metrics.AddCount(ctx, metricName, 1, labels)
logging.WithFields("name", metricName, "labels", labels).OnError(addCountErr).Error("incrementing counter metric failed") logging.WithFields("name", metricName, "labels", labels).OnError(addCountErr).Error("incrementing counter metric failed")

View File

@ -9,7 +9,6 @@ import (
verify "github.com/twilio/twilio-go/rest/verify/v2" verify "github.com/twilio/twilio-go/rest/verify/v2"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/notification/channels" "github.com/zitadel/zitadel/internal/notification/channels"
"github.com/zitadel/zitadel/internal/notification/messages" "github.com/zitadel/zitadel/internal/notification/messages"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
@ -39,15 +38,14 @@ func InitChannel(config Config) channels.NotificationChannel {
// as it would be a waste of resources and could potentially result in a rate limit. // as it would be a waste of resources and could potentially result in a rate limit.
var twilioErr *twilioClient.TwilioRestError var twilioErr *twilioClient.TwilioRestError
if errors.As(err, &twilioErr) && twilioErr.Status >= 400 && twilioErr.Status < 500 { if errors.As(err, &twilioErr) && twilioErr.Status >= 400 && twilioErr.Status < 500 {
userID, notificationID := userAndNotificationIDsFromEvent(twilioMsg.TriggeringEvent)
logging.WithFields( logging.WithFields(
"error", twilioErr.Message, "error", twilioErr.Message,
"status", twilioErr.Status, "status", twilioErr.Status,
"code", twilioErr.Code, "code", twilioErr.Code,
"instanceID", twilioMsg.TriggeringEvent.Aggregate().InstanceID, "instanceID", twilioMsg.InstanceID,
"userID", userID, "jobID", twilioMsg.JobID,
"notificationID", notificationID). "userID", twilioMsg.UserID,
Warn("twilio create verification error") ).Warn("twilio create verification error")
return channels.NewCancelError(twilioErr) return channels.NewCancelError(twilioErr)
} }
@ -76,24 +74,3 @@ func InitChannel(config Config) channels.NotificationChannel {
return nil return nil
}) })
} }
func userAndNotificationIDsFromEvent(event eventstore.Event) (userID, notificationID string) {
aggID := event.Aggregate().ID
// we cannot cast to the actual event type because of circular dependencies
// so we just check the type...
if event.Aggregate().Type != aggregateTypeNotification {
// in case it's not a notification event, we can directly return the aggregate ID (as it's a user event)
return aggID, ""
}
// ...and unmarshal the event data from the notification event into a struct that contains the fields we need
var data struct {
Request struct {
UserID string `json:"userID"`
} `json:"request"`
}
if err := event.Unmarshal(&data); err != nil {
return "", aggID
}
return data.Request.UserID, aggID
}

View File

@ -191,7 +191,7 @@ func (u *backChannelLogoutNotifier) sendLogoutToken(ctx context.Context, oidcSes
if err != nil { if err != nil {
return err return err
} }
err = types.SendSecurityTokenEvent(ctx, set.Config{CallURL: oidcSession.BackChannelLogoutURI}, u.channels, &LogoutTokenMessage{LogoutToken: token}, e).WithoutTemplate() err = types.SendSecurityTokenEvent(ctx, set.Config{CallURL: oidcSession.BackChannelLogoutURI}, u.channels, &LogoutTokenMessage{LogoutToken: token}, e.Type()).WithoutTemplate()
if err != nil { if err != nil {
return err return err
} }
@ -247,7 +247,7 @@ func (b *backChannelLogoutSession) AppendEvents(events ...eventstore.Event) {
BackChannelLogoutURI: e.BackChannelLogoutURI, BackChannelLogoutURI: e.BackChannelLogoutURI,
}) })
case *sessionlogout.BackChannelLogoutSentEvent: case *sessionlogout.BackChannelLogoutSentEvent:
slices.DeleteFunc(b.sessions, func(session backChannelLogoutOIDCSessions) bool { b.sessions = slices.DeleteFunc(b.sessions, func(session backChannelLogoutOIDCSessions) bool {
return session.OIDCSessionID == e.OIDCSessionID return session.OIDCSessionID == e.OIDCSessionID
}) })
} }

View File

@ -2,19 +2,13 @@ package handlers
import ( import (
"context" "context"
"database/sql"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/notification/senders" "github.com/zitadel/zitadel/internal/notification/senders"
"github.com/zitadel/zitadel/internal/repository/milestone" "github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/quota" "github.com/zitadel/zitadel/internal/repository/quota"
) )
type Commands interface { type Commands interface {
RequestNotification(ctx context.Context, instanceID string, request *command.NotificationRequest) error
NotificationCanceled(ctx context.Context, tx *sql.Tx, id, resourceOwner string, err error) error
NotificationRetryRequested(ctx context.Context, tx *sql.Tx, id, resourceOwner string, request *command.NotificationRetryRequest, err error) error
NotificationSent(ctx context.Context, tx *sql.Tx, id, instanceID string) error
HumanInitCodeSent(ctx context.Context, orgID, userID string) error HumanInitCodeSent(ctx context.Context, orgID, userID string) error
HumanEmailVerificationCodeSent(ctx context.Context, orgID, userID string) error HumanEmailVerificationCodeSent(ctx context.Context, orgID, userID string) error
PasswordCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error PasswordCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error

View File

@ -15,7 +15,7 @@ func HandlerContext(event *eventstore.Aggregate) context.Context {
} }
func ContextWithNotifier(ctx context.Context, aggregate *eventstore.Aggregate) context.Context { func ContextWithNotifier(ctx context.Context, aggregate *eventstore.Aggregate) context.Context {
return authz.SetCtxData(ctx, authz.CtxData{UserID: NotifyUserID, OrgID: aggregate.ResourceOwner}) return authz.WithInstanceID(authz.SetCtxData(ctx, authz.CtxData{UserID: NotifyUserID, OrgID: aggregate.ResourceOwner}), aggregate.InstanceID)
} }
func (n *NotificationQueries) HandlerContext(event *eventstore.Aggregate) (context.Context, error) { func (n *NotificationQueries) HandlerContext(event *eventstore.Aggregate) (context.Context, error) {

View File

@ -2,3 +2,4 @@ package handlers
//go:generate mockgen -package mock -destination ./mock/queries.mock.go github.com/zitadel/zitadel/internal/notification/handlers Queries //go:generate mockgen -package mock -destination ./mock/queries.mock.go github.com/zitadel/zitadel/internal/notification/handlers Queries
//go:generate mockgen -package mock -destination ./mock/commands.mock.go github.com/zitadel/zitadel/internal/notification/handlers Commands //go:generate mockgen -package mock -destination ./mock/commands.mock.go github.com/zitadel/zitadel/internal/notification/handlers Commands
//go:generate mockgen -package mock -destination ./mock/queue.mock.go github.com/zitadel/zitadel/internal/notification/handlers Queue

View File

@ -11,10 +11,8 @@ package mock
import ( import (
context "context" context "context"
sql "database/sql"
reflect "reflect" reflect "reflect"
command "github.com/zitadel/zitadel/internal/command"
senders "github.com/zitadel/zitadel/internal/notification/senders" senders "github.com/zitadel/zitadel/internal/notification/senders"
milestone "github.com/zitadel/zitadel/internal/repository/milestone" milestone "github.com/zitadel/zitadel/internal/repository/milestone"
quota "github.com/zitadel/zitadel/internal/repository/quota" quota "github.com/zitadel/zitadel/internal/repository/quota"
@ -156,48 +154,6 @@ func (mr *MockCommandsMockRecorder) MilestonePushed(arg0, arg1, arg2, arg3 any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), arg0, arg1, arg2, arg3)
} }
// NotificationCanceled mocks base method.
func (m *MockCommands) NotificationCanceled(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string, arg4 error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NotificationCanceled", arg0, arg1, arg2, arg3, arg4)
ret0, _ := ret[0].(error)
return ret0
}
// NotificationCanceled indicates an expected call of NotificationCanceled.
func (mr *MockCommandsMockRecorder) NotificationCanceled(arg0, arg1, arg2, arg3, arg4 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationCanceled", reflect.TypeOf((*MockCommands)(nil).NotificationCanceled), arg0, arg1, arg2, arg3, arg4)
}
// NotificationRetryRequested mocks base method.
func (m *MockCommands) NotificationRetryRequested(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string, arg4 *command.NotificationRetryRequest, arg5 error) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NotificationRetryRequested", arg0, arg1, arg2, arg3, arg4, arg5)
ret0, _ := ret[0].(error)
return ret0
}
// NotificationRetryRequested indicates an expected call of NotificationRetryRequested.
func (mr *MockCommandsMockRecorder) NotificationRetryRequested(arg0, arg1, arg2, arg3, arg4, arg5 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationRetryRequested", reflect.TypeOf((*MockCommands)(nil).NotificationRetryRequested), arg0, arg1, arg2, arg3, arg4, arg5)
}
// NotificationSent mocks base method.
func (m *MockCommands) NotificationSent(arg0 context.Context, arg1 *sql.Tx, arg2, arg3 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NotificationSent", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// NotificationSent indicates an expected call of NotificationSent.
func (mr *MockCommandsMockRecorder) NotificationSent(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationSent", reflect.TypeOf((*MockCommands)(nil).NotificationSent), arg0, arg1, arg2, arg3)
}
// OTPEmailSent mocks base method. // OTPEmailSent mocks base method.
func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) error { func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
@ -254,20 +210,6 @@ func (mr *MockCommandsMockRecorder) PasswordCodeSent(arg0, arg1, arg2, arg3 any)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2, arg3) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2, arg3)
} }
// RequestNotification mocks base method.
func (m *MockCommands) RequestNotification(arg0 context.Context, arg1 string, arg2 *command.NotificationRequest) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "RequestNotification", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// RequestNotification indicates an expected call of RequestNotification.
func (mr *MockCommandsMockRecorder) RequestNotification(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RequestNotification", reflect.TypeOf((*MockCommands)(nil).RequestNotification), arg0, arg1, arg2)
}
// UsageNotificationSent mocks base method. // UsageNotificationSent mocks base method.
func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.NotificationDueEvent) error { func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.NotificationDueEvent) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()

View File

@ -0,0 +1,61 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/zitadel/internal/notification/handlers (interfaces: Queue)
//
// Generated by this command:
//
// mockgen -package mock -destination ./mock/queue.mock.go github.com/zitadel/zitadel/internal/notification/handlers Queue
//
// Package mock is a generated GoMock package.
package mock
import (
context "context"
reflect "reflect"
river "github.com/riverqueue/river"
queue "github.com/zitadel/zitadel/internal/queue"
gomock "go.uber.org/mock/gomock"
)
// MockQueue is a mock of Queue interface.
type MockQueue struct {
ctrl *gomock.Controller
recorder *MockQueueMockRecorder
}
// MockQueueMockRecorder is the mock recorder for MockQueue.
type MockQueueMockRecorder struct {
mock *MockQueue
}
// NewMockQueue creates a new mock instance.
func NewMockQueue(ctrl *gomock.Controller) *MockQueue {
mock := &MockQueue{ctrl: ctrl}
mock.recorder = &MockQueueMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockQueue) EXPECT() *MockQueueMockRecorder {
return m.recorder
}
// Insert mocks base method.
func (m *MockQueue) Insert(arg0 context.Context, arg1 river.JobArgs, arg2 ...queue.InsertOpt) error {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Insert", varargs...)
ret0, _ := ret[0].(error)
return ret0
}
// Insert indicates an expected call of Insert.
func (mr *MockQueueMockRecorder) Insert(arg0, arg1 any, arg2 ...any) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]any{arg0, arg1}, arg2...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Insert", reflect.TypeOf((*MockQueue)(nil).Insert), varargs...)
}

View File

@ -2,18 +2,16 @@ package handlers
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"math/rand/v2" "fmt"
"slices" "strconv"
"strings" "strings"
"time" "time"
"github.com/riverqueue/river"
"github.com/zitadel/logging" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
@ -22,7 +20,7 @@ import (
"github.com/zitadel/zitadel/internal/notification/senders" "github.com/zitadel/zitadel/internal/notification/senders"
"github.com/zitadel/zitadel/internal/notification/types" "github.com/zitadel/zitadel/internal/notification/types"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/instance" "github.com/zitadel/zitadel/internal/queue"
"github.com/zitadel/zitadel/internal/repository/notification" "github.com/zitadel/zitadel/internal/repository/notification"
) )
@ -32,6 +30,8 @@ const (
) )
type NotificationWorker struct { type NotificationWorker struct {
river.WorkerDefaults[*notification.Request]
commands Commands commands Commands
queries *NotificationQueries queries *NotificationQueries
es *eventstore.Eventstore es *eventstore.Eventstore
@ -39,22 +39,53 @@ type NotificationWorker struct {
channels types.ChannelChains channels types.ChannelChains
config WorkerConfig config WorkerConfig
now nowFunc now nowFunc
backOff func(current time.Duration) time.Duration }
func (w *NotificationWorker) Timeout(*river.Job[*notification.Request]) time.Duration {
return w.config.TransactionDuration
}
// Work implements [river.Worker].
func (w *NotificationWorker) Work(ctx context.Context, job *river.Job[*notification.Request]) error {
ctx = ContextWithNotifier(ctx, job.Args.Aggregate)
// if the notification is too old, we can directly cancel
if job.CreatedAt.Add(w.config.MaxTtl).Before(w.now()) {
return river.JobCancel(errors.New("notification is too old"))
}
// We do not trigger the projection to reduce load on the database. By the time the notification is processed,
// the user should be projected anyway. If not, it will just wait for the next run.
// We are aware that the user can change during the time the notification is in the queue.
notifyUser, err := w.queries.GetNotifyUserByID(ctx, false, job.Args.UserID)
if err != nil {
return err
}
// The domain claimed event requires the domain as argument, but lacks the user when creating the request event.
// Since we set it into the request arguments, it will be passed into a potential retry event.
if job.Args.RequiresPreviousDomain && job.Args.Args != nil && job.Args.Args.Domain == "" {
index := strings.LastIndex(notifyUser.LastEmail, "@")
job.Args.Args.Domain = notifyUser.LastEmail[index+1:]
}
err = w.sendNotificationQueue(ctx, job.Args, strconv.Itoa(int(job.ID)), notifyUser)
if err == nil {
return nil
}
// if the error explicitly specifies, we cancel the notification
if errors.Is(err, &channels.CancelError{}) {
return river.JobCancel(err)
}
return err
} }
type WorkerConfig struct { type WorkerConfig struct {
LegacyEnabled bool LegacyEnabled bool
Workers uint8 Workers uint8
BulkLimit uint16
RequeueEvery time.Duration
RetryWorkers uint8
RetryRequeueEvery time.Duration
TransactionDuration time.Duration TransactionDuration time.Duration
MaxAttempts uint8
MaxTtl time.Duration MaxTtl time.Duration
MinRetryDelay time.Duration MaxAttempts uint8
MaxRetryDelay time.Duration
RetryDelayFactor float32
} }
// nowFunc makes [time.Now] mockable // nowFunc makes [time.Now] mockable
@ -78,11 +109,8 @@ func NewNotificationWorker(
es *eventstore.Eventstore, es *eventstore.Eventstore,
client *database.DB, client *database.DB,
channels types.ChannelChains, channels types.ChannelChains,
queue *queue.Queue,
) *NotificationWorker { ) *NotificationWorker {
// make sure the delay does not get less
if config.RetryDelayFactor < 1 {
config.RetryDelayFactor = 1
}
w := &NotificationWorker{ w := &NotificationWorker{
config: config, config: config,
commands: commands, commands: commands,
@ -92,102 +120,31 @@ func NewNotificationWorker(
channels: channels, channels: channels,
now: time.Now, now: time.Now,
} }
w.backOff = w.exponentialBackOff if !config.LegacyEnabled {
queue.AddWorkers(w)
}
return w return w
} }
func (w *NotificationWorker) Start(ctx context.Context) { var _ river.Worker[*notification.Request] = (*NotificationWorker)(nil)
if w.config.LegacyEnabled {
return func (w *NotificationWorker) Register(workers *river.Workers, queues map[string]river.QueueConfig) {
} river.AddWorker(workers, w)
for i := 0; i < int(w.config.Workers); i++ { queues[notification.QueueName] = river.QueueConfig{
go w.schedule(ctx, i, false) MaxWorkers: int(w.config.Workers),
}
for i := 0; i < int(w.config.RetryWorkers); i++ {
go w.schedule(ctx, i, true)
} }
} }
func (w *NotificationWorker) reduceNotificationRequested(ctx, txCtx context.Context, tx *sql.Tx, event *notification.RequestedEvent) (err error) { func (w *NotificationWorker) sendNotificationQueue(ctx context.Context, request *notification.Request, jobID string, notifyUser *query.NotifyUser) error {
ctx = ContextWithNotifier(ctx, event.Aggregate())
// if the notification is too old, we can directly cancel
if event.CreatedAt().Add(w.config.MaxTtl).Before(w.now()) {
return w.commands.NotificationCanceled(txCtx, tx, event.Aggregate().ID, event.Aggregate().ResourceOwner, nil)
}
// Get the notify user first, so if anything fails afterward we have the current state of the user
// and can pass that to the retry request.
// We do not trigger the projection to reduce load on the database. By the time the notification is processed,
// the user should be projected anyway. If not, it will just wait for the next run.
notifyUser, err := w.queries.GetNotifyUserByID(ctx, false, event.UserID)
if err != nil {
return err
}
// The domain claimed event requires the domain as argument, but lacks the user when creating the request event.
// Since we set it into the request arguments, it will be passed into a potential retry event.
if event.RequiresPreviousDomain && event.Request.Args != nil && event.Request.Args.Domain == "" {
index := strings.LastIndex(notifyUser.LastEmail, "@")
event.Request.Args.Domain = notifyUser.LastEmail[index+1:]
}
err = w.sendNotification(ctx, txCtx, tx, event.Request, notifyUser, event)
if err == nil {
return nil
}
// if retries are disabled or if the error explicitly specifies, we cancel the notification
if w.config.MaxAttempts <= 1 || errors.Is(err, &channels.CancelError{}) {
return w.commands.NotificationCanceled(txCtx, tx, event.Aggregate().ID, event.Aggregate().ResourceOwner, err)
}
// otherwise we retry after a backoff delay
return w.commands.NotificationRetryRequested(
txCtx,
tx,
event.Aggregate().ID,
event.Aggregate().ResourceOwner,
notificationEventToRequest(event.Request, notifyUser, w.backOff(0)),
err,
)
}
func (w *NotificationWorker) reduceNotificationRetry(ctx, txCtx context.Context, tx *sql.Tx, event *notification.RetryRequestedEvent) (err error) {
ctx = ContextWithNotifier(ctx, event.Aggregate())
// if the notification is too old, we can directly cancel
if event.CreatedAt().Add(w.config.MaxTtl).Before(w.now()) {
return w.commands.NotificationCanceled(txCtx, tx, event.Aggregate().ID, event.Aggregate().ResourceOwner, err)
}
if event.CreatedAt().Add(event.BackOff).After(w.now()) {
return nil
}
err = w.sendNotification(ctx, txCtx, tx, event.Request, event.NotifyUser, event)
if err == nil {
return nil
}
// if the max attempts are reached or if the error explicitly specifies, we cancel the notification
if event.Sequence() >= uint64(w.config.MaxAttempts) || errors.Is(err, &channels.CancelError{}) {
return w.commands.NotificationCanceled(txCtx, tx, event.Aggregate().ID, event.Aggregate().ResourceOwner, err)
}
// otherwise we retry after a backoff delay
return w.commands.NotificationRetryRequested(txCtx, tx, event.Aggregate().ID, event.Aggregate().ResourceOwner, notificationEventToRequest(
event.Request,
event.NotifyUser,
w.backOff(event.BackOff),
), err)
}
func (w *NotificationWorker) sendNotification(ctx, txCtx context.Context, tx *sql.Tx, request notification.Request, notifyUser *query.NotifyUser, e eventstore.Event) error {
ctx, err := enrichCtx(ctx, request.TriggeredAtOrigin)
if err != nil {
return channels.NewCancelError(err)
}
// check early that a "sent" handler exists, otherwise we can cancel early // check early that a "sent" handler exists, otherwise we can cancel early
sentHandler, ok := sentHandlers[request.EventType] sentHandler, ok := sentHandlers[request.EventType]
if !ok { if !ok {
logging.Errorf(`no "sent" handler registered for %s`, request.EventType) logging.Errorf(`no "sent" handler registered for %s`, request.EventType)
return channels.NewCancelError(fmt.Errorf("no sent handler registered for %s", request.EventType))
}
ctx, err := enrichCtx(ctx, request.TriggeredAtOrigin)
if err != nil {
return channels.NewCancelError(err) return channels.NewCancelError(err)
} }
@ -217,9 +174,9 @@ func (w *NotificationWorker) sendNotification(ctx, txCtx context.Context, tx *sq
if err != nil { if err != nil {
return err return err
} }
notify = types.SendEmail(ctx, w.channels, string(template.Template), translator, notifyUser, colors, e) notify = types.SendEmail(ctx, w.channels, string(template.Template), translator, notifyUser, colors, request.EventType)
case domain.NotificationTypeSms: case domain.NotificationTypeSms:
notify = types.SendSMS(ctx, w.channels, translator, notifyUser, colors, e, generatorInfo) notify = types.SendSMS(ctx, w.channels, translator, notifyUser, colors, request.EventType, request.Aggregate.InstanceID, jobID, generatorInfo)
} }
args := request.Args.ToMap() args := request.Args.ToMap()
@ -229,272 +186,12 @@ func (w *NotificationWorker) sendNotification(ctx, txCtx context.Context, tx *sq
args[OTP] = code args[OTP] = code
} }
if err := notify(request.URLTemplate, args, request.MessageType, request.UnverifiedNotificationChannel); err != nil { if err = notify(request.URLTemplate, args, request.MessageType, request.UnverifiedNotificationChannel); err != nil {
return err return err
} }
err = w.commands.NotificationSent(txCtx, tx, e.Aggregate().ID, e.Aggregate().ResourceOwner)
if err != nil { err = sentHandler(authz.WithInstanceID(ctx, request.Aggregate.InstanceID), w.commands, request.Aggregate.ID, request.Aggregate.ResourceOwner, generatorInfo, args)
// In case the notification event cannot be pushed, we most likely cannot create a retry or cancel event. logging.WithFields("instanceID", request.Aggregate.InstanceID, "notification", request.Aggregate.ID).
// Therefore, we'll only log the error and also do not need to try to push to the user / session.
logging.WithFields("instanceID", authz.GetInstance(ctx).InstanceID(), "notification", e.Aggregate().ID).
OnError(err).Error("could not set sent notification event")
return nil
}
err = sentHandler(txCtx, w.commands, request.NotificationAggregateID(), request.NotificationAggregateResourceOwner(), generatorInfo, args)
logging.WithFields("instanceID", authz.GetInstance(ctx).InstanceID(), "notification", e.Aggregate().ID).
OnError(err).Error("could not set notification event on aggregate") OnError(err).Error("could not set notification event on aggregate")
return nil return nil
} }
func (w *NotificationWorker) exponentialBackOff(current time.Duration) time.Duration {
if current >= w.config.MaxRetryDelay {
return w.config.MaxRetryDelay
}
if current < w.config.MinRetryDelay {
current = w.config.MinRetryDelay
}
t := time.Duration(rand.Int64N(int64(w.config.RetryDelayFactor*float32(current.Nanoseconds()))-current.Nanoseconds()) + current.Nanoseconds())
if t > w.config.MaxRetryDelay {
return w.config.MaxRetryDelay
}
return t
}
func notificationEventToRequest(e notification.Request, notifyUser *query.NotifyUser, backoff time.Duration) *command.NotificationRetryRequest {
return &command.NotificationRetryRequest{
NotificationRequest: command.NotificationRequest{
UserID: e.UserID,
UserResourceOwner: e.UserResourceOwner,
TriggerOrigin: e.TriggeredAtOrigin,
URLTemplate: e.URLTemplate,
Code: e.Code,
CodeExpiry: e.CodeExpiry,
EventType: e.EventType,
NotificationType: e.NotificationType,
MessageType: e.MessageType,
UnverifiedNotificationChannel: e.UnverifiedNotificationChannel,
Args: e.Args,
AggregateID: e.AggregateID,
AggregateResourceOwner: e.AggregateResourceOwner,
IsOTP: e.IsOTP,
},
BackOff: backoff,
NotifyUser: notifyUser,
}
}
func (w *NotificationWorker) schedule(ctx context.Context, workerID int, retry bool) {
t := time.NewTimer(0)
for {
select {
case <-ctx.Done():
t.Stop()
w.log(workerID, retry).Info("scheduler stopped")
return
case <-t.C:
instances, err := w.queryInstances(ctx, retry)
w.log(workerID, retry).OnError(err).Error("unable to query instances")
w.triggerInstances(call.WithTimestamp(ctx), instances, workerID, retry)
if retry {
t.Reset(w.config.RetryRequeueEvery)
continue
}
t.Reset(w.config.RequeueEvery)
}
}
}
func (w *NotificationWorker) log(workerID int, retry bool) *logging.Entry {
return logging.WithFields("notification worker", workerID, "retries", retry)
}
func (w *NotificationWorker) queryInstances(ctx context.Context, retry bool) ([]string, error) {
return w.queries.ActiveInstances(), nil
}
func (w *NotificationWorker) triggerInstances(ctx context.Context, instances []string, workerID int, retry bool) {
for _, instance := range instances {
instanceCtx := authz.WithInstanceID(ctx, instance)
err := w.trigger(instanceCtx, workerID, retry)
w.log(workerID, retry).WithField("instance", instance).OnError(err).Info("trigger failed")
}
}
func (w *NotificationWorker) trigger(ctx context.Context, workerID int, retry bool) (err error) {
txCtx := ctx
if w.config.TransactionDuration > 0 {
var cancel, cancelTx func()
txCtx, cancelTx = context.WithCancel(ctx)
defer cancelTx()
ctx, cancel = context.WithTimeout(ctx, w.config.TransactionDuration)
defer cancel()
}
tx, err := w.client.BeginTx(txCtx, nil)
if err != nil {
return err
}
defer func() {
err = database.CloseTransaction(tx, err)
}()
events, err := w.searchEvents(txCtx, tx, retry)
if err != nil {
return err
}
// If there aren't any events or no unlocked event terminate early and start a new run.
if len(events) == 0 {
return nil
}
w.log(workerID, retry).
WithField("instanceID", authz.GetInstance(ctx).InstanceID()).
WithField("events", len(events)).
Info("handling notification events")
for _, event := range events {
var err error
switch e := event.(type) {
case *notification.RequestedEvent:
w.createSavepoint(txCtx, tx, event, workerID, retry)
err = w.reduceNotificationRequested(ctx, txCtx, tx, e)
case *notification.RetryRequestedEvent:
w.createSavepoint(txCtx, tx, event, workerID, retry)
err = w.reduceNotificationRetry(ctx, txCtx, tx, e)
}
if err != nil {
w.log(workerID, retry).OnError(err).
WithField("instanceID", authz.GetInstance(ctx).InstanceID()).
WithField("notificationID", event.Aggregate().ID).
WithField("sequence", event.Sequence()).
WithField("type", event.Type()).
Error("could not handle notification event")
// if we have an error, we rollback to the savepoint and continue with the next event
// we use the txCtx to make sure we can rollback the transaction in case the ctx is canceled
w.rollbackToSavepoint(txCtx, tx, event, workerID, retry)
}
// if the context is canceled, we stop the processing
if ctx.Err() != nil {
return nil
}
}
return nil
}
func (w *NotificationWorker) latestRetries(events []eventstore.Event) []eventstore.Event {
for i := len(events) - 1; i > 0; i-- {
// since we delete during the iteration, we need to make sure we don't panic
if len(events) <= i {
continue
}
// delete all the previous retries of the same notification
events = slices.DeleteFunc(events, func(e eventstore.Event) bool {
return e.Aggregate().ID == events[i].Aggregate().ID &&
e.Sequence() < events[i].Sequence()
})
}
return events
}
func (w *NotificationWorker) createSavepoint(ctx context.Context, tx *sql.Tx, event eventstore.Event, workerID int, retry bool) {
_, err := tx.ExecContext(ctx, "SAVEPOINT notification_send")
w.log(workerID, retry).OnError(err).
WithField("instanceID", authz.GetInstance(ctx).InstanceID()).
WithField("notificationID", event.Aggregate().ID).
WithField("sequence", event.Sequence()).
WithField("type", event.Type()).
Error("could not create savepoint for notification event")
}
func (w *NotificationWorker) rollbackToSavepoint(ctx context.Context, tx *sql.Tx, event eventstore.Event, workerID int, retry bool) {
_, err := tx.ExecContext(ctx, "ROLLBACK TO SAVEPOINT notification_send")
w.log(workerID, retry).OnError(err).
WithField("instanceID", authz.GetInstance(ctx).InstanceID()).
WithField("notificationID", event.Aggregate().ID).
WithField("sequence", event.Sequence()).
WithField("type", event.Type()).
Error("could not rollback to savepoint for notification event")
}
func (w *NotificationWorker) searchEvents(ctx context.Context, tx *sql.Tx, retry bool) ([]eventstore.Event, error) {
if retry {
return w.searchRetryEvents(ctx, tx)
}
// query events and lock them for update (with skip locked)
searchQuery := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
LockRowsDuringTx(tx, eventstore.LockOptionSkipLocked).
// Messages older than the MaxTTL, we can be ignored.
// The first attempt of a retry might still be older than the TTL and needs to be filtered out later on.
CreationDateAfter(w.now().Add(-1*w.config.MaxTtl)).
Limit(uint64(w.config.BulkLimit)).
AddQuery().
AggregateTypes(notification.AggregateType).
EventTypes(notification.RequestedType).
Builder().
ExcludeAggregateIDs().
EventTypes(notification.RetryRequestedType, notification.CanceledType, notification.SentType).
AggregateTypes(notification.AggregateType).
Builder()
//nolint:staticcheck
return w.es.Filter(ctx, searchQuery)
}
func (w *NotificationWorker) searchRetryEvents(ctx context.Context, tx *sql.Tx) ([]eventstore.Event, error) {
// query events and lock them for update (with skip locked)
searchQuery := eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
LockRowsDuringTx(tx, eventstore.LockOptionSkipLocked).
// Messages older than the MaxTTL, we can be ignored.
// The first attempt of a retry might still be older than the TTL and needs to be filtered out later on.
CreationDateAfter(w.now().Add(-1*w.config.MaxTtl)).
AddQuery().
AggregateTypes(notification.AggregateType).
EventTypes(notification.RetryRequestedType).
Builder().
ExcludeAggregateIDs().
EventTypes(notification.CanceledType, notification.SentType).
AggregateTypes(notification.AggregateType).
Builder()
//nolint:staticcheck
events, err := w.es.Filter(ctx, searchQuery)
if err != nil {
return nil, err
}
return w.latestRetries(events), nil
}
type existingInstances []string
// AppendEvents implements eventstore.QueryReducer.
func (ai *existingInstances) AppendEvents(events ...eventstore.Event) {
for _, event := range events {
switch event.Type() {
case instance.InstanceAddedEventType:
*ai = append(*ai, event.Aggregate().InstanceID)
case instance.InstanceRemovedEventType:
*ai = slices.DeleteFunc(*ai, func(s string) bool {
return s == event.Aggregate().InstanceID
})
}
}
}
// Query implements eventstore.QueryReducer.
func (*existingInstances) Query() *eventstore.SearchQueryBuilder {
return eventstore.NewSearchQueryBuilder(eventstore.ColumnsEvent).
AddQuery().
AggregateTypes(instance.AggregateType).
EventTypes(
instance.InstanceAddedEventType,
instance.InstanceRemovedEventType,
).
Builder()
}
// Reduce implements eventstore.QueryReducer.
// reduce is not used as events are reduced during AppendEvents
func (*existingInstances) Reduce() error {
return nil
}

View File

@ -2,22 +2,21 @@ package handlers
import ( import (
"context" "context"
"database/sql"
"errors" "errors"
"fmt" "fmt"
"testing" "testing"
"time" "time"
"github.com/muhlemmer/gu" "github.com/muhlemmer/gu"
"github.com/riverqueue/river"
"github.com/riverqueue/river/rivertype"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock" "go.uber.org/mock/gomock"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/ui/login" "github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
es_repo_mock "github.com/zitadel/zitadel/internal/eventstore/repository/mock" es_repo_mock "github.com/zitadel/zitadel/internal/eventstore/repository/mock"
"github.com/zitadel/zitadel/internal/notification/channels/email" "github.com/zitadel/zitadel/internal/notification/channels/email"
channel_mock "github.com/zitadel/zitadel/internal/notification/channels/mock" channel_mock "github.com/zitadel/zitadel/internal/notification/channels/mock"
@ -51,7 +50,6 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
name: "too old", name: "too old",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fieldsWorker, a argsWorker, w wantWorker) { test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fieldsWorker, a argsWorker, w wantWorker) {
codeAlg, code := cryptoValue(t, ctrl, "testcode") codeAlg, code := cryptoValue(t, ctrl, "testcode")
commands.EXPECT().NotificationCanceled(gomock.Any(), gomock.Any(), notificationID, instanceID, nil).Return(nil)
return fieldsWorker{ return fieldsWorker{
queries: queries, queries: queries,
commands: commands, commands: commands,
@ -62,19 +60,18 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
now: testNow, now: testNow,
}, },
argsWorker{ argsWorker{
event: &notification.RequestedEvent{ job: &river.Job[*notification.Request]{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{ JobRow: &rivertype.JobRow{
InstanceID: instanceID, CreatedAt: time.Now().Add(-1 * time.Hour),
AggregateID: notificationID, },
ResourceOwner: sql.NullString{String: instanceID}, Args: &notification.Request{
CreationDate: time.Now().Add(-1 * time.Hour), Aggregate: &eventstore.Aggregate{
Typ: notification.RequestedType, InstanceID: instanceID,
}), ID: notificationID,
Request: notification.Request{ ResourceOwner: instanceID,
},
UserID: userID, UserID: userID,
UserResourceOwner: orgID, UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin, TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType, EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType, MessageType: domain.InviteUserMessageType,
@ -90,7 +87,12 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
}, },
}, },
}, },
}, w },
wantWorker{
err: func(tt assert.TestingT, err error, i ...interface{}) bool {
return errors.Is(err, new(river.JobCancelError))
},
}
}, },
}, },
{ {
@ -99,13 +101,13 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
givenTemplate := "{{.LogoURL}}" givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL) expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = &messages.Email{ w.message = &messages.Email{
Recipients: []string{lastEmail}, Recipients: []string{lastEmail},
Subject: "Invitation to APP", Subject: "Invitation to APP",
Content: expectContent, Content: expectContent,
TriggeringEventType: user.HumanInviteCodeAddedType,
} }
codeAlg, code := cryptoValue(t, ctrl, "testcode") codeAlg, code := cryptoValue(t, ctrl, "testcode")
expectTemplateWithNotifyUserQueries(queries, givenTemplate) expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().NotificationSent(gomock.Any(), gomock.Any(), notificationID, instanceID).Return(nil)
commands.EXPECT().InviteCodeSent(gomock.Any(), orgID, userID).Return(nil) commands.EXPECT().InviteCodeSent(gomock.Any(), orgID, userID).Return(nil)
return fieldsWorker{ return fieldsWorker{
queries: queries, queries: queries,
@ -117,19 +119,18 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
now: testNow, now: testNow,
}, },
argsWorker{ argsWorker{
event: &notification.RequestedEvent{ job: &river.Job[*notification.Request]{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{ JobRow: &rivertype.JobRow{
InstanceID: instanceID, CreatedAt: time.Now(),
AggregateID: notificationID, },
ResourceOwner: sql.NullString{String: instanceID}, Args: &notification.Request{
CreationDate: time.Now().UTC(), Aggregate: &eventstore.Aggregate{
Typ: notification.RequestedType, InstanceID: instanceID,
}), ID: userID,
Request: notification.Request{ ResourceOwner: orgID,
},
UserID: userID, UserID: userID,
UserResourceOwner: orgID, UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin, TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType, EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType, MessageType: domain.InviteUserMessageType,
@ -145,7 +146,8 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
}, },
}, },
}, },
}, w },
w
}, },
}, },
{ {
@ -159,10 +161,13 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
SenderPhoneNumber: "senderNumber", SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: verifiedPhone, RecipientPhoneNumber: verifiedPhone,
Content: expectContent, Content: expectContent,
TriggeringEventType: session.OTPSMSChallengedType,
InstanceID: instanceID,
JobID: "1",
UserID: userID,
} }
codeAlg, code := cryptoValue(t, ctrl, testCode) codeAlg, code := cryptoValue(t, ctrl, testCode)
expectTemplateWithNotifyUserQueriesSMS(queries) expectTemplateWithNotifyUserQueriesSMS(queries)
commands.EXPECT().NotificationSent(gomock.Any(), gomock.Any(), notificationID, instanceID).Return(nil)
commands.EXPECT().OTPSMSSent(gomock.Any(), sessionID, instanceID, &senders.CodeGeneratorInfo{ commands.EXPECT().OTPSMSSent(gomock.Any(), sessionID, instanceID, &senders.CodeGeneratorInfo{
ID: smsProviderID, ID: smsProviderID,
VerificationID: verificationID, VerificationID: verificationID,
@ -177,19 +182,19 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
now: testNow, now: testNow,
}, },
argsWorker{ argsWorker{
event: &notification.RequestedEvent{ job: &river.Job[*notification.Request]{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{ JobRow: &rivertype.JobRow{
InstanceID: instanceID, CreatedAt: time.Now(),
AggregateID: notificationID, ID: 1,
ResourceOwner: sql.NullString{String: instanceID}, },
CreationDate: time.Now().UTC(), Args: &notification.Request{
Typ: notification.RequestedType, Aggregate: &eventstore.Aggregate{
}), InstanceID: instanceID,
Request: notification.Request{ ID: sessionID,
ResourceOwner: instanceID,
},
UserID: userID, UserID: userID,
UserResourceOwner: orgID, UserResourceOwner: orgID,
AggregateID: sessionID,
AggregateResourceOwner: instanceID,
TriggeredAtOrigin: eventOrigin, TriggeredAtOrigin: eventOrigin,
EventType: session.OTPSMSChallengedType, EventType: session.OTPSMSChallengedType,
MessageType: domain.VerifySMSOTPMessageType, MessageType: domain.VerifySMSOTPMessageType,
@ -216,12 +221,12 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
givenTemplate := "{{.LogoURL}}" givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL) expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = &messages.Email{ w.message = &messages.Email{
Recipients: []string{verifiedEmail}, Recipients: []string{verifiedEmail},
Subject: "Domain has been claimed", Subject: "Domain has been claimed",
Content: expectContent, Content: expectContent,
TriggeringEventType: user.UserDomainClaimedType,
} }
expectTemplateWithNotifyUserQueries(queries, givenTemplate) expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().NotificationSent(gomock.Any(), gomock.Any(), notificationID, instanceID).Return(nil)
commands.EXPECT().UserDomainClaimedSent(gomock.Any(), orgID, userID).Return(nil) commands.EXPECT().UserDomainClaimedSent(gomock.Any(), orgID, userID).Return(nil)
return fieldsWorker{ return fieldsWorker{
queries: queries, queries: queries,
@ -233,19 +238,18 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
now: testNow, now: testNow,
}, },
argsWorker{ argsWorker{
event: &notification.RequestedEvent{ job: &river.Job[*notification.Request]{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{ JobRow: &rivertype.JobRow{
InstanceID: instanceID, CreatedAt: time.Now(),
AggregateID: notificationID, },
ResourceOwner: sql.NullString{String: instanceID}, Args: &notification.Request{
CreationDate: time.Now().UTC(), Aggregate: &eventstore.Aggregate{
Typ: notification.RequestedType, InstanceID: instanceID,
}), ID: userID,
Request: notification.Request{ ResourceOwner: orgID,
},
UserID: userID, UserID: userID,
UserResourceOwner: orgID, UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin, TriggeredAtOrigin: eventOrigin,
EventType: user.UserDomainClaimedType, EventType: user.UserDomainClaimedType,
MessageType: domain.DomainClaimedMessageType, MessageType: domain.DomainClaimedMessageType,
@ -270,47 +274,17 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
givenTemplate := "{{.LogoURL}}" givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL) expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = &messages.Email{ w.message = &messages.Email{
Recipients: []string{lastEmail}, Recipients: []string{lastEmail},
Subject: "Invitation to APP", Subject: "Invitation to APP",
Content: expectContent, Content: expectContent,
TriggeringEventType: user.HumanInviteCodeAddedType,
} }
w.sendError = sendError w.sendError = sendError
w.err = func(tt assert.TestingT, err error, i ...interface{}) bool {
return errors.Is(err, sendError)
}
codeAlg, code := cryptoValue(t, ctrl, "testcode") codeAlg, code := cryptoValue(t, ctrl, "testcode")
expectTemplateWithNotifyUserQueries(queries, givenTemplate) expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().NotificationRetryRequested(gomock.Any(), gomock.Any(), notificationID, instanceID,
&command.NotificationRetryRequest{
NotificationRequest: command.NotificationRequest{
UserID: userID,
UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggerOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType,
NotificationType: domain.NotificationTypeEmail,
URLTemplate: fmt.Sprintf("%s/ui/login/user/invite?userID=%s&loginname={{.LoginName}}&code={{.Code}}&orgID=%s&authRequestID=%s", eventOrigin, userID, orgID, authRequestID),
CodeExpiry: 1 * time.Hour,
Code: code,
UnverifiedNotificationChannel: true,
IsOTP: false,
RequiresPreviousDomain: false,
Args: &domain.NotificationArguments{
ApplicationName: "APP",
},
},
BackOff: 1 * time.Second,
NotifyUser: &query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
},
},
sendError,
).Return(nil)
return fieldsWorker{ return fieldsWorker{
queries: queries, queries: queries,
commands: commands, commands: commands,
@ -320,22 +294,21 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
userDataCrypto: codeAlg, userDataCrypto: codeAlg,
now: testNow, now: testNow,
backOff: testBackOff, backOff: testBackOff,
maxAttempts: 2,
}, },
argsWorker{ argsWorker{
event: &notification.RequestedEvent{ job: &river.Job[*notification.Request]{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{ JobRow: &rivertype.JobRow{
InstanceID: instanceID, ID: 1,
AggregateID: notificationID, CreatedAt: time.Now(),
ResourceOwner: sql.NullString{String: instanceID}, },
CreationDate: time.Now().UTC(), Args: &notification.Request{
Typ: notification.RequestedType, Aggregate: &eventstore.Aggregate{
}), InstanceID: instanceID,
Request: notification.Request{ ID: notificationID,
ResourceOwner: instanceID,
},
UserID: userID, UserID: userID,
UserResourceOwner: orgID, UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin, TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType, EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType, MessageType: domain.InviteUserMessageType,
@ -351,7 +324,8 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
}, },
}, },
}, },
}, w },
w
}, },
}, },
{ {
@ -360,315 +334,18 @@ func Test_userNotifier_reduceNotificationRequested(t *testing.T) {
givenTemplate := "{{.LogoURL}}" givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL) expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = &messages.Email{ w.message = &messages.Email{
Recipients: []string{lastEmail}, Recipients: []string{lastEmail},
Subject: "Invitation to APP", Subject: "Invitation to APP",
Content: expectContent, Content: expectContent,
TriggeringEventType: user.HumanInviteCodeAddedType,
} }
w.sendError = sendError w.sendError = sendError
codeAlg, code := cryptoValue(t, ctrl, "testcode") w.err = func(tt assert.TestingT, err error, i ...interface{}) bool {
expectTemplateWithNotifyUserQueries(queries, givenTemplate) return err != nil
commands.EXPECT().NotificationCanceled(gomock.Any(), gomock.Any(), notificationID, instanceID, sendError).Return(nil) }
return fieldsWorker{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).MockQuerier,
}),
userDataCrypto: codeAlg,
now: testNow,
backOff: testBackOff,
maxAttempts: 1,
},
argsWorker{
event: &notification.RequestedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
InstanceID: instanceID,
AggregateID: notificationID,
ResourceOwner: sql.NullString{String: instanceID},
CreationDate: time.Now().UTC(),
Seq: 1,
Typ: notification.RequestedType,
}),
Request: notification.Request{
UserID: userID,
UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType,
NotificationType: domain.NotificationTypeEmail,
URLTemplate: fmt.Sprintf("%s/ui/login/user/invite?userID=%s&loginname={{.LoginName}}&code={{.Code}}&orgID=%s&authRequestID=%s", eventOrigin, userID, orgID, authRequestID),
CodeExpiry: 1 * time.Hour,
Code: code,
UnverifiedNotificationChannel: true,
IsOTP: false,
RequiresPreviousDomain: false,
Args: &domain.NotificationArguments{
ApplicationName: "APP",
},
},
},
}, w
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
queries := mock.NewMockQueries(ctrl)
commands := mock.NewMockCommands(ctrl)
f, a, w := tt.test(ctrl, queries, commands)
err := newNotificationWorker(t, ctrl, queries, f, a, w).reduceNotificationRequested(
authz.WithInstanceID(context.Background(), instanceID),
authz.WithInstanceID(context.Background(), instanceID),
&sql.Tx{},
a.event.(*notification.RequestedEvent))
if w.err != nil {
w.err(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func Test_userNotifier_reduceNotificationRetry(t *testing.T) {
testNow := time.Now
testBackOff := func(current time.Duration) time.Duration {
return time.Second
}
sendError := errors.New("send error")
tests := []struct {
name string
test func(*gomock.Controller, *mock.MockQueries, *mock.MockCommands) (fieldsWorker, argsWorker, wantWorker)
}{
{
name: "too old",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fieldsWorker, a argsWorker, w wantWorker) {
codeAlg, code := cryptoValue(t, ctrl, "testcode") codeAlg, code := cryptoValue(t, ctrl, "testcode")
commands.EXPECT().NotificationCanceled(gomock.Any(), gomock.Any(), notificationID, instanceID, nil).Return(nil) expectTemplateWithNotifyUserQueries(queries, givenTemplate)
return fieldsWorker{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).MockQuerier,
}),
userDataCrypto: codeAlg,
now: testNow,
},
argsWorker{
event: &notification.RetryRequestedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
InstanceID: instanceID,
AggregateID: notificationID,
ResourceOwner: sql.NullString{String: instanceID},
CreationDate: time.Now().Add(-1 * time.Hour),
Typ: notification.RequestedType,
}),
Request: notification.Request{
UserID: userID,
UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType,
NotificationType: domain.NotificationTypeEmail,
URLTemplate: fmt.Sprintf("%s/ui/login/user/invite?userID=%s&loginname={{.LoginName}}&code={{.Code}}&orgID=%s&authRequestID=%s", eventOrigin, userID, orgID, authRequestID),
CodeExpiry: 1 * time.Hour,
Code: code,
UnverifiedNotificationChannel: true,
IsOTP: false,
RequiresPreviousDomain: false,
Args: &domain.NotificationArguments{
ApplicationName: "APP",
},
},
BackOff: 1 * time.Second,
NotifyUser: &query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
},
},
}, w
},
},
{
name: "backoff not done",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fieldsWorker, a argsWorker, w wantWorker) {
codeAlg, code := cryptoValue(t, ctrl, "testcode")
return fieldsWorker{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).MockQuerier,
}),
userDataCrypto: codeAlg,
now: testNow,
},
argsWorker{
event: &notification.RetryRequestedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
InstanceID: instanceID,
AggregateID: notificationID,
ResourceOwner: sql.NullString{String: instanceID},
CreationDate: time.Now(),
Typ: notification.RequestedType,
Seq: 2,
}),
Request: notification.Request{
UserID: userID,
UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType,
NotificationType: domain.NotificationTypeEmail,
URLTemplate: fmt.Sprintf("%s/ui/login/user/invite?userID=%s&loginname={{.LoginName}}&code={{.Code}}&orgID=%s&authRequestID=%s", eventOrigin, userID, orgID, authRequestID),
CodeExpiry: 1 * time.Hour,
Code: code,
UnverifiedNotificationChannel: true,
IsOTP: false,
RequiresPreviousDomain: false,
Args: &domain.NotificationArguments{
ApplicationName: "APP",
},
},
BackOff: 10 * time.Second,
NotifyUser: &query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
},
},
}, w
},
},
{
name: "send ok",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fieldsWorker, a argsWorker, w wantWorker) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: "Invitation to APP",
Content: expectContent,
}
codeAlg, code := cryptoValue(t, ctrl, "testcode")
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().NotificationSent(gomock.Any(), gomock.Any(), notificationID, instanceID).Return(nil)
commands.EXPECT().InviteCodeSent(gomock.Any(), orgID, userID).Return(nil)
return fieldsWorker{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).MockQuerier,
}),
userDataCrypto: codeAlg,
now: testNow,
maxAttempts: 3,
},
argsWorker{
event: &notification.RetryRequestedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
InstanceID: instanceID,
AggregateID: notificationID,
ResourceOwner: sql.NullString{String: instanceID},
CreationDate: time.Now().Add(-2 * time.Second),
Typ: notification.RequestedType,
Seq: 2,
}),
Request: notification.Request{
UserID: userID,
UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType,
NotificationType: domain.NotificationTypeEmail,
URLTemplate: fmt.Sprintf("%s/ui/login/user/invite?userID=%s&loginname={{.LoginName}}&code={{.Code}}&orgID=%s&authRequestID=%s", eventOrigin, userID, orgID, authRequestID),
CodeExpiry: 1 * time.Hour,
Code: code,
UnverifiedNotificationChannel: true,
IsOTP: false,
RequiresPreviousDomain: false,
Args: &domain.NotificationArguments{
ApplicationName: "APP",
},
},
BackOff: 1 * time.Second,
NotifyUser: &query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
},
},
}, w
},
},
{
name: "send failed, retry",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fieldsWorker, a argsWorker, w wantWorker) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: "Invitation to APP",
Content: expectContent,
}
w.sendError = sendError
codeAlg, code := cryptoValue(t, ctrl, "testcode")
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().NotificationRetryRequested(gomock.Any(), gomock.Any(), notificationID, instanceID,
&command.NotificationRetryRequest{
NotificationRequest: command.NotificationRequest{
UserID: userID,
UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggerOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType,
NotificationType: domain.NotificationTypeEmail,
URLTemplate: fmt.Sprintf("%s/ui/login/user/invite?userID=%s&loginname={{.LoginName}}&code={{.Code}}&orgID=%s&authRequestID=%s", eventOrigin, userID, orgID, authRequestID),
CodeExpiry: 1 * time.Hour,
Code: code,
UnverifiedNotificationChannel: true,
IsOTP: false,
RequiresPreviousDomain: false,
Args: &domain.NotificationArguments{
ApplicationName: "APP",
},
},
BackOff: 1 * time.Second,
NotifyUser: &query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
},
},
sendError,
).Return(nil)
return fieldsWorker{ return fieldsWorker{
queries: queries, queries: queries,
commands: commands, commands: commands,
@ -678,23 +355,20 @@ func Test_userNotifier_reduceNotificationRetry(t *testing.T) {
userDataCrypto: codeAlg, userDataCrypto: codeAlg,
now: testNow, now: testNow,
backOff: testBackOff, backOff: testBackOff,
maxAttempts: 3,
}, },
argsWorker{ argsWorker{
event: &notification.RetryRequestedEvent{ job: &river.Job[*notification.Request]{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{ JobRow: &rivertype.JobRow{
InstanceID: instanceID, CreatedAt: time.Now(),
AggregateID: notificationID, },
ResourceOwner: sql.NullString{String: instanceID}, Args: &notification.Request{
CreationDate: time.Now().Add(-2 * time.Second), Aggregate: &eventstore.Aggregate{
Typ: notification.RequestedType, InstanceID: instanceID,
Seq: 2, ID: userID,
}), ResourceOwner: orgID,
Request: notification.Request{ },
UserID: userID, UserID: userID,
UserResourceOwner: orgID, UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin, TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType, EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType, MessageType: domain.InviteUserMessageType,
@ -709,86 +383,9 @@ func Test_userNotifier_reduceNotificationRetry(t *testing.T) {
ApplicationName: "APP", ApplicationName: "APP",
}, },
}, },
BackOff: 1 * time.Second,
NotifyUser: &query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
},
}, },
}, w
},
},
{
name: "send failed (max attempts), cancel",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fieldsWorker, a argsWorker, w wantWorker) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: "Invitation to APP",
Content: expectContent,
}
w.sendError = sendError
codeAlg, code := cryptoValue(t, ctrl, "testcode")
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().NotificationCanceled(gomock.Any(), gomock.Any(), notificationID, instanceID, sendError).Return(nil)
return fieldsWorker{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).MockQuerier,
}),
userDataCrypto: codeAlg,
now: testNow,
backOff: testBackOff,
maxAttempts: 2,
}, },
argsWorker{ w
event: &notification.RetryRequestedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
InstanceID: instanceID,
AggregateID: notificationID,
ResourceOwner: sql.NullString{String: instanceID},
CreationDate: time.Now().Add(-2 * time.Second),
Seq: 2,
Typ: notification.RequestedType,
}),
Request: notification.Request{
UserID: userID,
UserResourceOwner: orgID,
AggregateID: "",
AggregateResourceOwner: "",
TriggeredAtOrigin: eventOrigin,
EventType: user.HumanInviteCodeAddedType,
MessageType: domain.InviteUserMessageType,
NotificationType: domain.NotificationTypeEmail,
URLTemplate: fmt.Sprintf("%s/ui/login/user/invite?userID=%s&loginname={{.LoginName}}&code={{.Code}}&orgID=%s&authRequestID=%s", eventOrigin, userID, orgID, authRequestID),
CodeExpiry: 1 * time.Hour,
Code: code,
UnverifiedNotificationChannel: true,
IsOTP: false,
RequiresPreviousDomain: false,
Args: &domain.NotificationArguments{
ApplicationName: "APP",
},
},
BackOff: 1 * time.Second,
NotifyUser: &query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
},
},
}, w
}, },
}, },
} }
@ -798,11 +395,9 @@ func Test_userNotifier_reduceNotificationRetry(t *testing.T) {
queries := mock.NewMockQueries(ctrl) queries := mock.NewMockQueries(ctrl)
commands := mock.NewMockCommands(ctrl) commands := mock.NewMockCommands(ctrl)
f, a, w := tt.test(ctrl, queries, commands) f, a, w := tt.test(ctrl, queries, commands)
err := newNotificationWorker(t, ctrl, queries, f, a, w).reduceNotificationRetry( err := newNotificationWorker(t, ctrl, queries, f, w).Work(
authz.WithInstanceID(context.Background(), instanceID), authz.WithInstanceID(context.Background(), instanceID),
authz.WithInstanceID(context.Background(), instanceID), a.job,
&sql.Tx{},
a.event.(*notification.RetryRequestedEvent),
) )
if w.err != nil { if w.err != nil {
w.err(t, err) w.err(t, err)
@ -813,22 +408,18 @@ func Test_userNotifier_reduceNotificationRetry(t *testing.T) {
} }
} }
func newNotificationWorker(t *testing.T, ctrl *gomock.Controller, queries *mock.MockQueries, f fieldsWorker, a argsWorker, w wantWorker) *NotificationWorker { func newNotificationWorker(t *testing.T, ctrl *gomock.Controller, queries *mock.MockQueries, f fieldsWorker, w wantWorker) *NotificationWorker {
queries.EXPECT().NotificationProviderByIDAndType(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&query.DebugNotificationProvider{}, nil) queries.EXPECT().NotificationProviderByIDAndType(gomock.Any(), gomock.Any(), gomock.Any()).AnyTimes().Return(&query.DebugNotificationProvider{}, nil)
smtpAlg, _ := cryptoValue(t, ctrl, "smtppw") smtpAlg, _ := cryptoValue(t, ctrl, "smtppw")
channel := channel_mock.NewMockNotificationChannel(ctrl) channel := channel_mock.NewMockNotificationChannel(ctrl)
if w.err == nil { if w.message != nil {
if w.message != nil { channel.EXPECT().HandleMessage(w.message).Return(w.sendError)
w.message.TriggeringEvent = a.event }
channel.EXPECT().HandleMessage(w.message).Return(w.sendError) if w.messageSMS != nil {
} channel.EXPECT().HandleMessage(w.messageSMS).DoAndReturn(func(message *messages.SMS) error {
if w.messageSMS != nil { message.VerificationID = gu.Ptr(verificationID)
w.messageSMS.TriggeringEvent = a.event return w.sendError
channel.EXPECT().HandleMessage(w.messageSMS).DoAndReturn(func(message *messages.SMS) error { })
message.VerificationID = gu.Ptr(verificationID)
return w.sendError
})
}
} }
return &NotificationWorker{ return &NotificationWorker{
commands: f.commands, commands: f.commands,
@ -878,88 +469,9 @@ func newNotificationWorker(t *testing.T, ctrl *gomock.Controller, queries *mock.
}, },
config: WorkerConfig{ config: WorkerConfig{
Workers: 1, Workers: 1,
BulkLimit: 10,
RequeueEvery: 2 * time.Second,
TransactionDuration: 5 * time.Second, TransactionDuration: 5 * time.Second,
MaxAttempts: f.maxAttempts,
MaxTtl: 5 * time.Minute, MaxTtl: 5 * time.Minute,
MinRetryDelay: 1 * time.Second,
MaxRetryDelay: 10 * time.Second,
RetryDelayFactor: 2,
}, },
now: f.now, now: f.now,
backOff: f.backOff,
}
}
func TestNotificationWorker_exponentialBackOff(t *testing.T) {
type fields struct {
config WorkerConfig
}
type args struct {
current time.Duration
}
tests := []struct {
name string
fields fields
args args
wantMin time.Duration
wantMax time.Duration
}{
{
name: "less than min, min - 1.5*min",
fields: fields{
config: WorkerConfig{
MinRetryDelay: 1 * time.Second,
MaxRetryDelay: 5 * time.Second,
RetryDelayFactor: 1.5,
},
},
args: args{
current: 0,
},
wantMin: 1000 * time.Millisecond,
wantMax: 1500 * time.Millisecond,
},
{
name: "current, 1.5*current - max",
fields: fields{
config: WorkerConfig{
MinRetryDelay: 1 * time.Second,
MaxRetryDelay: 5 * time.Second,
RetryDelayFactor: 1.5,
},
},
args: args{
current: 4 * time.Second,
},
wantMin: 4000 * time.Millisecond,
wantMax: 5000 * time.Millisecond,
},
{
name: "max, max",
fields: fields{
config: WorkerConfig{
MinRetryDelay: 1 * time.Second,
MaxRetryDelay: 5 * time.Second,
RetryDelayFactor: 1.5,
},
},
args: args{
current: 5 * time.Second,
},
wantMin: 5000 * time.Millisecond,
wantMax: 5000 * time.Millisecond,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
w := &NotificationWorker{
config: tt.fields.config,
}
b := w.exponentialBackOff(tt.args.current)
assert.GreaterOrEqual(t, b, tt.wantMin)
assert.LessOrEqual(t, b, tt.wantMax)
})
} }
} }

View File

@ -0,0 +1,13 @@
package handlers
import (
"context"
"github.com/riverqueue/river"
"github.com/zitadel/zitadel/internal/queue"
)
type Queue interface {
Insert(ctx context.Context, args river.JobArgs, opts ...queue.InsertOpt) error
}

View File

@ -36,7 +36,6 @@ func NewQuotaNotifier(
queries: queries, queries: queries,
channels: channels, channels: channels,
}) })
} }
func (*quotaNotifier) Name() string { func (*quotaNotifier) Name() string {
@ -72,7 +71,7 @@ func (u *quotaNotifier) reduceNotificationDue(event eventstore.Event) (*handler.
if alreadyHandled { if alreadyHandled {
return nil return nil
} }
err = types.SendJSON(ctx, webhook.Config{CallURL: e.CallURL, Method: http.MethodPost}, u.channels, e, e).WithoutTemplate() err = types.SendJSON(ctx, webhook.Config{CallURL: e.CallURL, Method: http.MethodPost}, u.channels, e, e.Type()).WithoutTemplate()
if err != nil { if err != nil {
return err return err
} }

View File

@ -104,7 +104,7 @@ func (t *telemetryPusher) pushMilestone(ctx context.Context, e *milestone.Reache
Type: e.MilestoneType, Type: e.MilestoneType,
ReachedDate: e.GetReachedDate(), ReachedDate: e.GetReachedDate(),
}, },
e, e.EventType,
).WithoutTemplate(); err != nil { ).WithoutTemplate(); err != nil {
return err return err
} }

View File

@ -7,12 +7,13 @@ import (
http_util "github.com/zitadel/zitadel/internal/api/http" http_util "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/ui/console" "github.com/zitadel/zitadel/internal/api/ui/console"
"github.com/zitadel/zitadel/internal/api/ui/login" "github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/notification/senders" "github.com/zitadel/zitadel/internal/notification/senders"
"github.com/zitadel/zitadel/internal/notification/types" "github.com/zitadel/zitadel/internal/notification/types"
"github.com/zitadel/zitadel/internal/queue"
"github.com/zitadel/zitadel/internal/repository/notification"
"github.com/zitadel/zitadel/internal/repository/session" "github.com/zitadel/zitadel/internal/repository/session"
"github.com/zitadel/zitadel/internal/repository/user" "github.com/zitadel/zitadel/internal/repository/user"
"github.com/zitadel/zitadel/internal/zerrors" "github.com/zitadel/zitadel/internal/zerrors"
@ -86,9 +87,11 @@ const (
) )
type userNotifier struct { type userNotifier struct {
commands Commands
queries *NotificationQueries queries *NotificationQueries
otpEmailTmpl string otpEmailTmpl string
queue Queue
maxAttempts uint8
} }
func NewUserNotifier( func NewUserNotifier(
@ -98,15 +101,17 @@ func NewUserNotifier(
queries *NotificationQueries, queries *NotificationQueries,
channels types.ChannelChains, channels types.ChannelChains,
otpEmailTmpl string, otpEmailTmpl string,
legacyMode bool, workerConfig WorkerConfig,
queue Queue,
) *handler.Handler { ) *handler.Handler {
if legacyMode { if workerConfig.LegacyEnabled {
return NewUserNotifierLegacy(ctx, config, commands, queries, channels, otpEmailTmpl) return NewUserNotifierLegacy(ctx, config, commands, queries, channels, otpEmailTmpl)
} }
return handler.NewHandler(ctx, &config, &userNotifier{ return handler.NewHandler(ctx, &config, &userNotifier{
commands: commands,
queries: queries, queries: queries,
otpEmailTmpl: otpEmailTmpl, otpEmailTmpl: otpEmailTmpl,
queue: queue,
maxAttempts: workerConfig.MaxAttempts,
}) })
} }
@ -198,7 +203,6 @@ func (u *userNotifier) reduceInitCodeAdded(event eventstore.Event) (*handler.Sta
if !ok { if !ok {
return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-EFe2f", "reduce.wrong.event.type %s", user.HumanInitialCodeAddedType) return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-EFe2f", "reduce.wrong.event.type %s", user.HumanInitialCodeAddedType)
} }
return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error { return handler.NewStatement(event, func(ex handler.Executer, projectionName string) error {
ctx := HandlerContext(event.Aggregate()) ctx := HandlerContext(event.Aggregate())
alreadyHandled, err := u.checkIfCodeAlreadyHandledOrExpired(ctx, event, e.Expiry, nil, alreadyHandled, err := u.checkIfCodeAlreadyHandledOrExpired(ctx, event, e.Expiry, nil,
@ -215,23 +219,26 @@ func (u *userNotifier) reduceInitCodeAdded(event eventstore.Event) (*handler.Sta
return err return err
} }
origin := http_util.DomainContext(ctx).Origin() origin := http_util.DomainContext(ctx).Origin()
return u.commands.RequestNotification( return u.queue.Insert(ctx,
ctx, &notification.Request{
e.Aggregate().ResourceOwner, Aggregate: e.Aggregate(),
command.NewNotificationRequest( UserID: e.Aggregate().ID,
e.Aggregate().ID, UserResourceOwner: e.Aggregate().ResourceOwner,
e.Aggregate().ResourceOwner, TriggeredAtOrigin: origin,
origin, EventType: e.EventType,
e.EventType, NotificationType: domain.NotificationTypeEmail,
domain.NotificationTypeEmail, MessageType: domain.InitCodeMessageType,
domain.InitCodeMessageType, Code: e.Code,
). CodeExpiry: e.Expiry,
WithURLTemplate(login.InitUserLinkTemplate(origin, e.Aggregate().ID, e.Aggregate().ResourceOwner, e.AuthRequestID)). IsOTP: false,
WithCode(e.Code, e.Expiry). UnverifiedNotificationChannel: true,
WithArgs(&domain.NotificationArguments{ URLTemplate: login.InitUserLinkTemplate(origin, e.Aggregate().ID, e.Aggregate().ResourceOwner, e.AuthRequestID),
Args: &domain.NotificationArguments{
AuthRequestID: e.AuthRequestID, AuthRequestID: e.AuthRequestID,
}). },
WithUnverifiedChannel(), },
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -262,23 +269,26 @@ func (u *userNotifier) reduceEmailCodeAdded(event eventstore.Event) (*handler.St
return err return err
} }
origin := http_util.DomainContext(ctx).Origin() origin := http_util.DomainContext(ctx).Origin()
return u.queue.Insert(ctx,
return u.commands.RequestNotification(ctx, &notification.Request{
e.Aggregate().ResourceOwner, Aggregate: e.Aggregate(),
command.NewNotificationRequest( UserID: e.Aggregate().ID,
e.Aggregate().ID, UserResourceOwner: e.Aggregate().ResourceOwner,
e.Aggregate().ResourceOwner, TriggeredAtOrigin: origin,
origin, EventType: e.EventType,
e.EventType, NotificationType: domain.NotificationTypeEmail,
domain.NotificationTypeEmail, MessageType: domain.VerifyEmailMessageType,
domain.VerifyEmailMessageType, Code: e.Code,
). CodeExpiry: e.Expiry,
WithURLTemplate(u.emailCodeTemplate(origin, e)). IsOTP: false,
WithCode(e.Code, e.Expiry). UnverifiedNotificationChannel: true,
WithArgs(&domain.NotificationArguments{ URLTemplate: u.emailCodeTemplate(origin, e),
Args: &domain.NotificationArguments{
AuthRequestID: e.AuthRequestID, AuthRequestID: e.AuthRequestID,
}). },
WithUnverifiedChannel(), },
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -315,22 +325,26 @@ func (u *userNotifier) reducePasswordCodeAdded(event eventstore.Event) (*handler
return err return err
} }
origin := http_util.DomainContext(ctx).Origin() origin := http_util.DomainContext(ctx).Origin()
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
origin, TriggeredAtOrigin: origin,
e.EventType, EventType: e.EventType,
e.NotificationType, NotificationType: e.NotificationType,
domain.PasswordResetMessageType, MessageType: domain.PasswordResetMessageType,
). Code: e.Code,
WithURLTemplate(u.passwordCodeTemplate(origin, e)). CodeExpiry: e.Expiry,
WithCode(e.Code, e.Expiry). IsOTP: false,
WithArgs(&domain.NotificationArguments{ UnverifiedNotificationChannel: true,
URLTemplate: u.passwordCodeTemplate(origin, e),
Args: &domain.NotificationArguments{
AuthRequestID: e.AuthRequestID, AuthRequestID: e.AuthRequestID,
}). },
WithUnverifiedChannel(), },
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -363,19 +377,22 @@ func (u *userNotifier) reduceOTPSMSCodeAdded(event eventstore.Event) (*handler.S
if err != nil { if err != nil {
return err return err
} }
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
http_util.DomainContext(ctx).Origin(), TriggeredAtOrigin: http_util.DomainContext(ctx).Origin(),
e.EventType, EventType: e.EventType,
domain.NotificationTypeSms, NotificationType: domain.NotificationTypeSms,
domain.VerifySMSOTPMessageType, MessageType: domain.VerifySMSOTPMessageType,
). Code: e.Code,
WithCode(e.Code, e.Expiry). CodeExpiry: e.Expiry,
WithArgs(otpArgs(ctx, e.Expiry)). IsOTP: true,
WithOTP(), Args: otpArgs(ctx, e.Expiry),
},
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -412,20 +429,22 @@ func (u *userNotifier) reduceSessionOTPSMSChallenged(event eventstore.Event) (*h
args := otpArgs(ctx, e.Expiry) args := otpArgs(ctx, e.Expiry)
args.SessionID = e.Aggregate().ID args.SessionID = e.Aggregate().ID
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
s.UserFactor.ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
s.UserFactor.UserID, UserID: s.UserFactor.UserID,
s.UserFactor.ResourceOwner, UserResourceOwner: s.UserFactor.ResourceOwner,
http_util.DomainContext(ctx).Origin(), TriggeredAtOrigin: http_util.DomainContext(ctx).Origin(),
e.EventType, EventType: e.EventType,
domain.NotificationTypeSms, NotificationType: domain.NotificationTypeSms,
domain.VerifySMSOTPMessageType, MessageType: domain.VerifySMSOTPMessageType,
). Code: e.Code,
WithAggregate(e.Aggregate().ID, e.Aggregate().ResourceOwner). CodeExpiry: e.Expiry,
WithCode(e.Code, e.Expiry). IsOTP: true,
WithOTP(). Args: args,
WithArgs(args), },
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -459,20 +478,23 @@ func (u *userNotifier) reduceOTPEmailCodeAdded(event eventstore.Event) (*handler
} }
args := otpArgs(ctx, e.Expiry) args := otpArgs(ctx, e.Expiry)
args.AuthRequestID = authRequestID args.AuthRequestID = authRequestID
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
origin, TriggeredAtOrigin: origin,
e.EventType, EventType: e.EventType,
domain.NotificationTypeEmail, NotificationType: domain.NotificationTypeEmail,
domain.VerifyEmailOTPMessageType, MessageType: domain.VerifyEmailOTPMessageType,
). Code: e.Code,
WithURLTemplate(login.OTPLinkTemplate(origin, authRequestID, domain.MFATypeOTPEmail)). CodeExpiry: e.Expiry,
WithCode(e.Code, e.Expiry). IsOTP: true,
WithOTP(). URLTemplate: login.OTPLinkTemplate(origin, authRequestID, domain.MFATypeOTPEmail),
WithArgs(args), Args: args,
},
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -509,21 +531,23 @@ func (u *userNotifier) reduceSessionOTPEmailChallenged(event eventstore.Event) (
args := otpArgs(ctx, e.Expiry) args := otpArgs(ctx, e.Expiry)
args.SessionID = e.Aggregate().ID args.SessionID = e.Aggregate().ID
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
s.UserFactor.ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
s.UserFactor.UserID, UserID: s.UserFactor.UserID,
s.UserFactor.ResourceOwner, UserResourceOwner: s.UserFactor.ResourceOwner,
origin, TriggeredAtOrigin: origin,
e.EventType, EventType: e.EventType,
domain.NotificationTypeEmail, NotificationType: domain.NotificationTypeEmail,
domain.VerifyEmailOTPMessageType, MessageType: domain.VerifyEmailOTPMessageType,
). Code: e.Code,
WithAggregate(e.Aggregate().ID, e.Aggregate().ResourceOwner). CodeExpiry: e.Expiry,
WithURLTemplate(u.otpEmailTemplate(origin, e)). IsOTP: true,
WithCode(e.Code, e.Expiry). URLTemplate: u.otpEmailTemplate(origin, e),
WithOTP(). Args: args,
WithArgs(args), },
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -564,22 +588,24 @@ func (u *userNotifier) reduceDomainClaimed(event eventstore.Event) (*handler.Sta
return err return err
} }
origin := http_util.DomainContext(ctx).Origin() origin := http_util.DomainContext(ctx).Origin()
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
origin, TriggeredAtOrigin: origin,
e.EventType, EventType: e.EventType,
domain.NotificationTypeEmail, NotificationType: domain.NotificationTypeEmail,
domain.DomainClaimedMessageType, MessageType: domain.DomainClaimedMessageType,
). URLTemplate: login.LoginLink(origin, e.Aggregate().ResourceOwner),
WithURLTemplate(login.LoginLink(origin, e.Aggregate().ResourceOwner)). UnverifiedNotificationChannel: true,
WithUnverifiedChannel(). Args: &domain.NotificationArguments{
WithPreviousDomain().
WithArgs(&domain.NotificationArguments{
TempUsername: e.UserName, TempUsername: e.UserName,
}), },
RequiresPreviousDomain: true,
},
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -607,21 +633,24 @@ func (u *userNotifier) reducePasswordlessCodeRequested(event eventstore.Event) (
return err return err
} }
origin := http_util.DomainContext(ctx).Origin() origin := http_util.DomainContext(ctx).Origin()
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
origin, TriggeredAtOrigin: origin,
e.EventType, EventType: e.EventType,
domain.NotificationTypeEmail, NotificationType: domain.NotificationTypeEmail,
domain.PasswordlessRegistrationMessageType, MessageType: domain.PasswordlessRegistrationMessageType,
). URLTemplate: u.passwordlessCodeTemplate(origin, e),
WithURLTemplate(u.passwordlessCodeTemplate(origin, e)). Args: &domain.NotificationArguments{
WithCode(e.Code, e.Expiry).
WithArgs(&domain.NotificationArguments{
CodeID: e.ID, CodeID: e.ID,
}), },
CodeExpiry: e.Expiry,
Code: e.Code,
},
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -664,18 +693,20 @@ func (u *userNotifier) reducePasswordChanged(event eventstore.Event) (*handler.S
} }
origin := http_util.DomainContext(ctx).Origin() origin := http_util.DomainContext(ctx).Origin()
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
origin, TriggeredAtOrigin: origin,
e.EventType, EventType: e.EventType,
domain.NotificationTypeEmail, NotificationType: domain.NotificationTypeEmail,
domain.PasswordChangeMessageType, MessageType: domain.PasswordChangeMessageType,
). URLTemplate: console.LoginHintLink(origin, "{{.PreferredLoginName}}"),
WithURLTemplate(console.LoginHintLink(origin, "{{.PreferredLoginName}}")). UnverifiedNotificationChannel: true,
WithUnverifiedChannel(), },
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -706,21 +737,24 @@ func (u *userNotifier) reducePhoneCodeAdded(event eventstore.Event) (*handler.St
return err return err
} }
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
http_util.DomainContext(ctx).Origin(), TriggeredAtOrigin: http_util.DomainContext(ctx).Origin(),
e.EventType, EventType: e.EventType,
domain.NotificationTypeSms, NotificationType: domain.NotificationTypeSms,
domain.VerifyPhoneMessageType, MessageType: domain.VerifyPhoneMessageType,
). CodeExpiry: e.Expiry,
WithCode(e.Code, e.Expiry). Code: e.Code,
WithUnverifiedChannel(). UnverifiedNotificationChannel: true,
WithArgs(&domain.NotificationArguments{ Args: &domain.NotificationArguments{
Domain: http_util.DomainContext(ctx).RequestedDomain(), Domain: http_util.DomainContext(ctx).RequestedDomain(),
}), },
},
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }
@ -755,23 +789,26 @@ func (u *userNotifier) reduceInviteCodeAdded(event eventstore.Event) (*handler.S
if applicationName == "" { if applicationName == "" {
applicationName = "ZITADEL" applicationName = "ZITADEL"
} }
return u.commands.RequestNotification(ctx, return u.queue.Insert(ctx,
e.Aggregate().ResourceOwner, &notification.Request{
command.NewNotificationRequest( Aggregate: e.Aggregate(),
e.Aggregate().ID, UserID: e.Aggregate().ID,
e.Aggregate().ResourceOwner, UserResourceOwner: e.Aggregate().ResourceOwner,
origin, TriggeredAtOrigin: origin,
e.EventType, EventType: e.EventType,
domain.NotificationTypeEmail, NotificationType: domain.NotificationTypeEmail,
domain.InviteUserMessageType, MessageType: domain.InviteUserMessageType,
). CodeExpiry: e.Expiry,
WithURLTemplate(u.inviteCodeTemplate(origin, e)). Code: e.Code,
WithCode(e.Code, e.Expiry). UnverifiedNotificationChannel: true,
WithUnverifiedChannel(). URLTemplate: u.inviteCodeTemplate(origin, e),
WithArgs(&domain.NotificationArguments{ Args: &domain.NotificationArguments{
AuthRequestID: e.AuthRequestID, AuthRequestID: e.AuthRequestID,
ApplicationName: applicationName, ApplicationName: applicationName,
}), },
},
queue.WithQueueName(notification.QueueName),
queue.WithMaxAttempts(u.maxAttempts),
) )
}), nil }), nil
} }

View File

@ -171,7 +171,7 @@ func (u *userNotifierLegacy) reduceInitCodeAdded(event eventstore.Event) (*handl
if err != nil { if err != nil {
return err return err
} }
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e). err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e.Type()).
SendUserInitCode(ctx, notifyUser, code, e.AuthRequestID) SendUserInitCode(ctx, notifyUser, code, e.AuthRequestID)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
@ -232,7 +232,7 @@ func (u *userNotifierLegacy) reduceEmailCodeAdded(event eventstore.Event) (*hand
if err != nil { if err != nil {
return err return err
} }
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e). err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event.Type()).
SendEmailVerificationCode(ctx, notifyUser, code, e.URLTemplate, e.AuthRequestID) SendEmailVerificationCode(ctx, notifyUser, code, e.URLTemplate, e.AuthRequestID)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
@ -296,9 +296,9 @@ func (u *userNotifierLegacy) reducePasswordCodeAdded(event eventstore.Event) (*h
return err return err
} }
generatorInfo := new(senders.CodeGeneratorInfo) generatorInfo := new(senders.CodeGeneratorInfo)
notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e) notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event.Type())
if e.NotificationType == domain.NotificationTypeSms { if e.NotificationType == domain.NotificationTypeSms {
notify = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e, generatorInfo) notify = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e.Type(), e.Aggregate().InstanceID, e.ID, generatorInfo)
} }
err = notify.SendPasswordCode(ctx, notifyUser, code, e.URLTemplate, e.AuthRequestID) err = notify.SendPasswordCode(ctx, notifyUser, code, e.URLTemplate, e.AuthRequestID)
if err != nil { if err != nil {
@ -396,7 +396,7 @@ func (u *userNotifierLegacy) reduceOTPSMS(
return nil, err return nil, err
} }
generatorInfo := new(senders.CodeGeneratorInfo) generatorInfo := new(senders.CodeGeneratorInfo)
notify := types.SendSMS(ctx, u.channels, translator, notifyUser, colors, event, generatorInfo) notify := types.SendSMS(ctx, u.channels, translator, notifyUser, colors, event.Type(), event.Aggregate().InstanceID, event.Aggregate().ID, generatorInfo)
err = notify.SendOTPSMSCode(ctx, plainCode, expiry) err = notify.SendOTPSMSCode(ctx, plainCode, expiry)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
@ -522,7 +522,7 @@ func (u *userNotifierLegacy) reduceOTPEmail(
if err != nil { if err != nil {
return nil, err return nil, err
} }
notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event) notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event.Type())
err = notify.SendOTPEmailCode(ctx, url, plainCode, expiry) err = notify.SendOTPEmailCode(ctx, url, plainCode, expiry)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
@ -576,7 +576,7 @@ func (u *userNotifierLegacy) reduceDomainClaimed(event eventstore.Event) (*handl
if err != nil { if err != nil {
return err return err
} }
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e). err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event.Type()).
SendDomainClaimed(ctx, notifyUser, e.UserName) SendDomainClaimed(ctx, notifyUser, e.UserName)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
@ -634,7 +634,7 @@ func (u *userNotifierLegacy) reducePasswordlessCodeRequested(event eventstore.Ev
if err != nil { if err != nil {
return err return err
} }
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e). err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event.Type()).
SendPasswordlessRegistrationLink(ctx, notifyUser, code, e.ID, e.URLTemplate) SendPasswordlessRegistrationLink(ctx, notifyUser, code, e.ID, e.URLTemplate)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
@ -697,7 +697,7 @@ func (u *userNotifierLegacy) reducePasswordChanged(event eventstore.Event) (*han
if err != nil { if err != nil {
return err return err
} }
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e). err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event.Type()).
SendPasswordChange(ctx, notifyUser) SendPasswordChange(ctx, notifyUser)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
@ -756,7 +756,7 @@ func (u *userNotifierLegacy) reducePhoneCodeAdded(event eventstore.Event) (*hand
return err return err
} }
generatorInfo := new(senders.CodeGeneratorInfo) generatorInfo := new(senders.CodeGeneratorInfo)
if err = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e, generatorInfo). if err = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e.Type(), e.Aggregate().InstanceID, e.ID, generatorInfo).
SendPhoneVerificationCode(ctx, code); err != nil { SendPhoneVerificationCode(ctx, code); err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {
// if the notification was canceled, we don't want to return the error, so there is no retry // if the notification was canceled, we don't want to return the error, so there is no retry
@ -814,7 +814,7 @@ func (u *userNotifierLegacy) reduceInviteCodeAdded(event eventstore.Event) (*han
if err != nil { if err != nil {
return err return err
} }
notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e) notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event.Type())
err = notify.SendInviteCode(ctx, notifyUser, code, e.ApplicationName, e.URLTemplate, e.AuthRequestID) err = notify.SendInviteCode(ctx, notifyUser, code, e.ApplicationName, e.URLTemplate, e.AuthRequestID)
if err != nil { if err != nil {
if errors.Is(err, &channels.CancelError{}) { if errors.Is(err, &channels.CancelError{}) {

View File

@ -611,328 +611,331 @@ func Test_userNotifierLegacy_reducePasswordCodeAdded(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
test func(*gomock.Controller, *mock.MockQueries, *mock.MockCommands) (fields, args, wantLegacy) test func(*gomock.Controller, *mock.MockQueries, *mock.MockCommands) (fields, args, wantLegacy)
}{{ }{
name: "asset url with event trigger url", {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { name: "asset url with event trigger url",
givenTemplate := "{{.LogoURL}}" test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL) givenTemplate := "{{.LogoURL}}"
w.message = &wantLegacyEmail{ expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
email: &messages.Email{ w.message = &wantLegacyEmail{
Recipients: []string{lastEmail}, email: &messages.Email{
Subject: expectMailSubject, Recipients: []string{lastEmail},
Content: expectContent, Subject: expectMailSubject,
}, Content: expectContent,
}
codeAlg, code := cryptoValue(t, ctrl, "testcode")
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
userDataCrypto: codeAlg,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: code,
Expiry: time.Hour,
URLTemplate: "",
CodeReturned: false,
TriggeredAtOrigin: eventOrigin,
}, },
}, w }
}, codeAlg, code := cryptoValue(t, ctrl, "testcode")
}, { expectTemplateWithNotifyUserQueries(queries, givenTemplate)
name: "asset url without event trigger url", commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { return fields{
givenTemplate := "{{.LogoURL}}" queries: queries,
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL) commands: commands,
w.message = &wantLegacyEmail{ es: eventstore.NewEventstore(&eventstore.Config{
email: &messages.Email{ Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
},
}
codeAlg, code := cryptoValue(t, ctrl, "testcode")
queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
Domains: []*query.InstanceDomain{{
Domain: instancePrimaryDomain,
IsPrimary: true,
}},
}, nil)
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
userDataCrypto: codeAlg,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}), }),
Code: code, userDataCrypto: codeAlg,
Expiry: time.Hour, }, args{
URLTemplate: "", event: &user.HumanPasswordCodeAddedEvent{
CodeReturned: false, BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: code,
Expiry: time.Hour,
URLTemplate: "",
CodeReturned: false,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
}, {
name: "asset url without event trigger url",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = &wantLegacyEmail{
email: &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}, },
}, w }
}, codeAlg, code := cryptoValue(t, ctrl, "testcode")
}, { queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
name: "button url with event trigger url", Domains: []*query.InstanceDomain{{
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { Domain: instancePrimaryDomain,
givenTemplate := "{{.URL}}" IsPrimary: true,
testCode := "testcode" }},
expectContent := fmt.Sprintf("%s/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", eventOrigin, "", testCode, orgID, userID) }, nil)
w.message = &wantLegacyEmail{ expectTemplateWithNotifyUserQueries(queries, givenTemplate)
email: &messages.Email{ commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
Recipients: []string{lastEmail}, return fields{
Subject: expectMailSubject, queries: queries,
Content: expectContent, commands: commands,
}, es: eventstore.NewEventstore(&eventstore.Config{
} Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
codeAlg, code := cryptoValue(t, ctrl, testCode)
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
userDataCrypto: codeAlg,
SMSTokenCrypto: nil,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}), }),
Code: code, userDataCrypto: codeAlg,
Expiry: time.Hour, }, args{
URLTemplate: "", event: &user.HumanPasswordCodeAddedEvent{
CodeReturned: false, BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
TriggeredAtOrigin: eventOrigin, AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: code,
Expiry: time.Hour,
URLTemplate: "",
CodeReturned: false,
},
}, w
},
}, {
name: "button url with event trigger url",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", eventOrigin, "", testCode, orgID, userID)
w.message = &wantLegacyEmail{
email: &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}, },
}, w }
}, codeAlg, code := cryptoValue(t, ctrl, testCode)
}, { expectTemplateWithNotifyUserQueries(queries, givenTemplate)
name: "button url without event trigger url", commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { return fields{
givenTemplate := "{{.URL}}" queries: queries,
testCode := "testcode" commands: commands,
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, "", testCode, orgID, userID) es: eventstore.NewEventstore(&eventstore.Config{
w.message = &wantLegacyEmail{ Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
email: &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
},
}
codeAlg, code := cryptoValue(t, ctrl, testCode)
queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
Domains: []*query.InstanceDomain{{
Domain: instancePrimaryDomain,
IsPrimary: true,
}},
}, nil)
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
userDataCrypto: codeAlg,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}), }),
Code: code, userDataCrypto: codeAlg,
Expiry: time.Hour, SMSTokenCrypto: nil,
URLTemplate: "", }, args{
CodeReturned: false, event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: code,
Expiry: time.Hour,
URLTemplate: "",
CodeReturned: false,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
}, {
name: "button url without event trigger url",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, "", testCode, orgID, userID)
w.message = &wantLegacyEmail{
email: &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}, },
}, w }
}, codeAlg, code := cryptoValue(t, ctrl, testCode)
}, { queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
name: "button url without event trigger url with authRequestID", Domains: []*query.InstanceDomain{{
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { Domain: instancePrimaryDomain,
givenTemplate := "{{.URL}}" IsPrimary: true,
testCode := "testcode" }},
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, authRequestID, testCode, orgID, userID) }, nil)
w.message = &wantLegacyEmail{ expectTemplateWithNotifyUserQueries(queries, givenTemplate)
email: &messages.Email{ commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
Recipients: []string{lastEmail}, return fields{
Subject: expectMailSubject, queries: queries,
Content: expectContent, commands: commands,
}, es: eventstore.NewEventstore(&eventstore.Config{
} Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
codeAlg, code := cryptoValue(t, ctrl, testCode)
queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
Domains: []*query.InstanceDomain{{
Domain: instancePrimaryDomain,
IsPrimary: true,
}},
}, nil)
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
userDataCrypto: codeAlg,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}), }),
Code: code, userDataCrypto: codeAlg,
Expiry: time.Hour, }, args{
URLTemplate: "", event: &user.HumanPasswordCodeAddedEvent{
CodeReturned: false, BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AuthRequestID: authRequestID, AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: code,
Expiry: time.Hour,
URLTemplate: "",
CodeReturned: false,
},
}, w
},
}, {
name: "button url without event trigger url with authRequestID",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, authRequestID, testCode, orgID, userID)
w.message = &wantLegacyEmail{
email: &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}, },
}, w }
}, codeAlg, code := cryptoValue(t, ctrl, testCode)
}, { queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
name: "button url with url template and event trigger url", Domains: []*query.InstanceDomain{{
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { Domain: instancePrimaryDomain,
givenTemplate := "{{.URL}}" IsPrimary: true,
urlTemplate := "https://my.custom.url/org/{{.OrgID}}/user/{{.UserID}}/verify/{{.Code}}" }},
testCode := "testcode" }, nil)
expectContent := fmt.Sprintf("https://my.custom.url/org/%s/user/%s/verify/%s", orgID, userID, testCode) expectTemplateWithNotifyUserQueries(queries, givenTemplate)
w.message = &wantLegacyEmail{ commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
email: &messages.Email{ return fields{
Recipients: []string{lastEmail}, queries: queries,
Subject: expectMailSubject, commands: commands,
Content: expectContent, es: eventstore.NewEventstore(&eventstore.Config{
}, Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}
codeAlg, code := cryptoValue(t, ctrl, testCode)
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
userDataCrypto: codeAlg,
SMSTokenCrypto: nil,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}), }),
Code: code, userDataCrypto: codeAlg,
Expiry: time.Hour, }, args{
URLTemplate: urlTemplate, event: &user.HumanPasswordCodeAddedEvent{
CodeReturned: false, BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
TriggeredAtOrigin: eventOrigin, AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: code,
Expiry: time.Hour,
URLTemplate: "",
CodeReturned: false,
AuthRequestID: authRequestID,
},
}, w
},
}, {
name: "button url with url template and event trigger url",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
givenTemplate := "{{.URL}}"
urlTemplate := "https://my.custom.url/org/{{.OrgID}}/user/{{.UserID}}/verify/{{.Code}}"
testCode := "testcode"
expectContent := fmt.Sprintf("https://my.custom.url/org/%s/user/%s/verify/%s", orgID, userID, testCode)
w.message = &wantLegacyEmail{
email: &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}, },
}, w }
}, codeAlg, code := cryptoValue(t, ctrl, testCode)
}, { expectTemplateWithNotifyUserQueries(queries, givenTemplate)
name: "external code", commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { return fields{
givenTemplate := "{{.URL}}" queries: queries,
expectContent := "We received a password reset request. Please use the button below to reset your password. (Code ) If you didn't ask for this mail, please ignore it." commands: commands,
w.messageSMS = &wantLegacySMS{ es: eventstore.NewEventstore(&eventstore.Config{
sms: &messages.SMS{ Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: lastPhone,
Content: expectContent,
},
}
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
SMSTokenCrypto: nil,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}), }),
Code: nil, userDataCrypto: codeAlg,
Expiry: 0, SMSTokenCrypto: nil,
URLTemplate: "", }, args{
CodeReturned: false, event: &user.HumanPasswordCodeAddedEvent{
NotificationType: domain.NotificationTypeSms, BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
GeneratorID: smsProviderID, AggregateID: userID,
TriggeredAtOrigin: eventOrigin, ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: code,
Expiry: time.Hour,
URLTemplate: urlTemplate,
CodeReturned: false,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
}, {
name: "external code",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
givenTemplate := "{{.URL}}"
expectContent := "We received a password reset request. Please use the button below to reset your password. (Code ) If you didn't ask for this mail, please ignore it."
w.messageSMS = &wantLegacySMS{
sms: &messages.SMS{
SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: lastPhone,
Content: expectContent,
UserID: userID,
}, },
}, w }
}, expectTemplateWithNotifyUserQueries(queries, givenTemplate)
}, { commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
name: "cancel error, no reduce error expected", return fields{
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { queries: queries,
givenTemplate := "{{.URL}}" commands: commands,
expectContent := "We received a password reset request. Please use the button below to reset your password. (Code ) If you didn't ask for this mail, please ignore it." es: eventstore.NewEventstore(&eventstore.Config{
w.messageSMS = &wantLegacySMS{ Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
sms: &messages.SMS{
SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: lastPhone,
Content: expectContent,
},
err: channels.NewCancelError(nil),
}
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
SMSTokenCrypto: nil,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}), }),
Code: nil, SMSTokenCrypto: nil,
Expiry: 0, }, args{
URLTemplate: "", event: &user.HumanPasswordCodeAddedEvent{
CodeReturned: false, BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
NotificationType: domain.NotificationTypeSms, AggregateID: userID,
GeneratorID: smsProviderID, ResourceOwner: sql.NullString{String: orgID},
TriggeredAtOrigin: eventOrigin, CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: 0,
URLTemplate: "",
CodeReturned: false,
NotificationType: domain.NotificationTypeSms,
GeneratorID: smsProviderID,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
}, {
name: "cancel error, no reduce error expected",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
givenTemplate := "{{.URL}}"
expectContent := "We received a password reset request. Please use the button below to reset your password. (Code ) If you didn't ask for this mail, please ignore it."
w.messageSMS = &wantLegacySMS{
sms: &messages.SMS{
SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: lastPhone,
Content: expectContent,
UserID: userID,
}, },
}, w err: channels.NewCancelError(nil),
}
expectTemplateWithNotifyUserQueries(queries, givenTemplate)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
SMSTokenCrypto: nil,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: 0,
URLTemplate: "",
CodeReturned: false,
NotificationType: domain.NotificationTypeSms,
GeneratorID: smsProviderID,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
}, },
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1774,131 +1777,138 @@ func Test_userNotifierLegacy_reduceOTPSMSChallenged(t *testing.T) {
tests := []struct { tests := []struct {
name string name string
test func(*gomock.Controller, *mock.MockQueries, *mock.MockCommands) (fields, args, wantLegacy) test func(*gomock.Controller, *mock.MockQueries, *mock.MockCommands) (fields, args, wantLegacy)
}{{ }{
name: "asset url with event trigger url", {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { name: "asset url with event trigger url",
testCode := "" test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
expiry := 0 * time.Hour testCode := ""
expectContent := fmt.Sprintf(`%[1]s is your one-time password for %[2]s. Use it within the next %[3]s. expiry := 0 * time.Hour
expectContent := fmt.Sprintf(`%[1]s is your one-time password for %[2]s. Use it within the next %[3]s.
@%[2]s #%[1]s`, testCode, eventOriginDomain, expiry) @%[2]s #%[1]s`, testCode, eventOriginDomain, expiry)
w.messageSMS = &wantLegacySMS{ w.messageSMS = &wantLegacySMS{
sms: &messages.SMS{ sms: &messages.SMS{
SenderPhoneNumber: "senderNumber", SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: verifiedPhone, RecipientPhoneNumber: verifiedPhone,
Content: expectContent, Content: expectContent,
}, JobID: userID,
} UserID: userID,
expectTemplateWithNotifyUserQueriesSMS(queries)
queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil)
commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
TriggeredAtOrigin: eventOrigin,
}, },
}, w }
}, expectTemplateWithNotifyUserQueriesSMS(queries)
}, { queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil)
name: "asset url without event trigger url", commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { return fields{
testCode := "" queries: queries,
expiry := 0 * time.Hour commands: commands,
expectContent := fmt.Sprintf(`%[1]s is your one-time password for %[2]s. Use it within the next %[3]s. es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
}, {
name: "asset url without event trigger url",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
testCode := ""
expiry := 0 * time.Hour
expectContent := fmt.Sprintf(`%[1]s is your one-time password for %[2]s. Use it within the next %[3]s.
@%[2]s #%[1]s`, testCode, instancePrimaryDomain, expiry) @%[2]s #%[1]s`, testCode, instancePrimaryDomain, expiry)
w.messageSMS = &wantLegacySMS{ w.messageSMS = &wantLegacySMS{
sms: &messages.SMS{ sms: &messages.SMS{
SenderPhoneNumber: "senderNumber", SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: verifiedPhone, RecipientPhoneNumber: verifiedPhone,
Content: expectContent, Content: expectContent,
}, JobID: userID,
} UserID: userID,
expectTemplateWithNotifyUserQueriesSMS(queries)
queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil)
queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
Domains: []*query.InstanceDomain{{
Domain: instancePrimaryDomain,
IsPrimary: true,
}},
}, nil)
commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
}, },
}, w }
}, expectTemplateWithNotifyUserQueriesSMS(queries)
}, { queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil)
name: "cancel error, no reduce error expected", queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) { Domains: []*query.InstanceDomain{{
testCode := "" Domain: instancePrimaryDomain,
expiry := 0 * time.Hour IsPrimary: true,
expectContent := fmt.Sprintf(`%[1]s is your one-time password for %[2]s. Use it within the next %[3]s. }},
}, nil)
commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
},
}, w
},
}, {
name: "cancel error, no reduce error expected",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w wantLegacy) {
testCode := ""
expiry := 0 * time.Hour
expectContent := fmt.Sprintf(`%[1]s is your one-time password for %[2]s. Use it within the next %[3]s.
@%[2]s #%[1]s`, testCode, instancePrimaryDomain, expiry) @%[2]s #%[1]s`, testCode, instancePrimaryDomain, expiry)
w.messageSMS = &wantLegacySMS{ w.messageSMS = &wantLegacySMS{
sms: &messages.SMS{ sms: &messages.SMS{
SenderPhoneNumber: "senderNumber", SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: verifiedPhone, RecipientPhoneNumber: verifiedPhone,
Content: expectContent, Content: expectContent,
}, JobID: userID,
err: channels.NewCancelError(nil), UserID: userID,
}
expectTemplateWithNotifyUserQueriesSMS(queries)
queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil)
queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
Domains: []*query.InstanceDomain{{
Domain: instancePrimaryDomain,
IsPrimary: true,
}},
}, nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
}, },
}, w err: channels.NewCancelError(nil),
}
expectTemplateWithNotifyUserQueriesSMS(queries)
queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any(), nil).Return(&query.Session{}, nil)
queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
Domains: []*query.InstanceDomain{{
Domain: instancePrimaryDomain,
IsPrimary: true,
}},
}, nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
},
}, w
},
}, },
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
@ -1938,11 +1948,11 @@ func newUserNotifierLegacy(t *testing.T, ctrl *gomock.Controller, queries *mock.
channel := channel_mock.NewMockNotificationChannel(ctrl) channel := channel_mock.NewMockNotificationChannel(ctrl)
if w.err == nil { if w.err == nil {
if w.message != nil { if w.message != nil {
w.message.email.TriggeringEvent = a.event w.message.email.TriggeringEventType = a.event.Type()
channel.EXPECT().HandleMessage(w.message.email).Return(w.message.err) channel.EXPECT().HandleMessage(w.message.email).Return(w.message.err)
} }
if w.messageSMS != nil { if w.messageSMS != nil {
w.messageSMS.sms.TriggeringEvent = a.event w.messageSMS.sms.TriggeringEventType = a.event.Type()
channel.EXPECT().HandleMessage(w.messageSMS.sms).DoAndReturn(func(message *messages.SMS) error { channel.EXPECT().HandleMessage(w.messageSMS.sms).DoAndReturn(func(message *messages.SMS) error {
message.VerificationID = gu.Ptr(verificationID) message.VerificationID = gu.Ptr(verificationID)
return w.messageSMS.err return w.messageSMS.err

File diff suppressed because it is too large Load Diff

View File

@ -19,15 +19,15 @@ var (
var _ channels.Message = (*Email)(nil) var _ channels.Message = (*Email)(nil)
type Email struct { type Email struct {
Recipients []string Recipients []string
BCC []string BCC []string
CC []string CC []string
SenderEmail string SenderEmail string
SenderName string SenderName string
ReplyToAddress string ReplyToAddress string
Subject string Subject string
Content string Content string
TriggeringEvent eventstore.Event TriggeringEventType eventstore.EventType
} }
func (msg *Email) GetContent() (string, error) { func (msg *Email) GetContent() (string, error) {
@ -61,8 +61,8 @@ func (msg *Email) GetContent() (string, error) {
return message, nil return message, nil
} }
func (msg *Email) GetTriggeringEvent() eventstore.Event { func (msg *Email) GetTriggeringEventType() eventstore.EventType {
return msg.TriggeringEvent return msg.TriggeringEventType
} }
func isHTML(input string) bool { func isHTML(input string) bool {

View File

@ -12,8 +12,8 @@ import (
var _ channels.Message = (*Form)(nil) var _ channels.Message = (*Form)(nil)
type Form struct { type Form struct {
Serializable any Serializable any
TriggeringEvent eventstore.Event TriggeringEventType eventstore.EventType
} }
func (msg *Form) GetContent() (string, error) { func (msg *Form) GetContent() (string, error) {
@ -22,6 +22,6 @@ func (msg *Form) GetContent() (string, error) {
return values.Encode(), err return values.Encode(), err
} }
func (msg *Form) GetTriggeringEvent() eventstore.Event { func (msg *Form) GetTriggeringEventType() eventstore.EventType {
return msg.TriggeringEvent return msg.TriggeringEventType
} }

View File

@ -10,8 +10,8 @@ import (
var _ channels.Message = (*JSON)(nil) var _ channels.Message = (*JSON)(nil)
type JSON struct { type JSON struct {
Serializable interface{} Serializable interface{}
TriggeringEvent eventstore.Event TriggeringEventType eventstore.EventType
} }
func (msg *JSON) GetContent() (string, error) { func (msg *JSON) GetContent() (string, error) {
@ -19,6 +19,6 @@ func (msg *JSON) GetContent() (string, error) {
return string(bytes), err return string(bytes), err
} }
func (msg *JSON) GetTriggeringEvent() eventstore.Event { func (msg *JSON) GetTriggeringEventType() eventstore.EventType {
return msg.TriggeringEvent return msg.TriggeringEventType
} }

View File

@ -11,16 +11,19 @@ type SMS struct {
SenderPhoneNumber string SenderPhoneNumber string
RecipientPhoneNumber string RecipientPhoneNumber string
Content string Content string
TriggeringEvent eventstore.Event TriggeringEventType eventstore.EventType
// VerificationID is set by the sender // VerificationID is set by the sender
VerificationID *string VerificationID *string
InstanceID string
JobID string
UserID string
} }
func (msg *SMS) GetContent() (string, error) { func (msg *SMS) GetContent() (string, error) {
return msg.Content, nil return msg.Content, nil
} }
func (msg *SMS) GetTriggeringEvent() eventstore.Event { func (msg *SMS) GetTriggeringEventType() eventstore.EventType {
return msg.TriggeringEvent return msg.TriggeringEventType
} }

View File

@ -13,6 +13,7 @@ import (
_ "github.com/zitadel/zitadel/internal/notification/statik" _ "github.com/zitadel/zitadel/internal/notification/statik"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/queue"
) )
var ( var (
@ -35,10 +36,15 @@ func Register(
userEncryption, smtpEncryption, smsEncryption, keysEncryptionAlg crypto.EncryptionAlgorithm, userEncryption, smtpEncryption, smsEncryption, keysEncryptionAlg crypto.EncryptionAlgorithm,
tokenLifetime time.Duration, tokenLifetime time.Duration,
client *database.DB, client *database.DB,
queue *queue.Queue,
) { ) {
if !notificationWorkerConfig.LegacyEnabled {
queue.ShouldStart()
}
q := handlers.NewNotificationQueries(queries, es, externalDomain, externalPort, externalSecure, fileSystemPath, userEncryption, smtpEncryption, smsEncryption) q := handlers.NewNotificationQueries(queries, es, externalDomain, externalPort, externalSecure, fileSystemPath, userEncryption, smtpEncryption, smsEncryption)
c := newChannels(q) c := newChannels(q)
projections = append(projections, handlers.NewUserNotifier(ctx, projection.ApplyCustomConfig(userHandlerCustomConfig), commands, q, c, otpEmailTmpl, notificationWorkerConfig.LegacyEnabled)) projections = append(projections, handlers.NewUserNotifier(ctx, projection.ApplyCustomConfig(userHandlerCustomConfig), commands, q, c, otpEmailTmpl, notificationWorkerConfig, queue))
projections = append(projections, handlers.NewQuotaNotifier(ctx, projection.ApplyCustomConfig(quotaHandlerCustomConfig), commands, q, c)) projections = append(projections, handlers.NewQuotaNotifier(ctx, projection.ApplyCustomConfig(quotaHandlerCustomConfig), commands, q, c))
projections = append(projections, handlers.NewBackChannelLogoutNotifier( projections = append(projections, handlers.NewBackChannelLogoutNotifier(
ctx, ctx,
@ -53,14 +59,13 @@ func Register(
if telemetryCfg.Enabled { if telemetryCfg.Enabled {
projections = append(projections, handlers.NewTelemetryPusher(ctx, telemetryCfg, projection.ApplyCustomConfig(telemetryHandlerCustomConfig), commands, q, c)) projections = append(projections, handlers.NewTelemetryPusher(ctx, telemetryCfg, projection.ApplyCustomConfig(telemetryHandlerCustomConfig), commands, q, c))
} }
worker = handlers.NewNotificationWorker(notificationWorkerConfig, commands, q, es, client, c) worker = handlers.NewNotificationWorker(notificationWorkerConfig, commands, q, es, client, c, queue)
} }
func Start(ctx context.Context) { func Start(ctx context.Context) {
for _, projection := range projections { for _, projection := range projections {
projection.Start(ctx) projection.Start(ctx)
} }
worker.Start(ctx)
} }
func ProjectInstance(ctx context.Context) error { func ProjectInstance(ctx context.Context) error {

View File

@ -39,7 +39,7 @@ func SendEmail(
translator *i18n.Translator, translator *i18n.Translator,
user *query.NotifyUser, user *query.NotifyUser,
colors *query.LabelPolicy, colors *query.LabelPolicy,
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
) Notify { ) Notify {
return func( return func(
urlTmpl string, urlTmpl string,
@ -66,7 +66,7 @@ func SendEmail(
data, data,
args, args,
allowUnverifiedNotificationChannel, allowUnverifiedNotificationChannel,
triggeringEvent, triggeringEventType,
) )
} }
} }
@ -102,7 +102,9 @@ func SendSMS(
translator *i18n.Translator, translator *i18n.Translator,
user *query.NotifyUser, user *query.NotifyUser,
colors *query.LabelPolicy, colors *query.LabelPolicy,
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
instanceID string,
jobID string,
generatorInfo *senders.CodeGeneratorInfo, generatorInfo *senders.CodeGeneratorInfo,
) Notify { ) Notify {
return func( return func(
@ -124,7 +126,9 @@ func SendSMS(
data, data,
args, args,
allowUnverifiedNotificationChannel, allowUnverifiedNotificationChannel,
triggeringEvent, triggeringEventType,
instanceID,
jobID,
generatorInfo, generatorInfo,
) )
} }
@ -135,7 +139,7 @@ func SendJSON(
webhookConfig webhook.Config, webhookConfig webhook.Config,
channels ChannelChains, channels ChannelChains,
serializable interface{}, serializable interface{},
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
) Notify { ) Notify {
return func(_ string, _ map[string]interface{}, _ string, _ bool) error { return func(_ string, _ map[string]interface{}, _ string, _ bool) error {
return handleWebhook( return handleWebhook(
@ -143,7 +147,7 @@ func SendJSON(
webhookConfig, webhookConfig,
channels, channels,
serializable, serializable,
triggeringEvent, triggeringEventType,
) )
} }
} }
@ -153,7 +157,7 @@ func SendSecurityTokenEvent(
setConfig set.Config, setConfig set.Config,
channels ChannelChains, channels ChannelChains,
token any, token any,
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
) Notify { ) Notify {
return func(_ string, _ map[string]interface{}, _ string, _ bool) error { return func(_ string, _ map[string]interface{}, _ string, _ bool) error {
return handleSecurityTokenEvent( return handleSecurityTokenEvent(
@ -161,7 +165,7 @@ func SendSecurityTokenEvent(
setConfig, setConfig,
channels, channels,
token, token,
triggeringEvent, triggeringEventType,
) )
} }
} }

View File

@ -13,11 +13,11 @@ func handleSecurityTokenEvent(
setConfig set.Config, setConfig set.Config,
channels ChannelChains, channels ChannelChains,
token any, token any,
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
) error { ) error {
message := &messages.Form{ message := &messages.Form{
Serializable: token, Serializable: token,
TriggeringEvent: triggeringEvent, TriggeringEventType: triggeringEventType,
} }
setChannels, err := channels.SecurityTokenEvent(ctx, setConfig) setChannels, err := channels.SecurityTokenEvent(ctx, setConfig)
if err != nil { if err != nil {

View File

@ -23,7 +23,7 @@ func generateEmail(
data templates.TemplateData, data templates.TemplateData,
args map[string]interface{}, args map[string]interface{},
lastEmail bool, lastEmail bool,
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
) error { ) error {
emailChannels, config, err := channels.Email(ctx) emailChannels, config, err := channels.Email(ctx)
logging.OnError(err).Error("could not create email channel") logging.OnError(err).Error("could not create email channel")
@ -38,10 +38,10 @@ func generateEmail(
} }
if config.SMTPConfig != nil { if config.SMTPConfig != nil {
message := &messages.Email{ message := &messages.Email{
Recipients: []string{recipient}, Recipients: []string{recipient},
Subject: data.Subject, Subject: data.Subject,
Content: html.UnescapeString(template), Content: html.UnescapeString(template),
TriggeringEvent: triggeringEvent, TriggeringEventType: triggeringEventType,
} }
return emailChannels.HandleMessage(message) return emailChannels.HandleMessage(message)
} }
@ -52,7 +52,7 @@ func generateEmail(
} }
contextInfo := map[string]interface{}{ contextInfo := map[string]interface{}{
"recipientEmailAddress": recipient, "recipientEmailAddress": recipient,
"eventType": triggeringEvent.Type(), "eventType": triggeringEventType,
"provider": config.ProviderConfig, "provider": config.ProviderConfig,
} }
@ -62,7 +62,7 @@ func generateEmail(
TemplateData: data, TemplateData: data,
Args: caseArgs, Args: caseArgs,
}, },
TriggeringEvent: triggeringEvent, TriggeringEventType: triggeringEventType,
} }
webhookChannels, err := channels.Webhook(ctx, *config.WebhookConfig) webhookChannels, err := channels.Webhook(ctx, *config.WebhookConfig)
if err != nil { if err != nil {

View File

@ -28,7 +28,9 @@ func generateSms(
data templates.TemplateData, data templates.TemplateData,
args map[string]interface{}, args map[string]interface{},
lastPhone bool, lastPhone bool,
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
instanceID string,
jobID string,
generatorInfo *senders.CodeGeneratorInfo, generatorInfo *senders.CodeGeneratorInfo,
) error { ) error {
smsChannels, config, err := channels.SMS(ctx) smsChannels, config, err := channels.SMS(ctx)
@ -51,7 +53,10 @@ func generateSms(
SenderPhoneNumber: number, SenderPhoneNumber: number,
RecipientPhoneNumber: recipient, RecipientPhoneNumber: recipient,
Content: data.Text, Content: data.Text,
TriggeringEvent: triggeringEvent, TriggeringEventType: triggeringEventType,
InstanceID: instanceID,
JobID: jobID,
UserID: user.ID,
} }
err = smsChannels.HandleMessage(message) err = smsChannels.HandleMessage(message)
if err != nil { if err != nil {
@ -70,7 +75,7 @@ func generateSms(
} }
contextInfo := map[string]interface{}{ contextInfo := map[string]interface{}{
"recipientPhoneNumber": recipient, "recipientPhoneNumber": recipient,
"eventType": triggeringEvent.Type(), "eventType": triggeringEventType,
"provider": config.ProviderConfig, "provider": config.ProviderConfig,
} }
@ -80,7 +85,7 @@ func generateSms(
Args: caseArgs, Args: caseArgs,
ContextInfo: contextInfo, ContextInfo: contextInfo,
}, },
TriggeringEvent: triggeringEvent, TriggeringEventType: triggeringEventType,
} }
webhookChannels, err := channels.Webhook(ctx, *config.WebhookConfig) webhookChannels, err := channels.Webhook(ctx, *config.WebhookConfig)
if err != nil { if err != nil {

View File

@ -13,11 +13,11 @@ func handleWebhook(
webhookConfig webhook.Config, webhookConfig webhook.Config,
channels ChannelChains, channels ChannelChains,
serializable interface{}, serializable interface{},
triggeringEvent eventstore.Event, triggeringEventType eventstore.EventType,
) error { ) error {
message := &messages.JSON{ message := &messages.JSON{
Serializable: serializable, Serializable: serializable,
TriggeringEvent: triggeringEvent, TriggeringEventType: triggeringEventType,
} }
webhookChannels, err := channels.Webhook(ctx, webhookConfig) webhookChannels, err := channels.Webhook(ctx, webhookConfig)
if err != nil { if err != nil {

View File

@ -0,0 +1,45 @@
package queue
import (
"context"
"sync"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/internal/database/dialect"
)
const (
schema = "queue"
applicationName = "zitadel_queue"
)
var conns = &sync.Map{}
type queueKey struct{}
func WithQueue(parent context.Context) context.Context {
return context.WithValue(parent, queueKey{}, struct{}{})
}
func init() {
dialect.RegisterBeforeAcquire(func(ctx context.Context, c *pgx.Conn) error {
if _, ok := ctx.Value(queueKey{}).(struct{}); !ok {
return nil
}
_, err := c.Exec(ctx, "SET search_path TO "+schema+"; SET application_name TO "+applicationName)
if err != nil {
return err
}
conns.Store(c, struct{}{})
return nil
})
dialect.RegisterAfterRelease(func(c *pgx.Conn) error {
_, ok := conns.LoadAndDelete(c)
if !ok {
return nil
}
_, err := c.Exec(context.Background(), "SET search_path TO DEFAULT; SET application_name TO "+dialect.DefaultAppName)
return err
})
}

38
internal/queue/migrate.go Normal file
View File

@ -0,0 +1,38 @@
package queue
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivermigrate"
"github.com/zitadel/zitadel/internal/database"
)
type Migrator struct {
driver riverdriver.Driver[pgx.Tx]
}
func NewMigrator(client *database.DB) *Migrator {
return &Migrator{
driver: riverpgxv5.New(client.Pool),
}
}
func (m *Migrator) Execute(ctx context.Context) error {
_, err := m.driver.GetExecutor().Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema)
if err != nil {
return err
}
migrator, err := rivermigrate.New(m.driver, nil)
if err != nil {
return err
}
ctx = WithQueue(ctx)
_, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, nil)
return err
}

View File

@ -2,74 +2,96 @@ package queue
import ( import (
"context" "context"
"sync"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
"github.com/riverqueue/river"
"github.com/riverqueue/river/riverdriver" "github.com/riverqueue/river/riverdriver"
"github.com/riverqueue/river/riverdriver/riverpgxv5" "github.com/riverqueue/river/riverdriver/riverpgxv5"
"github.com/riverqueue/river/rivermigrate" "github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/database" "github.com/zitadel/zitadel/internal/database"
"github.com/zitadel/zitadel/internal/database/dialect"
) )
const (
schema = "queue"
applicationName = "zitadel_queue"
)
var conns = &sync.Map{}
type queueKey struct{}
func WithQueue(parent context.Context) context.Context {
return context.WithValue(parent, queueKey{}, struct{}{})
}
func init() {
dialect.RegisterBeforeAcquire(func(ctx context.Context, c *pgx.Conn) error {
if _, ok := ctx.Value(queueKey{}).(struct{}); !ok {
return nil
}
_, err := c.Exec(ctx, "SET search_path TO "+schema+"; SET application_name TO "+applicationName)
if err != nil {
return err
}
conns.Store(c, struct{}{})
return nil
})
dialect.RegisterAfterRelease(func(c *pgx.Conn) error {
_, ok := conns.LoadAndDelete(c)
if !ok {
return nil
}
_, err := c.Exec(context.Background(), "SET search_path TO DEFAULT; SET application_name TO "+dialect.DefaultAppName)
return err
})
}
// Queue abstracts the underlying queuing library // Queue abstracts the underlying queuing library
// For more information see github.com/riverqueue/river // For more information see github.com/riverqueue/river
// TODO(adlerhurst): maybe it makes more sense to split the effective queue from the migrator.
type Queue struct { type Queue struct {
driver riverdriver.Driver[pgx.Tx] driver riverdriver.Driver[pgx.Tx]
client *river.Client[pgx.Tx]
config *river.Config
shouldStart bool
} }
func New(client *database.DB) *Queue { type Config struct {
return &Queue{driver: riverpgxv5.New(client.Pool)} Client *database.DB `mapstructure:"-"` // mapstructure is needed if we would like to use viper to configure the queue
} }
func (q *Queue) ExecuteMigrations(ctx context.Context) error { func NewQueue(config *Config) (_ *Queue, err error) {
_, err := q.driver.GetExecutor().Exec(ctx, "CREATE SCHEMA IF NOT EXISTS "+schema) return &Queue{
if err != nil { driver: riverpgxv5.New(config.Client.Pool),
return err config: &river.Config{
Workers: river.NewWorkers(),
Queues: make(map[string]river.QueueConfig),
JobTimeout: -1,
},
}, nil
}
func (q *Queue) ShouldStart() {
if q == nil {
return
} }
q.shouldStart = true
}
migrator, err := rivermigrate.New(q.driver, nil) func (q *Queue) Start(ctx context.Context) (err error) {
if err != nil { if q == nil || !q.shouldStart {
return err return nil
} }
ctx = WithQueue(ctx) ctx = WithQueue(ctx)
_, err = migrator.Migrate(ctx, rivermigrate.DirectionUp, nil)
q.client, err = river.NewClient(q.driver, q.config)
if err != nil {
return err
}
return q.client.Start(ctx)
}
func (q *Queue) AddWorkers(w ...Worker) {
if q == nil {
logging.Info("skip adding workers because queue is not set")
return
}
for _, worker := range w {
worker.Register(q.config.Workers, q.config.Queues)
}
}
type InsertOpt func(*river.InsertOpts)
func WithMaxAttempts(maxAttempts uint8) InsertOpt {
return func(opts *river.InsertOpts) {
opts.MaxAttempts = int(maxAttempts)
}
}
func WithQueueName(name string) InsertOpt {
return func(opts *river.InsertOpts) {
opts.Queue = name
}
}
func (q *Queue) Insert(ctx context.Context, args river.JobArgs, opts ...InsertOpt) error {
options := new(river.InsertOpts)
ctx = WithQueue(ctx)
for _, opt := range opts {
opt(options)
}
_, err := q.client.Insert(ctx, args, options)
return err return err
} }
type Worker interface {
Register(workers *river.Workers, queues map[string]river.QueueConfig)
}

View File

@ -1,25 +0,0 @@
package notification
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
const (
AggregateType = "notification"
AggregateVersion = "v1"
)
type Aggregate struct {
eventstore.Aggregate
}
func NewAggregate(id, resourceOwner string) *Aggregate {
return &Aggregate{
Aggregate: eventstore.Aggregate{
Type: AggregateType,
Version: AggregateVersion,
ID: id,
ResourceOwner: resourceOwner,
},
}
}

View File

@ -1,12 +0,0 @@
package notification
import (
"github.com/zitadel/zitadel/internal/eventstore"
)
func init() {
eventstore.RegisterFilterEventMapper(AggregateType, RequestedType, eventstore.GenericEventMapper[RequestedEvent])
eventstore.RegisterFilterEventMapper(AggregateType, SentType, eventstore.GenericEventMapper[SentEvent])
eventstore.RegisterFilterEventMapper(AggregateType, RetryRequestedType, eventstore.GenericEventMapper[RetryRequestedEvent])
eventstore.RegisterFilterEventMapper(AggregateType, CanceledType, eventstore.GenericEventMapper[CanceledEvent])
}

View File

@ -1,28 +1,21 @@
package notification package notification
import ( import (
"context"
"time" "time"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore" "github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query"
) )
const ( const (
notificationEventPrefix = "notification." QueueName = "notification"
RequestedType = notificationEventPrefix + "requested"
RetryRequestedType = notificationEventPrefix + "retry.requested"
SentType = notificationEventPrefix + "sent"
CanceledType = notificationEventPrefix + "canceled"
) )
type Request struct { type Request struct {
Aggregate *eventstore.Aggregate `json:"aggregate"`
UserID string `json:"userID"` UserID string `json:"userID"`
UserResourceOwner string `json:"userResourceOwner"` UserResourceOwner string `json:"userResourceOwner"`
AggregateID string `json:"notificationAggregateID"`
AggregateResourceOwner string `json:"notificationAggregateResourceOwner"`
TriggeredAtOrigin string `json:"triggeredAtOrigin"` TriggeredAtOrigin string `json:"triggeredAtOrigin"`
EventType eventstore.EventType `json:"eventType"` EventType eventstore.EventType `json:"eventType"`
MessageType string `json:"messageType"` MessageType string `json:"messageType"`
@ -32,213 +25,10 @@ type Request struct {
Code *crypto.CryptoValue `json:"code,omitempty"` Code *crypto.CryptoValue `json:"code,omitempty"`
UnverifiedNotificationChannel bool `json:"unverifiedNotificationChannel,omitempty"` UnverifiedNotificationChannel bool `json:"unverifiedNotificationChannel,omitempty"`
IsOTP bool `json:"isOTP,omitempty"` IsOTP bool `json:"isOTP,omitempty"`
RequiresPreviousDomain bool `json:"RequiresPreviousDomain,omitempty"` RequiresPreviousDomain bool `json:"requiresPreviousDomain,omitempty"`
Args *domain.NotificationArguments `json:"args,omitempty"` Args *domain.NotificationArguments `json:"args,omitempty"`
} }
func (e *Request) NotificationAggregateID() string { func (e *Request) Kind() string {
if e.AggregateID == "" { return "notification_request"
return e.UserID
}
return e.AggregateID
}
func (e *Request) NotificationAggregateResourceOwner() string {
if e.AggregateResourceOwner == "" {
return e.UserResourceOwner
}
return e.AggregateResourceOwner
}
type RequestedEvent struct {
eventstore.BaseEvent `json:"-"`
Request `json:"request"`
}
func (e *RequestedEvent) TriggerOrigin() string {
return e.TriggeredAtOrigin
}
func (e *RequestedEvent) Payload() interface{} {
return e
}
func (e *RequestedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *RequestedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewRequestedEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
userID,
userResourceOwner,
aggregateID,
aggregateResourceOwner,
triggerOrigin,
urlTemplate string,
code *crypto.CryptoValue,
codeExpiry time.Duration,
eventType eventstore.EventType,
notificationType domain.NotificationType,
messageType string,
unverifiedNotificationChannel,
isOTP,
requiresPreviousDomain bool,
args *domain.NotificationArguments,
) *RequestedEvent {
return &RequestedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
RequestedType,
),
Request: Request{
UserID: userID,
UserResourceOwner: userResourceOwner,
AggregateID: aggregateID,
AggregateResourceOwner: aggregateResourceOwner,
TriggeredAtOrigin: triggerOrigin,
EventType: eventType,
MessageType: messageType,
NotificationType: notificationType,
URLTemplate: urlTemplate,
CodeExpiry: codeExpiry,
Code: code,
UnverifiedNotificationChannel: unverifiedNotificationChannel,
IsOTP: isOTP,
RequiresPreviousDomain: requiresPreviousDomain,
Args: args,
},
}
}
type SentEvent struct {
eventstore.BaseEvent `json:"-"`
}
func (e *SentEvent) Payload() interface{} {
return e
}
func (e *SentEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *SentEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewSentEvent(ctx context.Context,
aggregate *eventstore.Aggregate,
) *SentEvent {
return &SentEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
SentType,
),
}
}
type CanceledEvent struct {
eventstore.BaseEvent `json:"-"`
Error string `json:"error"`
}
func (e *CanceledEvent) Payload() interface{} {
return e
}
func (e *CanceledEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *CanceledEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewCanceledEvent(ctx context.Context, aggregate *eventstore.Aggregate, errorMessage string) *CanceledEvent {
return &CanceledEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
CanceledType,
),
Error: errorMessage,
}
}
type RetryRequestedEvent struct {
eventstore.BaseEvent `json:"-"`
Request `json:"request"`
Error string `json:"error"`
NotifyUser *query.NotifyUser `json:"notifyUser"`
BackOff time.Duration `json:"backOff"`
}
func (e *RetryRequestedEvent) Payload() interface{} {
return e
}
func (e *RetryRequestedEvent) UniqueConstraints() []*eventstore.UniqueConstraint {
return nil
}
func (e *RetryRequestedEvent) SetBaseEvent(event *eventstore.BaseEvent) {
e.BaseEvent = *event
}
func NewRetryRequestedEvent(
ctx context.Context,
aggregate *eventstore.Aggregate,
userID,
userResourceOwner,
aggregateID,
aggregateResourceOwner,
triggerOrigin,
urlTemplate string,
code *crypto.CryptoValue,
codeExpiry time.Duration,
eventType eventstore.EventType,
notificationType domain.NotificationType,
messageType string,
unverifiedNotificationChannel,
isOTP bool,
args *domain.NotificationArguments,
notifyUser *query.NotifyUser,
backoff time.Duration,
errorMessage string,
) *RetryRequestedEvent {
return &RetryRequestedEvent{
BaseEvent: *eventstore.NewBaseEventForPush(
ctx,
aggregate,
RetryRequestedType,
),
Request: Request{
UserID: userID,
UserResourceOwner: userResourceOwner,
AggregateID: aggregateID,
AggregateResourceOwner: aggregateResourceOwner,
TriggeredAtOrigin: triggerOrigin,
EventType: eventType,
MessageType: messageType,
NotificationType: notificationType,
URLTemplate: urlTemplate,
CodeExpiry: codeExpiry,
Code: code,
UnverifiedNotificationChannel: unverifiedNotificationChannel,
IsOTP: isOTP,
Args: args,
},
NotifyUser: notifyUser,
BackOff: backoff,
Error: errorMessage,
}
} }