feat: Add Twilio Verification Service (#8678)

# Which Problems Are Solved
Twilio supports a robust, multi-channel verification service that
notably supports multi-region SMS sender numbers required for our use
case. Currently, Zitadel does much of the work of the Twilio Verify (eg.
localization, code generation, messaging) but doesn't support the pool
of sender numbers that Twilio Verify does.

# How the Problems Are Solved
To support this API, we need to be able to store the Twilio Service ID
and send that in a verification request where appropriate: phone number
verification and SMS 2FA code paths.

This PR does the following: 
- Adds the ability to use Twilio Verify of standard messaging through
Twilio
- Adds support for international numbers and more reliable verification
messages sent from multiple numbers
- Adds a new Twilio configuration option to support Twilio Verify in the
admin console
- Sends verification SMS messages through Twilio Verify
- Implements Twilio Verification Checks for codes generated through the
same

# Additional Changes

# Additional Context
- base was implemented by @zhirschtritt in
https://github.com/zitadel/zitadel/pull/8268 ❤️
- closes https://github.com/zitadel/zitadel/issues/8581

---------

Co-authored-by: Zachary Hirschtritt <zachary.hirschtritt@klaviyo.com>
Co-authored-by: Joey Biscoglia <joey.biscoglia@klaviyo.com>
This commit is contained in:
Livio Spring
2024-09-26 09:14:33 +02:00
committed by GitHub
parent 4eaa3163b6
commit 14e2aba1bc
89 changed files with 3888 additions and 782 deletions

View File

@@ -1,7 +1,9 @@
package twilio
import (
"github.com/kevinburke/twilio-go"
newTwilio "github.com/twilio/twilio-go"
openapi "github.com/twilio/twilio-go/rest/api/v2010"
verify "github.com/twilio/twilio-go/rest/verify/v2"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/notification/channels"
@@ -10,8 +12,7 @@ import (
)
func InitChannel(config Config) channels.NotificationChannel {
client := twilio.NewClient(config.SID, config.Token, nil)
client := newTwilio.NewRestClientWithParams(newTwilio.ClientParams{Username: config.SID, Password: config.Token})
logging.Debug("successfully initialized twilio sms channel")
return channels.HandleMessageFunc(func(message channels.Message) error {
@@ -19,11 +20,30 @@ func InitChannel(config Config) channels.NotificationChannel {
if !ok {
return zerrors.ThrowInternal(nil, "TWILI-s0pLc", "message is not SMS")
}
if config.VerifyServiceSID != "" {
params := &verify.CreateVerificationParams{}
params.SetTo(twilioMsg.RecipientPhoneNumber)
params.SetChannel("sms")
resp, err := client.VerifyV2.CreateVerification(config.VerifyServiceSID, params)
if err != nil {
return zerrors.ThrowInternal(err, "TWILI-0s9f2", "could not send verification")
}
logging.WithFields("sid", resp.Sid, "status", resp.Status).Debug("verification sent")
twilioMsg.VerificationID = resp.Sid
return nil
}
content, err := twilioMsg.GetContent()
if err != nil {
return err
}
m, err := client.Messages.SendMessage(twilioMsg.SenderPhoneNumber, twilioMsg.RecipientPhoneNumber, content, nil)
params := &openapi.CreateMessageParams{}
params.SetTo(twilioMsg.RecipientPhoneNumber)
params.SetFrom(twilioMsg.SenderPhoneNumber)
params.SetBody(content)
m, err := client.Api.CreateMessage(params)
if err != nil {
return zerrors.ThrowInternal(err, "TWILI-osk3S", "could not send message")
}

View File

@@ -1,11 +1,40 @@
package twilio
import (
newTwilio "github.com/twilio/twilio-go"
verify "github.com/twilio/twilio-go/rest/verify/v2"
"github.com/zitadel/zitadel/internal/zerrors"
)
type Config struct {
SID string
Token string
SenderNumber string
SID string
Token string
SenderNumber string
VerifyServiceSID string
}
func (t *Config) IsValid() bool {
return t.SID != "" && t.Token != "" && t.SenderNumber != ""
}
func (t *Config) VerifyCode(verificationID, code string) error {
client := newTwilio.NewRestClientWithParams(newTwilio.ClientParams{Username: t.SID, Password: t.Token})
checkParams := &verify.CreateVerificationCheckParams{}
checkParams.SetVerificationSid(verificationID)
checkParams.SetCode(code)
resp, err := client.VerifyV2.CreateVerificationCheck(t.VerifyServiceSID, checkParams)
if err != nil || resp.Status == nil {
return zerrors.ThrowInvalidArgument(err, "TWILI-JK3ta", "Errors.User.Code.NotFound")
}
switch *resp.Status {
case "approved":
return nil
case "expired":
return zerrors.ThrowInvalidArgument(nil, "TWILI-SF3ba", "Errors.User.Code.Expired")
case "max_attempts_reached":
return zerrors.ThrowInvalidArgument(nil, "TWILI-Ok39a", "Errors.User.Code.NotFound")
default:
return zerrors.ThrowInvalidArgument(nil, "TWILI-Skwe4", "Errors.User.Code.Invalid")
}
}

View File

@@ -3,6 +3,7 @@ package handlers
import (
"context"
"github.com/zitadel/zitadel/internal/notification/senders"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/quota"
)
@@ -10,15 +11,15 @@ import (
type Commands interface {
HumanInitCodeSent(ctx context.Context, orgID, userID string) error
HumanEmailVerificationCodeSent(ctx context.Context, orgID, userID string) error
PasswordCodeSent(ctx context.Context, orgID, userID string) error
HumanOTPSMSCodeSent(ctx context.Context, userID, resourceOwner string) error
PasswordCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error
HumanOTPSMSCodeSent(ctx context.Context, userID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error
HumanOTPEmailCodeSent(ctx context.Context, userID, resourceOwner string) error
OTPSMSSent(ctx context.Context, sessionID, resourceOwner string) error
OTPSMSSent(ctx context.Context, sessionID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) error
OTPEmailSent(ctx context.Context, sessionID, resourceOwner string) error
UserDomainClaimedSent(ctx context.Context, orgID, userID string) error
HumanPasswordlessInitCodeSent(ctx context.Context, userID, resourceOwner, codeID string) error
PasswordChangeSent(ctx context.Context, orgID, userID string) error
HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string) error
HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string, generatorInfo *senders.CodeGeneratorInfo) error
InviteCodeSent(ctx context.Context, orgID, userID string) error
UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error
MilestonePushed(ctx context.Context, msType milestone.Type, endpoints []string, primaryDomain string) error

View File

@@ -31,9 +31,10 @@ func (n *NotificationQueries) GetActiveSMSConfig(ctx context.Context) (*sms.Conf
return &sms.Config{
ProviderConfig: provider,
TwilioConfig: &twilio.Config{
SID: config.TwilioConfig.SID,
Token: token,
SenderNumber: config.TwilioConfig.SenderNumber,
SID: config.TwilioConfig.SID,
Token: token,
SenderNumber: config.TwilioConfig.SenderNumber,
VerifyServiceSID: config.TwilioConfig.VerifyServiceSID,
},
}, nil
}

View File

@@ -13,35 +13,36 @@ import (
context "context"
reflect "reflect"
senders "github.com/zitadel/zitadel/internal/notification/senders"
milestone "github.com/zitadel/zitadel/internal/repository/milestone"
quota "github.com/zitadel/zitadel/internal/repository/quota"
gomock "go.uber.org/mock/gomock"
)
// MockCommands is a mock of Commands interface
// MockCommands is a mock of Commands interface.
type MockCommands struct {
ctrl *gomock.Controller
recorder *MockCommandsMockRecorder
}
// MockCommandsMockRecorder is the mock recorder for MockCommands
// MockCommandsMockRecorder is the mock recorder for MockCommands.
type MockCommandsMockRecorder struct {
mock *MockCommands
}
// NewMockCommands creates a new mock instance
// NewMockCommands creates a new mock instance.
func NewMockCommands(ctrl *gomock.Controller) *MockCommands {
mock := &MockCommands{ctrl: ctrl}
mock.recorder = &MockCommandsMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCommands) EXPECT() *MockCommandsMockRecorder {
return m.recorder
}
// HumanEmailVerificationCodeSent mocks base method
// HumanEmailVerificationCodeSent mocks base method.
func (m *MockCommands) HumanEmailVerificationCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", arg0, arg1, arg2)
@@ -49,13 +50,13 @@ func (m *MockCommands) HumanEmailVerificationCodeSent(arg0 context.Context, arg1
return ret0
}
// HumanEmailVerificationCodeSent indicates an expected call of HumanEmailVerificationCodeSent
func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// HumanEmailVerificationCodeSent indicates an expected call of HumanEmailVerificationCodeSent.
func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), arg0, arg1, arg2)
}
// HumanInitCodeSent mocks base method
// HumanInitCodeSent mocks base method.
func (m *MockCommands) HumanInitCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanInitCodeSent", arg0, arg1, arg2)
@@ -63,13 +64,13 @@ func (m *MockCommands) HumanInitCodeSent(arg0 context.Context, arg1, arg2 string
return ret0
}
// HumanInitCodeSent indicates an expected call of HumanInitCodeSent
func (mr *MockCommandsMockRecorder) HumanInitCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// HumanInitCodeSent indicates an expected call of HumanInitCodeSent.
func (mr *MockCommandsMockRecorder) HumanInitCodeSent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), arg0, arg1, arg2)
}
// HumanOTPEmailCodeSent mocks base method
// HumanOTPEmailCodeSent mocks base method.
func (m *MockCommands) HumanOTPEmailCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", arg0, arg1, arg2)
@@ -77,27 +78,27 @@ func (m *MockCommands) HumanOTPEmailCodeSent(arg0 context.Context, arg1, arg2 st
return ret0
}
// HumanOTPEmailCodeSent indicates an expected call of HumanOTPEmailCodeSent
func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// HumanOTPEmailCodeSent indicates an expected call of HumanOTPEmailCodeSent.
func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), arg0, arg1, arg2)
}
// HumanOTPSMSCodeSent mocks base method
func (m *MockCommands) HumanOTPSMSCodeSent(arg0 context.Context, arg1, arg2 string) error {
// HumanOTPSMSCodeSent mocks base method.
func (m *MockCommands) HumanOTPSMSCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// HumanOTPSMSCodeSent indicates an expected call of HumanOTPSMSCodeSent
func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// HumanOTPSMSCodeSent indicates an expected call of HumanOTPSMSCodeSent.
func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), arg0, arg1, arg2, arg3)
}
// HumanPasswordlessInitCodeSent mocks base method
// HumanPasswordlessInitCodeSent mocks base method.
func (m *MockCommands) HumanPasswordlessInitCodeSent(arg0 context.Context, arg1, arg2, arg3 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", arg0, arg1, arg2, arg3)
@@ -105,27 +106,27 @@ func (m *MockCommands) HumanPasswordlessInitCodeSent(arg0 context.Context, arg1,
return ret0
}
// HumanPasswordlessInitCodeSent indicates an expected call of HumanPasswordlessInitCodeSent
func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
// HumanPasswordlessInitCodeSent indicates an expected call of HumanPasswordlessInitCodeSent.
func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), arg0, arg1, arg2, arg3)
}
// HumanPhoneVerificationCodeSent mocks base method
func (m *MockCommands) HumanPhoneVerificationCodeSent(arg0 context.Context, arg1, arg2 string) error {
// HumanPhoneVerificationCodeSent mocks base method.
func (m *MockCommands) HumanPhoneVerificationCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// HumanPhoneVerificationCodeSent indicates an expected call of HumanPhoneVerificationCodeSent
func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// HumanPhoneVerificationCodeSent indicates an expected call of HumanPhoneVerificationCodeSent.
func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), arg0, arg1, arg2, arg3)
}
// InviteCodeSent mocks base method
// InviteCodeSent mocks base method.
func (m *MockCommands) InviteCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "InviteCodeSent", arg0, arg1, arg2)
@@ -133,13 +134,13 @@ func (m *MockCommands) InviteCodeSent(arg0 context.Context, arg1, arg2 string) e
return ret0
}
// InviteCodeSent indicates an expected call of InviteCodeSent
func (mr *MockCommandsMockRecorder) InviteCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// InviteCodeSent indicates an expected call of InviteCodeSent.
func (mr *MockCommandsMockRecorder) InviteCodeSent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "InviteCodeSent", reflect.TypeOf((*MockCommands)(nil).InviteCodeSent), arg0, arg1, arg2)
}
// MilestonePushed mocks base method
// MilestonePushed mocks base method.
func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 milestone.Type, arg2 []string, arg3 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3)
@@ -147,13 +148,13 @@ func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 milestone.Type
return ret0
}
// MilestonePushed indicates an expected call of MilestonePushed
func (mr *MockCommandsMockRecorder) MilestonePushed(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
// MilestonePushed indicates an expected call of MilestonePushed.
func (mr *MockCommandsMockRecorder) MilestonePushed(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), arg0, arg1, arg2, arg3)
}
// OTPEmailSent mocks base method
// OTPEmailSent mocks base method.
func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OTPEmailSent", arg0, arg1, arg2)
@@ -161,27 +162,27 @@ func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) err
return ret0
}
// OTPEmailSent indicates an expected call of OTPEmailSent
func (mr *MockCommandsMockRecorder) OTPEmailSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// OTPEmailSent indicates an expected call of OTPEmailSent.
func (mr *MockCommandsMockRecorder) OTPEmailSent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), arg0, arg1, arg2)
}
// OTPSMSSent mocks base method
func (m *MockCommands) OTPSMSSent(arg0 context.Context, arg1, arg2 string) error {
// OTPSMSSent mocks base method.
func (m *MockCommands) OTPSMSSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OTPSMSSent", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "OTPSMSSent", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// OTPSMSSent indicates an expected call of OTPSMSSent
func (mr *MockCommandsMockRecorder) OTPSMSSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// OTPSMSSent indicates an expected call of OTPSMSSent.
func (mr *MockCommandsMockRecorder) OTPSMSSent(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), arg0, arg1, arg2, arg3)
}
// PasswordChangeSent mocks base method
// PasswordChangeSent mocks base method.
func (m *MockCommands) PasswordChangeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PasswordChangeSent", arg0, arg1, arg2)
@@ -189,27 +190,27 @@ func (m *MockCommands) PasswordChangeSent(arg0 context.Context, arg1, arg2 strin
return ret0
}
// PasswordChangeSent indicates an expected call of PasswordChangeSent
func (mr *MockCommandsMockRecorder) PasswordChangeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// PasswordChangeSent indicates an expected call of PasswordChangeSent.
func (mr *MockCommandsMockRecorder) PasswordChangeSent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), arg0, arg1, arg2)
}
// PasswordCodeSent mocks base method
func (m *MockCommands) PasswordCodeSent(arg0 context.Context, arg1, arg2 string) error {
// PasswordCodeSent mocks base method.
func (m *MockCommands) PasswordCodeSent(arg0 context.Context, arg1, arg2 string, arg3 *senders.CodeGeneratorInfo) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PasswordCodeSent", arg0, arg1, arg2)
ret := m.ctrl.Call(m, "PasswordCodeSent", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// PasswordCodeSent indicates an expected call of PasswordCodeSent
func (mr *MockCommandsMockRecorder) PasswordCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// PasswordCodeSent indicates an expected call of PasswordCodeSent.
func (mr *MockCommandsMockRecorder) PasswordCodeSent(arg0, arg1, arg2, arg3 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2, arg3)
}
// UsageNotificationSent mocks base method
// UsageNotificationSent mocks base method.
func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.NotificationDueEvent) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UsageNotificationSent", arg0, arg1)
@@ -217,13 +218,13 @@ func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.N
return ret0
}
// UsageNotificationSent indicates an expected call of UsageNotificationSent
func (mr *MockCommandsMockRecorder) UsageNotificationSent(arg0, arg1 interface{}) *gomock.Call {
// UsageNotificationSent indicates an expected call of UsageNotificationSent.
func (mr *MockCommandsMockRecorder) UsageNotificationSent(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), arg0, arg1)
}
// UserDomainClaimedSent mocks base method
// UserDomainClaimedSent mocks base method.
func (m *MockCommands) UserDomainClaimedSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserDomainClaimedSent", arg0, arg1, arg2)
@@ -231,8 +232,8 @@ func (m *MockCommands) UserDomainClaimedSent(arg0 context.Context, arg1, arg2 st
return ret0
}
// UserDomainClaimedSent indicates an expected call of UserDomainClaimedSent
func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(arg0, arg1, arg2 interface{}) *gomock.Call {
// UserDomainClaimedSent indicates an expected call of UserDomainClaimedSent.
func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(arg0, arg1, arg2 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), arg0, arg1, arg2)
}

View File

@@ -11,6 +11,7 @@ import (
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/handler/v2"
"github.com/zitadel/zitadel/internal/notification/senders"
"github.com/zitadel/zitadel/internal/notification/types"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/repository/session"
@@ -258,9 +259,12 @@ func (u *userNotifier) reducePasswordCodeAdded(event eventstore.Event) (*handler
if alreadyHandled {
return nil
}
code, err := crypto.DecryptString(e.Code, u.queries.UserDataCrypto)
if err != nil {
return err
var code string
if e.Code != nil {
code, err = crypto.DecryptString(e.Code, u.queries.UserDataCrypto)
if err != nil {
return err
}
}
colors, err := u.queries.ActiveLabelPolicyByOrg(ctx, e.Aggregate().ResourceOwner, false)
if err != nil {
@@ -285,15 +289,16 @@ func (u *userNotifier) reducePasswordCodeAdded(event eventstore.Event) (*handler
if err != nil {
return err
}
generatorInfo := new(senders.CodeGeneratorInfo)
notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e)
if e.NotificationType == domain.NotificationTypeSms {
notify = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e)
notify = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e, generatorInfo)
}
err = notify.SendPasswordCode(ctx, notifyUser, code, e.URLTemplate, e.AuthRequestID)
if err != nil {
return err
}
return u.commands.PasswordCodeSent(ctx, e.Aggregate().ResourceOwner, e.Aggregate().ID)
return u.commands.PasswordCodeSent(ctx, e.Aggregate().ResourceOwner, e.Aggregate().ID, generatorInfo)
}), nil
}
@@ -345,7 +350,7 @@ func (u *userNotifier) reduceOTPSMS(
expiry time.Duration,
userID,
resourceOwner string,
sentCommand func(ctx context.Context, userID string, resourceOwner string) (err error),
sentCommand func(ctx context.Context, userID, resourceOwner string, generatorInfo *senders.CodeGeneratorInfo) (err error),
eventTypes ...eventstore.EventType,
) (*handler.Statement, error) {
ctx := HandlerContext(event.Aggregate())
@@ -356,9 +361,12 @@ func (u *userNotifier) reduceOTPSMS(
if alreadyHandled {
return handler.NewNoOpStatement(event), nil
}
plainCode, err := crypto.DecryptString(code, u.queries.UserDataCrypto)
if err != nil {
return nil, err
var plainCode string
if code != nil {
plainCode, err = crypto.DecryptString(code, u.queries.UserDataCrypto)
if err != nil {
return nil, err
}
}
colors, err := u.queries.ActiveLabelPolicyByOrg(ctx, resourceOwner, false)
if err != nil {
@@ -377,12 +385,13 @@ func (u *userNotifier) reduceOTPSMS(
if err != nil {
return nil, err
}
notify := types.SendSMS(ctx, u.channels, translator, notifyUser, colors, event)
generatorInfo := new(senders.CodeGeneratorInfo)
notify := types.SendSMS(ctx, u.channels, translator, notifyUser, colors, event, generatorInfo)
err = notify.SendOTPSMSCode(ctx, plainCode, expiry)
if err != nil {
return nil, err
}
err = sentCommand(ctx, event.Aggregate().ID, event.Aggregate().ResourceOwner)
err = sentCommand(ctx, event.Aggregate().ID, event.Aggregate().ResourceOwner, generatorInfo)
if err != nil {
return nil, err
}
@@ -691,9 +700,12 @@ func (u *userNotifier) reducePhoneCodeAdded(event eventstore.Event) (*handler.St
if alreadyHandled {
return nil
}
code, err := crypto.DecryptString(e.Code, u.queries.UserDataCrypto)
if err != nil {
return err
var code string
if e.Code != nil {
code, err = crypto.DecryptString(e.Code, u.queries.UserDataCrypto)
if err != nil {
return err
}
}
colors, err := u.queries.ActiveLabelPolicyByOrg(ctx, e.Aggregate().ResourceOwner, false)
if err != nil {
@@ -713,12 +725,12 @@ func (u *userNotifier) reducePhoneCodeAdded(event eventstore.Event) (*handler.St
if err != nil {
return err
}
err = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e).
SendPhoneVerificationCode(ctx, code)
if err != nil {
generatorInfo := new(senders.CodeGeneratorInfo)
if err = types.SendSMS(ctx, u.channels, translator, notifyUser, colors, e, generatorInfo).
SendPhoneVerificationCode(ctx, code); err != nil {
return err
}
return u.commands.HumanPhoneVerificationCodeSent(ctx, e.Aggregate().ResourceOwner, e.Aggregate().ID)
return u.commands.HumanPhoneVerificationCodeSent(ctx, e.Aggregate().ResourceOwner, e.Aggregate().ID, generatorInfo)
}), nil
}
@@ -778,7 +790,7 @@ func (u *userNotifier) reduceInviteCodeAdded(event eventstore.Event) (*handler.S
}
func (u *userNotifier) checkIfCodeAlreadyHandledOrExpired(ctx context.Context, event eventstore.Event, expiry time.Duration, data map[string]interface{}, eventTypes ...eventstore.EventType) (bool, error) {
if event.CreatedAt().Add(expiry).Before(time.Now().UTC()) {
if expiry > 0 && event.CreatedAt().Add(expiry).Before(time.Now().UTC()) {
return true, nil
}
return u.queries.IsAlreadyHandled(ctx, event, data, eventTypes...)

View File

@@ -7,11 +7,13 @@ import (
"testing"
"time"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"go.uber.org/mock/gomock"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/eventstore/repository"
es_repo_mock "github.com/zitadel/zitadel/internal/eventstore/repository/mock"
@@ -19,6 +21,7 @@ import (
channel_mock "github.com/zitadel/zitadel/internal/notification/channels/mock"
"github.com/zitadel/zitadel/internal/notification/channels/sms"
"github.com/zitadel/zitadel/internal/notification/channels/smtp"
"github.com/zitadel/zitadel/internal/notification/channels/twilio"
"github.com/zitadel/zitadel/internal/notification/channels/webhook"
"github.com/zitadel/zitadel/internal/notification/handlers/mock"
"github.com/zitadel/zitadel/internal/notification/messages"
@@ -36,10 +39,13 @@ const (
codeID = "event1"
logoURL = "logo.png"
eventOrigin = "https://triggered.here"
eventOriginDomain = "triggered.here"
assetsPath = "/assets/v1"
preferredLoginName = "loginName1"
lastEmail = "last@email.com"
verifiedEmail = "verified@email.com"
lastPhone = "+41797654321"
verifiedPhone = "+41791234567"
instancePrimaryDomain = "primary.domain"
externalDomain = "external.domain"
externalPort = 3000
@@ -47,6 +53,9 @@ const (
externalProtocol = "http"
defaultOTPEmailTemplate = "/otp/verify?loginName={{.LoginName}}&code={{.Code}}"
authRequestID = "authRequestID"
smsProviderID = "smsProviderID"
emailProviderID = "emailProviderID"
verificationID = "verificationID"
)
func Test_userNotifier_reduceInitCodeAdded(t *testing.T) {
@@ -59,7 +68,7 @@ func Test_userNotifier_reduceInitCodeAdded(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -92,7 +101,7 @@ func Test_userNotifier_reduceInitCodeAdded(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -131,7 +140,7 @@ func Test_userNotifier_reduceInitCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s/ui/login/user/init?authRequestID=%s&code=%s&loginname=%s&orgID=%s&passwordset=%t&userID=%s", eventOrigin, "", testCode, preferredLoginName, orgID, false, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -165,7 +174,7 @@ func Test_userNotifier_reduceInitCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/user/init?authRequestID=%s&code=%s&loginname=%s&orgID=%s&passwordset=%t&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, "", testCode, preferredLoginName, orgID, false, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -204,7 +213,7 @@ func Test_userNotifier_reduceInitCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/user/init?authRequestID=%s&code=%s&loginname=%s&orgID=%s&passwordset=%t&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, authRequestID, testCode, preferredLoginName, orgID, false, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -272,7 +281,7 @@ func Test_userNotifier_reduceEmailCodeAdded(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -307,7 +316,7 @@ func Test_userNotifier_reduceEmailCodeAdded(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -348,7 +357,7 @@ func Test_userNotifier_reduceEmailCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s/ui/login/mail/verification?authRequestID=%s&code=%s&orgID=%s&userID=%s", eventOrigin, "", testCode, orgID, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -385,7 +394,7 @@ func Test_userNotifier_reduceEmailCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/mail/verification?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, "", testCode, orgID, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -426,7 +435,7 @@ func Test_userNotifier_reduceEmailCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/mail/verification?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, authRequestID, testCode, orgID, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -469,7 +478,7 @@ func Test_userNotifier_reduceEmailCodeAdded(t *testing.T) {
urlTemplate := "https://my.custom.url/org/{{.OrgID}}/user/{{.UserID}}/verify/{{.Code}}"
testCode := "testcode"
expectContent := fmt.Sprintf("https://my.custom.url/org/%s/user/%s/verify/%s", orgID, userID, testCode)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -533,14 +542,14 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}
codeAlg, code := cryptoValue(t, ctrl, "testcode")
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID).Return(nil)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
@@ -568,7 +577,7 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -581,7 +590,7 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
}},
}, nil)
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID).Return(nil)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
@@ -609,14 +618,14 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", eventOrigin, "", testCode, orgID, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}
codeAlg, code := cryptoValue(t, ctrl, testCode)
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID).Return(nil)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
@@ -646,7 +655,7 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, "", testCode, orgID, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -659,7 +668,7 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
}},
}, nil)
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID).Return(nil)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
@@ -687,7 +696,7 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/password/init?authRequestID=%s&code=%s&orgID=%s&userID=%s", externalProtocol, instancePrimaryDomain, externalPort, authRequestID, testCode, orgID, userID)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -700,7 +709,7 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
}},
}, nil)
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID).Return(nil)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
@@ -730,14 +739,14 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
urlTemplate := "https://my.custom.url/org/{{.OrgID}}/user/{{.UserID}}/verify/{{.Code}}"
testCode := "testcode"
expectContent := fmt.Sprintf("https://my.custom.url/org/%s/user/%s/verify/%s", orgID, userID, testCode)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
}
codeAlg, code := cryptoValue(t, ctrl, testCode)
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID).Return(nil)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{}).Return(nil)
return fields{
queries: queries,
commands: commands,
@@ -761,7 +770,44 @@ func Test_userNotifier_reducePasswordCodeAdded(t *testing.T) {
},
}, w
},
}}
}, {
name: "external code",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.URL}}"
expectContent := "We received a password reset request. Please use the button below to reset your password. (Code ) If you didn't ask for this mail, please ignore it."
w.messageSMS = &messages.SMS{
SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: lastPhone,
Content: expectContent,
}
expectTemplateQueries(queries, givenTemplate)
commands.EXPECT().PasswordCodeSent(gomock.Any(), orgID, userID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
SMSTokenCrypto: nil,
}, args{
event: &user.HumanPasswordCodeAddedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: 0,
URLTemplate: "",
CodeReturned: false,
NotificationType: domain.NotificationTypeSms,
GeneratorID: smsProviderID,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
@@ -794,7 +840,7 @@ func Test_userNotifier_reduceDomainClaimed(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -823,7 +869,7 @@ func Test_userNotifier_reduceDomainClaimed(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -885,7 +931,7 @@ func Test_userNotifier_reducePasswordlessCodeRequested(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -921,7 +967,7 @@ func Test_userNotifier_reducePasswordlessCodeRequested(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -964,7 +1010,7 @@ func Test_userNotifier_reducePasswordlessCodeRequested(t *testing.T) {
testCode := "testcode"
codeAlg, code := cryptoValue(t, ctrl, testCode)
expectContent := fmt.Sprintf("%s/ui/login/login/passwordless/init?userID=%s&orgID=%s&codeID=%s&code=%s", eventOrigin, userID, orgID, codeID, testCode)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1002,7 +1048,7 @@ func Test_userNotifier_reducePasswordlessCodeRequested(t *testing.T) {
testCode := "testcode"
codeAlg, code := cryptoValue(t, ctrl, testCode)
expectContent := fmt.Sprintf("%s://%s:%d/ui/login/login/passwordless/init?userID=%s&orgID=%s&codeID=%s&code=%s", externalProtocol, instancePrimaryDomain, externalPort, userID, orgID, codeID, testCode)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1044,7 +1090,7 @@ func Test_userNotifier_reducePasswordlessCodeRequested(t *testing.T) {
urlTemplate := "https://my.custom.url/org/{{.OrgID}}/user/{{.UserID}}/verify/{{.Code}}"
testCode := "testcode"
expectContent := fmt.Sprintf("https://my.custom.url/org/%s/user/%s/verify/%s", orgID, userID, testCode)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1109,7 +1155,7 @@ func Test_userNotifier_reducePasswordChanged(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1141,7 +1187,7 @@ func Test_userNotifier_reducePasswordChanged(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{lastEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1206,7 +1252,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s%s/%s/%s", eventOrigin, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{verifiedEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1242,7 +1288,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) {
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
expectContent := fmt.Sprintf("%s://%s:%d%s/%s/%s", externalProtocol, instancePrimaryDomain, externalPort, assetsPath, policyID, logoURL)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{verifiedEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1284,7 +1330,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s/otp/verify?loginName=%s&code=%s", eventOrigin, preferredLoginName, testCode)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{verifiedEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1322,7 +1368,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) {
givenTemplate := "{{.URL}}"
testCode := "testcode"
expectContent := fmt.Sprintf("%s://%s:%d/otp/verify?loginName=%s&code=%s", externalProtocol, instancePrimaryDomain, externalPort, preferredLoginName, testCode)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{verifiedEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1364,7 +1410,7 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) {
urlTemplate := "https://my.custom.url/user/{{.LoginName}}/verify"
testCode := "testcode"
expectContent := fmt.Sprintf("https://my.custom.url/user/%s/verify", preferredLoginName)
w.message = messages.Email{
w.message = &messages.Email{
Recipients: []string{verifiedEmail},
Subject: expectMailSubject,
Content: expectContent,
@@ -1413,6 +1459,107 @@ func Test_userNotifier_reduceOTPEmailChallenged(t *testing.T) {
}
}
func Test_userNotifier_reduceOTPSMSChallenged(t *testing.T) {
tests := []struct {
name string
test func(*gomock.Controller, *mock.MockQueries, *mock.MockCommands) (fields, args, want)
}{{
name: "asset url with event trigger url",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
testCode := ""
expiry := 0 * time.Hour
expectContent := fmt.Sprintf(`%[1]s is your one-time-password for %[2]s. Use it within the next %[3]s.
@%[2]s #%[1]s`, testCode, eventOriginDomain, expiry)
w.messageSMS = &messages.SMS{
SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: verifiedPhone,
Content: expectContent,
}
expectTemplateQueriesSMS(queries, givenTemplate)
queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil)
commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
TriggeredAtOrigin: eventOrigin,
},
}, w
},
}, {
name: "asset url without event trigger url",
test: func(ctrl *gomock.Controller, queries *mock.MockQueries, commands *mock.MockCommands) (f fields, a args, w want) {
givenTemplate := "{{.LogoURL}}"
testCode := ""
expiry := 0 * time.Hour
expectContent := fmt.Sprintf(`%[1]s is your one-time-password for %[2]s. Use it within the next %[3]s.
@%[2]s #%[1]s`, testCode, instancePrimaryDomain, expiry)
w.messageSMS = &messages.SMS{
SenderPhoneNumber: "senderNumber",
RecipientPhoneNumber: verifiedPhone,
Content: expectContent,
}
expectTemplateQueriesSMS(queries, givenTemplate)
queries.EXPECT().SessionByID(gomock.Any(), gomock.Any(), userID, gomock.Any()).Return(&query.Session{}, nil)
queries.EXPECT().SearchInstanceDomains(gomock.Any(), gomock.Any()).Return(&query.InstanceDomains{
Domains: []*query.InstanceDomain{{
Domain: instancePrimaryDomain,
IsPrimary: true,
}},
}, nil)
commands.EXPECT().OTPSMSSent(gomock.Any(), userID, orgID, &senders.CodeGeneratorInfo{ID: smsProviderID, VerificationID: verificationID}).Return(nil)
return fields{
queries: queries,
commands: commands,
es: eventstore.NewEventstore(&eventstore.Config{
Querier: es_repo_mock.NewRepo(t).ExpectFilterEvents().MockQuerier,
}),
}, args{
event: &session.OTPSMSChallengedEvent{
BaseEvent: *eventstore.BaseEventFromRepo(&repository.Event{
AggregateID: userID,
ResourceOwner: sql.NullString{String: orgID},
CreationDate: time.Now().UTC(),
}),
Code: nil,
Expiry: expiry,
CodeReturned: false,
GeneratorID: smsProviderID,
},
}, w
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
queries := mock.NewMockQueries(ctrl)
commands := mock.NewMockCommands(ctrl)
f, a, w := tt.test(ctrl, queries, commands)
_, err := newUserNotifier(t, ctrl, queries, f, a, w).reduceSessionOTPSMSChallenged(a.event)
if w.err != nil {
w.err(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
type fields struct {
queries *mock.MockQueries
commands *mock.MockCommands
@@ -1424,8 +1571,9 @@ type args struct {
event eventstore.Event
}
type want struct {
message messages.Email
err assert.ErrorAssertionFunc
message *messages.Email
messageSMS *messages.SMS
err assert.ErrorAssertionFunc
}
func newUserNotifier(t *testing.T, ctrl *gomock.Controller, queries *mock.MockQueries, f fields, a args, w want) *userNotifier {
@@ -1433,8 +1581,17 @@ func newUserNotifier(t *testing.T, ctrl *gomock.Controller, queries *mock.MockQu
smtpAlg, _ := cryptoValue(t, ctrl, "smtppw")
channel := channel_mock.NewMockNotificationChannel(ctrl)
if w.err == nil {
w.message.TriggeringEvent = a.event
channel.EXPECT().HandleMessage(&w.message).Return(nil)
if w.message != nil {
w.message.TriggeringEvent = a.event
channel.EXPECT().HandleMessage(w.message).Return(nil)
}
if w.messageSMS != nil {
w.messageSMS.TriggeringEvent = a.event
channel.EXPECT().HandleMessage(w.messageSMS).DoAndReturn(func(message *messages.SMS) error {
message.VerificationID = gu.Ptr(verificationID)
return nil
})
}
}
return &userNotifier{
commands: f.commands,
@@ -1454,8 +1611,8 @@ func newUserNotifier(t *testing.T, ctrl *gomock.Controller, queries *mock.MockQu
Chain: *senders.ChainChannels(channel),
EmailConfig: &email.Config{
ProviderConfig: &email.Provider{
ID: "ID",
Description: "Description",
ID: "emailProviderID",
Description: "description",
},
SMTPConfig: &smtp.Config{
SMTP: smtp.SMTP{
@@ -1470,6 +1627,18 @@ func newUserNotifier(t *testing.T, ctrl *gomock.Controller, queries *mock.MockQu
},
WebhookConfig: nil,
},
SMSConfig: &sms.Config{
ProviderConfig: &sms.Provider{
ID: "smsProviderID",
Description: "description",
},
TwilioConfig: &twilio.Config{
SID: "sid",
Token: "token",
SenderNumber: "senderNumber",
VerifyServiceSID: "verifyServiceSID",
},
},
},
}
}
@@ -1479,6 +1648,7 @@ var _ types.ChannelChains = (*channels)(nil)
type channels struct {
senders.Chain
EmailConfig *email.Config
SMSConfig *sms.Config
}
func (c *channels) Email(context.Context) (*senders.Chain, *email.Config, error) {
@@ -1486,7 +1656,7 @@ func (c *channels) Email(context.Context) (*senders.Chain, *email.Config, error)
}
func (c *channels) SMS(context.Context) (*senders.Chain, *sms.Config, error) {
return &c.Chain, nil, nil
return &c.Chain, c.SMSConfig, nil
}
func (c *channels) Webhook(context.Context, webhook.Config) (*senders.Chain, error) {
@@ -1510,6 +1680,31 @@ func expectTemplateQueries(queries *mock.MockQueries, template string) {
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
}, nil)
queries.EXPECT().GetDefaultLanguage(gomock.Any()).Return(language.English)
queries.EXPECT().CustomTextListByTemplate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(&query.CustomTexts{}, nil)
}
func expectTemplateQueriesSMS(queries *mock.MockQueries, template string) {
queries.EXPECT().GetInstanceRestrictions(gomock.Any()).Return(query.Restrictions{
AllowedLanguages: []language.Tag{language.English},
}, nil)
queries.EXPECT().ActiveLabelPolicyByOrg(gomock.Any(), gomock.Any(), gomock.Any()).Return(&query.LabelPolicy{
ID: policyID,
Light: query.Theme{
LogoURL: logoURL,
},
}, nil)
queries.EXPECT().GetNotifyUserByID(gomock.Any(), gomock.Any(), gomock.Any()).Return(&query.NotifyUser{
ID: userID,
ResourceOwner: orgID,
LastEmail: lastEmail,
VerifiedEmail: verifiedEmail,
PreferredLoginName: preferredLoginName,
LastPhone: lastPhone,
VerifiedPhone: verifiedPhone,
}, nil)
queries.EXPECT().GetDefaultLanguage(gomock.Any()).Return(language.English)
queries.EXPECT().CustomTextListByTemplate(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any()).Times(2).Return(&query.CustomTexts{}, nil)

View File

@@ -12,6 +12,9 @@ type SMS struct {
RecipientPhoneNumber string
Content string
TriggeringEvent eventstore.Event
// VerificationID is set by the sender
VerificationID *string
}
func (msg *SMS) GetContent() (string, error) {

View File

@@ -0,0 +1,24 @@
package senders
type CodeGenerator interface {
VerifyCode(verificationID, code string) error
}
type CodeGeneratorInfo struct {
ID string `json:"id,omitempty"`
VerificationID string `json:"verificationId,omitempty"`
}
func (c *CodeGeneratorInfo) GetID() string {
if c == nil {
return ""
}
return c.ID
}
func (c *CodeGeneratorInfo) GetVerificationID() string {
if c == nil {
return ""
}
return c.VerificationID
}

View File

@@ -0,0 +1,3 @@
package senders
//go:generate mockgen -package mock -destination ./mock/code_generator.mock.go github.com/zitadel/zitadel/internal/notification/senders CodeGenerator

View File

@@ -0,0 +1,53 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/zitadel/internal/notification/senders (interfaces: CodeGenerator)
//
// Generated by this command:
//
// mockgen -package mock -destination ./mock/code_generator.mock.go github.com/zitadel/zitadel/internal/notification/senders CodeGenerator
//
// Package mock is a generated GoMock package.
package mock
import (
reflect "reflect"
gomock "go.uber.org/mock/gomock"
)
// MockCodeGenerator is a mock of CodeGenerator interface.
type MockCodeGenerator struct {
ctrl *gomock.Controller
recorder *MockCodeGeneratorMockRecorder
}
// MockCodeGeneratorMockRecorder is the mock recorder for MockCodeGenerator.
type MockCodeGeneratorMockRecorder struct {
mock *MockCodeGenerator
}
// NewMockCodeGenerator creates a new mock instance.
func NewMockCodeGenerator(ctrl *gomock.Controller) *MockCodeGenerator {
mock := &MockCodeGenerator{ctrl: ctrl}
mock.recorder = &MockCodeGeneratorMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use.
func (m *MockCodeGenerator) EXPECT() *MockCodeGeneratorMockRecorder {
return m.recorder
}
// VerifyCode mocks base method.
func (m *MockCodeGenerator) VerifyCode(arg0, arg1 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "VerifyCode", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// VerifyCode indicates an expected call of VerifyCode.
func (mr *MockCodeGeneratorMockRecorder) VerifyCode(arg0, arg1 any) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "VerifyCode", reflect.TypeOf((*MockCodeGenerator)(nil).VerifyCode), arg0, arg1)
}

View File

@@ -87,6 +87,7 @@ func SendSMS(
user *query.NotifyUser,
colors *query.LabelPolicy,
triggeringEvent eventstore.Event,
generatorInfo *senders.CodeGeneratorInfo,
) Notify {
return func(
url string,
@@ -104,6 +105,7 @@ func SendSMS(
args,
allowUnverifiedNotificationChannel,
triggeringEvent,
generatorInfo,
)
}
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/notification/messages"
"github.com/zitadel/zitadel/internal/notification/senders"
"github.com/zitadel/zitadel/internal/notification/templates"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/zerrors"
@@ -27,6 +28,7 @@ func generateSms(
args map[string]interface{},
lastPhone bool,
triggeringEvent eventstore.Event,
generatorInfo *senders.CodeGeneratorInfo,
) error {
smsChannels, config, err := channels.SMS(ctx)
logging.OnError(err).Error("could not create sms channel")
@@ -48,7 +50,15 @@ func generateSms(
Content: data.Text,
TriggeringEvent: triggeringEvent,
}
return smsChannels.HandleMessage(message)
err = smsChannels.HandleMessage(message)
if err != nil {
return err
}
if config.TwilioConfig.VerifyServiceSID != "" {
generatorInfo.ID = config.ProviderConfig.ID
generatorInfo.VerificationID = *message.VerificationID
}
return nil
}
if config.WebhookConfig != nil {
caseArgs := make(map[string]interface{}, len(args))