fix: use triggering origin for notification links (#6628)

* take baseurl if saved on event

* refactor: make es mocks reusable

* Revert "refactor: make es mocks reusable"

This reverts commit 434ce12a6a.

* make messages testable

* test asset url

* fmt

* fmt

* simplify notification.Start

* test url combinations

* support init code added

* support password changed

* support reset pw

* support user domain claimed

* support add pwless login

* support verify phone

* Revert "support verify phone"

This reverts commit e40503303e.

* save trigger origin from ctx

* add ready for review check

* camel

* test email otp

* fix variable naming

* fix DefaultOTPEmailURLV2

* Revert "fix DefaultOTPEmailURLV2"

This reverts commit fa34d4d2a8.

* fix email otp challenged test

* fix email otp challenged test

* pass origin in login and gateway requests

* take origin from header

* take x-forwarded if present

* Update internal/notification/handlers/queries.go

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>

* Update internal/notification/handlers/commands.go

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>

* move origin header to ctx if available

* generate

* cleanup

* use forwarded header

* support X-Forwarded-* headers

* standardize context handling

* fix linting

---------

Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
Elio Bischof
2023-10-10 15:20:53 +02:00
committed by GitHub
parent 0180779d6d
commit 8f6cb47567
47 changed files with 2405 additions and 508 deletions

View File

@@ -0,0 +1,24 @@
package handlers
import (
"context"
"github.com/zitadel/zitadel/internal/repository/milestone"
"github.com/zitadel/zitadel/internal/repository/quota"
)
type Commands interface {
HumanInitCodeSent(ctx context.Context, orgID, userID string) error
HumanEmailVerificationCodeSent(ctx context.Context, orgID, userID string) error
PasswordCodeSent(ctx context.Context, orgID, userID string) error
HumanOTPSMSCodeSent(ctx context.Context, userID, resourceOwner string) error
HumanOTPEmailCodeSent(ctx context.Context, userID, resourceOwner string) error
OTPSMSSent(ctx context.Context, sessionID, resourceOwner string) error
OTPEmailSent(ctx context.Context, sessionID, resourceOwner string) error
UserDomainClaimedSent(ctx context.Context, orgID, userID string) error
HumanPasswordlessInitCodeSent(ctx context.Context, userID, resourceOwner, codeID string) error
PasswordChangeSent(ctx context.Context, orgID, userID string) error
HumanPhoneVerificationCodeSent(ctx context.Context, orgID, userID string) error
UsageNotificationSent(ctx context.Context, dueEvent *quota.NotificationDueEvent) error
MilestonePushed(ctx context.Context, msType milestone.Type, endpoints []string, primaryDomain string) error
}

View File

@@ -0,0 +1,4 @@
package handlers
//go:generate mockgen -package mock -destination ./mock/queries.mock.go github.com/zitadel/zitadel/internal/notification/handlers Queries
//go:generate mockgen -package mock -destination ./mock/commands.mock.go github.com/zitadel/zitadel/internal/notification/handlers Commands

View File

@@ -0,0 +1,218 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/zitadel/internal/notification/handlers (interfaces: Commands)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
gomock "github.com/golang/mock/gomock"
milestone "github.com/zitadel/zitadel/internal/repository/milestone"
quota "github.com/zitadel/zitadel/internal/repository/quota"
reflect "reflect"
)
// MockCommands is a mock of Commands interface
type MockCommands struct {
ctrl *gomock.Controller
recorder *MockCommandsMockRecorder
}
// MockCommandsMockRecorder is the mock recorder for MockCommands
type MockCommandsMockRecorder struct {
mock *MockCommands
}
// NewMockCommands creates a new mock instance
func NewMockCommands(ctrl *gomock.Controller) *MockCommands {
mock := &MockCommands{ctrl: ctrl}
mock.recorder = &MockCommandsMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockCommands) EXPECT() *MockCommandsMockRecorder {
return m.recorder
}
// HumanEmailVerificationCodeSent mocks base method
func (m *MockCommands) HumanEmailVerificationCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanEmailVerificationCodeSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// HumanEmailVerificationCodeSent indicates an expected call of HumanEmailVerificationCodeSent
func (mr *MockCommandsMockRecorder) HumanEmailVerificationCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanEmailVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanEmailVerificationCodeSent), arg0, arg1, arg2)
}
// HumanInitCodeSent mocks base method
func (m *MockCommands) HumanInitCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanInitCodeSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// HumanInitCodeSent indicates an expected call of HumanInitCodeSent
func (mr *MockCommandsMockRecorder) HumanInitCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanInitCodeSent), arg0, arg1, arg2)
}
// HumanOTPEmailCodeSent mocks base method
func (m *MockCommands) HumanOTPEmailCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanOTPEmailCodeSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// HumanOTPEmailCodeSent indicates an expected call of HumanOTPEmailCodeSent
func (mr *MockCommandsMockRecorder) HumanOTPEmailCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPEmailCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPEmailCodeSent), arg0, arg1, arg2)
}
// HumanOTPSMSCodeSent mocks base method
func (m *MockCommands) HumanOTPSMSCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanOTPSMSCodeSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// HumanOTPSMSCodeSent indicates an expected call of HumanOTPSMSCodeSent
func (mr *MockCommandsMockRecorder) HumanOTPSMSCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanOTPSMSCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanOTPSMSCodeSent), arg0, arg1, arg2)
}
// HumanPasswordlessInitCodeSent mocks base method
func (m *MockCommands) HumanPasswordlessInitCodeSent(arg0 context.Context, arg1, arg2, arg3 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanPasswordlessInitCodeSent", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// HumanPasswordlessInitCodeSent indicates an expected call of HumanPasswordlessInitCodeSent
func (mr *MockCommandsMockRecorder) HumanPasswordlessInitCodeSent(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPasswordlessInitCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPasswordlessInitCodeSent), arg0, arg1, arg2, arg3)
}
// HumanPhoneVerificationCodeSent mocks base method
func (m *MockCommands) HumanPhoneVerificationCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "HumanPhoneVerificationCodeSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// HumanPhoneVerificationCodeSent indicates an expected call of HumanPhoneVerificationCodeSent
func (mr *MockCommandsMockRecorder) HumanPhoneVerificationCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "HumanPhoneVerificationCodeSent", reflect.TypeOf((*MockCommands)(nil).HumanPhoneVerificationCodeSent), arg0, arg1, arg2)
}
// MilestonePushed mocks base method
func (m *MockCommands) MilestonePushed(arg0 context.Context, arg1 milestone.Type, arg2 []string, arg3 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MilestonePushed", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(error)
return ret0
}
// MilestonePushed indicates an expected call of MilestonePushed
func (mr *MockCommandsMockRecorder) MilestonePushed(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MilestonePushed", reflect.TypeOf((*MockCommands)(nil).MilestonePushed), arg0, arg1, arg2, arg3)
}
// OTPEmailSent mocks base method
func (m *MockCommands) OTPEmailSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OTPEmailSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// OTPEmailSent indicates an expected call of OTPEmailSent
func (mr *MockCommandsMockRecorder) OTPEmailSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPEmailSent", reflect.TypeOf((*MockCommands)(nil).OTPEmailSent), arg0, arg1, arg2)
}
// OTPSMSSent mocks base method
func (m *MockCommands) OTPSMSSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "OTPSMSSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// OTPSMSSent indicates an expected call of OTPSMSSent
func (mr *MockCommandsMockRecorder) OTPSMSSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "OTPSMSSent", reflect.TypeOf((*MockCommands)(nil).OTPSMSSent), arg0, arg1, arg2)
}
// PasswordChangeSent mocks base method
func (m *MockCommands) PasswordChangeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PasswordChangeSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// PasswordChangeSent indicates an expected call of PasswordChangeSent
func (mr *MockCommandsMockRecorder) PasswordChangeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordChangeSent", reflect.TypeOf((*MockCommands)(nil).PasswordChangeSent), arg0, arg1, arg2)
}
// PasswordCodeSent mocks base method
func (m *MockCommands) PasswordCodeSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "PasswordCodeSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// PasswordCodeSent indicates an expected call of PasswordCodeSent
func (mr *MockCommandsMockRecorder) PasswordCodeSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "PasswordCodeSent", reflect.TypeOf((*MockCommands)(nil).PasswordCodeSent), arg0, arg1, arg2)
}
// UsageNotificationSent mocks base method
func (m *MockCommands) UsageNotificationSent(arg0 context.Context, arg1 *quota.NotificationDueEvent) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UsageNotificationSent", arg0, arg1)
ret0, _ := ret[0].(error)
return ret0
}
// UsageNotificationSent indicates an expected call of UsageNotificationSent
func (mr *MockCommandsMockRecorder) UsageNotificationSent(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UsageNotificationSent", reflect.TypeOf((*MockCommands)(nil).UsageNotificationSent), arg0, arg1)
}
// UserDomainClaimedSent mocks base method
func (m *MockCommands) UserDomainClaimedSent(arg0 context.Context, arg1, arg2 string) error {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UserDomainClaimedSent", arg0, arg1, arg2)
ret0, _ := ret[0].(error)
return ret0
}
// UserDomainClaimedSent indicates an expected call of UserDomainClaimedSent
func (mr *MockCommandsMockRecorder) UserDomainClaimedSent(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UserDomainClaimedSent", reflect.TypeOf((*MockCommands)(nil).UserDomainClaimedSent), arg0, arg1, arg2)
}

View File

@@ -0,0 +1,226 @@
// Code generated by MockGen. DO NOT EDIT.
// Source: github.com/zitadel/zitadel/internal/notification/handlers (interfaces: Queries)
// Package mock is a generated GoMock package.
package mock
import (
context "context"
gomock "github.com/golang/mock/gomock"
domain "github.com/zitadel/zitadel/internal/domain"
query "github.com/zitadel/zitadel/internal/query"
language "golang.org/x/text/language"
reflect "reflect"
)
// MockQueries is a mock of Queries interface
type MockQueries struct {
ctrl *gomock.Controller
recorder *MockQueriesMockRecorder
}
// MockQueriesMockRecorder is the mock recorder for MockQueries
type MockQueriesMockRecorder struct {
mock *MockQueries
}
// NewMockQueries creates a new mock instance
func NewMockQueries(ctrl *gomock.Controller) *MockQueries {
mock := &MockQueries{ctrl: ctrl}
mock.recorder = &MockQueriesMockRecorder{mock}
return mock
}
// EXPECT returns an object that allows the caller to indicate expected use
func (m *MockQueries) EXPECT() *MockQueriesMockRecorder {
return m.recorder
}
// ActiveLabelPolicyByOrg mocks base method
func (m *MockQueries) ActiveLabelPolicyByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.LabelPolicy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "ActiveLabelPolicyByOrg", arg0, arg1, arg2)
ret0, _ := ret[0].(*query.LabelPolicy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// ActiveLabelPolicyByOrg indicates an expected call of ActiveLabelPolicyByOrg
func (mr *MockQueriesMockRecorder) ActiveLabelPolicyByOrg(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ActiveLabelPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).ActiveLabelPolicyByOrg), arg0, arg1, arg2)
}
// CustomTextListByTemplate mocks base method
func (m *MockQueries) CustomTextListByTemplate(arg0 context.Context, arg1, arg2 string, arg3 bool) (*query.CustomTexts, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "CustomTextListByTemplate", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*query.CustomTexts)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// CustomTextListByTemplate indicates an expected call of CustomTextListByTemplate
func (mr *MockQueriesMockRecorder) CustomTextListByTemplate(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CustomTextListByTemplate", reflect.TypeOf((*MockQueries)(nil).CustomTextListByTemplate), arg0, arg1, arg2, arg3)
}
// GetDefaultLanguage mocks base method
func (m *MockQueries) GetDefaultLanguage(arg0 context.Context) language.Tag {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "GetDefaultLanguage", arg0)
ret0, _ := ret[0].(language.Tag)
return ret0
}
// GetDefaultLanguage indicates an expected call of GetDefaultLanguage
func (mr *MockQueriesMockRecorder) GetDefaultLanguage(arg0 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetDefaultLanguage", reflect.TypeOf((*MockQueries)(nil).GetDefaultLanguage), arg0)
}
// GetNotifyUserByID mocks base method
func (m *MockQueries) GetNotifyUserByID(arg0 context.Context, arg1 bool, arg2 string, arg3 bool, arg4 ...query.SearchQuery) (*query.NotifyUser, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0, arg1, arg2, arg3}
for _, a := range arg4 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "GetNotifyUserByID", varargs...)
ret0, _ := ret[0].(*query.NotifyUser)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// GetNotifyUserByID indicates an expected call of GetNotifyUserByID
func (mr *MockQueriesMockRecorder) GetNotifyUserByID(arg0, arg1, arg2, arg3 interface{}, arg4 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0, arg1, arg2, arg3}, arg4...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetNotifyUserByID", reflect.TypeOf((*MockQueries)(nil).GetNotifyUserByID), varargs...)
}
// MailTemplateByOrg mocks base method
func (m *MockQueries) MailTemplateByOrg(arg0 context.Context, arg1 string, arg2 bool) (*query.MailTemplate, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "MailTemplateByOrg", arg0, arg1, arg2)
ret0, _ := ret[0].(*query.MailTemplate)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// MailTemplateByOrg indicates an expected call of MailTemplateByOrg
func (mr *MockQueriesMockRecorder) MailTemplateByOrg(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "MailTemplateByOrg", reflect.TypeOf((*MockQueries)(nil).MailTemplateByOrg), arg0, arg1, arg2)
}
// NotificationPolicyByOrg mocks base method
func (m *MockQueries) NotificationPolicyByOrg(arg0 context.Context, arg1 bool, arg2 string, arg3 bool) (*query.NotificationPolicy, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NotificationPolicyByOrg", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*query.NotificationPolicy)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NotificationPolicyByOrg indicates an expected call of NotificationPolicyByOrg
func (mr *MockQueriesMockRecorder) NotificationPolicyByOrg(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationPolicyByOrg", reflect.TypeOf((*MockQueries)(nil).NotificationPolicyByOrg), arg0, arg1, arg2, arg3)
}
// NotificationProviderByIDAndType mocks base method
func (m *MockQueries) NotificationProviderByIDAndType(arg0 context.Context, arg1 string, arg2 domain.NotificationProviderType) (*query.DebugNotificationProvider, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "NotificationProviderByIDAndType", arg0, arg1, arg2)
ret0, _ := ret[0].(*query.DebugNotificationProvider)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// NotificationProviderByIDAndType indicates an expected call of NotificationProviderByIDAndType
func (mr *MockQueriesMockRecorder) NotificationProviderByIDAndType(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "NotificationProviderByIDAndType", reflect.TypeOf((*MockQueries)(nil).NotificationProviderByIDAndType), arg0, arg1, arg2)
}
// SMSProviderConfig mocks base method
func (m *MockQueries) SMSProviderConfig(arg0 context.Context, arg1 ...query.SearchQuery) (*query.SMSConfig, error) {
m.ctrl.T.Helper()
varargs := []interface{}{arg0}
for _, a := range arg1 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "SMSProviderConfig", varargs...)
ret0, _ := ret[0].(*query.SMSConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SMSProviderConfig indicates an expected call of SMSProviderConfig
func (mr *MockQueriesMockRecorder) SMSProviderConfig(arg0 interface{}, arg1 ...interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
varargs := append([]interface{}{arg0}, arg1...)
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMSProviderConfig", reflect.TypeOf((*MockQueries)(nil).SMSProviderConfig), varargs...)
}
// SMTPConfigByAggregateID mocks base method
func (m *MockQueries) SMTPConfigByAggregateID(arg0 context.Context, arg1 string) (*query.SMTPConfig, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SMTPConfigByAggregateID", arg0, arg1)
ret0, _ := ret[0].(*query.SMTPConfig)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SMTPConfigByAggregateID indicates an expected call of SMTPConfigByAggregateID
func (mr *MockQueriesMockRecorder) SMTPConfigByAggregateID(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SMTPConfigByAggregateID", reflect.TypeOf((*MockQueries)(nil).SMTPConfigByAggregateID), arg0, arg1)
}
// SearchInstanceDomains mocks base method
func (m *MockQueries) SearchInstanceDomains(arg0 context.Context, arg1 *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SearchInstanceDomains", arg0, arg1)
ret0, _ := ret[0].(*query.InstanceDomains)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SearchInstanceDomains indicates an expected call of SearchInstanceDomains
func (mr *MockQueriesMockRecorder) SearchInstanceDomains(arg0, arg1 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchInstanceDomains", reflect.TypeOf((*MockQueries)(nil).SearchInstanceDomains), arg0, arg1)
}
// SearchMilestones mocks base method
func (m *MockQueries) SearchMilestones(arg0 context.Context, arg1 []string, arg2 *query.MilestonesSearchQueries) (*query.Milestones, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SearchMilestones", arg0, arg1, arg2)
ret0, _ := ret[0].(*query.Milestones)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SearchMilestones indicates an expected call of SearchMilestones
func (mr *MockQueriesMockRecorder) SearchMilestones(arg0, arg1, arg2 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SearchMilestones", reflect.TypeOf((*MockQueries)(nil).SearchMilestones), arg0, arg1, arg2)
}
// SessionByID mocks base method
func (m *MockQueries) SessionByID(arg0 context.Context, arg1 bool, arg2, arg3 string) (*query.Session, error) {
m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "SessionByID", arg0, arg1, arg2, arg3)
ret0, _ := ret[0].(*query.Session)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// SessionByID indicates an expected call of SessionByID
func (mr *MockQueriesMockRecorder) SessionByID(arg0, arg1, arg2, arg3 interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SessionByID", reflect.TypeOf((*MockQueries)(nil).SessionByID), arg0, arg1, arg2, arg3)
}

View File

@@ -2,27 +2,56 @@ package handlers
import (
"context"
"fmt"
"net/url"
"github.com/zitadel/zitadel/internal/api/authz"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/eventstore"
"github.com/zitadel/zitadel/internal/query"
)
func (n *NotificationQueries) Origin(ctx context.Context) (context.Context, string, error) {
type OriginEvent interface {
eventstore.Event
TriggerOrigin() string
}
func (n *NotificationQueries) Origin(ctx context.Context, e eventstore.Event) (context.Context, error) {
originEvent, ok := e.(OriginEvent)
if !ok {
return ctx, errors.ThrowInternal(fmt.Errorf("event of type %T doesn't implement OriginEvent", e), "NOTIF-3m9fs", "Errors.Internal")
}
origin := originEvent.TriggerOrigin()
if origin != "" {
originURL, err := url.Parse(origin)
if err != nil {
return ctx, err
}
return enrichCtx(ctx, originURL.Hostname(), origin), nil
}
primary, err := query.NewInstanceDomainPrimarySearchQuery(true)
if err != nil {
return ctx, "", err
return ctx, err
}
domains, err := n.SearchInstanceDomains(ctx, &query.InstanceDomainSearchQueries{
Queries: []query.SearchQuery{primary},
})
if err != nil {
return ctx, "", err
return ctx, err
}
if len(domains.Domains) < 1 {
return ctx, "", errors.ThrowInternal(nil, "NOTIF-Ef3r1", "Errors.Notification.NoDomain")
return ctx, errors.ThrowInternal(nil, "NOTIF-Ef3r1", "Errors.Notification.NoDomain")
}
ctx = authz.WithRequestedDomain(ctx, domains.Domains[0].Domain)
return ctx, http_utils.BuildHTTP(domains.Domains[0].Domain, n.externalPort, n.externalSecure), nil
return enrichCtx(
ctx,
domains.Domains[0].Domain,
http_utils.BuildHTTP(domains.Domains[0].Domain, n.externalPort, n.externalSecure),
), nil
}
func enrichCtx(ctx context.Context, host, origin string) context.Context {
ctx = authz.WithRequestedDomain(ctx, host)
ctx = http_utils.WithComposedOrigin(ctx, origin)
return ctx
}

View File

@@ -1,16 +1,34 @@
package handlers
import (
"context"
"net/http"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/eventstore"
_ "github.com/zitadel/zitadel/internal/notification/statik"
"github.com/zitadel/zitadel/internal/query"
)
type Queries interface {
ActiveLabelPolicyByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.LabelPolicy, error)
MailTemplateByOrg(ctx context.Context, orgID string, withOwnerRemoved bool) (*query.MailTemplate, error)
GetNotifyUserByID(ctx context.Context, shouldTriggered bool, userID string, withOwnerRemoved bool, queries ...query.SearchQuery) (*query.NotifyUser, error)
CustomTextListByTemplate(ctx context.Context, aggregateID, template string, withOwnerRemoved bool) (*query.CustomTexts, error)
SearchInstanceDomains(ctx context.Context, queries *query.InstanceDomainSearchQueries) (*query.InstanceDomains, error)
SessionByID(ctx context.Context, shouldTriggerBulk bool, id, sessionToken string) (*query.Session, error)
NotificationPolicyByOrg(ctx context.Context, shouldTriggerBulk bool, orgID string, withOwnerRemoved bool) (*query.NotificationPolicy, error)
SearchMilestones(ctx context.Context, instanceIDs []string, queries *query.MilestonesSearchQueries) (*query.Milestones, error)
NotificationProviderByIDAndType(ctx context.Context, aggID string, providerType domain.NotificationProviderType) (*query.DebugNotificationProvider, error)
SMSProviderConfig(ctx context.Context, queries ...query.SearchQuery) (*query.SMSConfig, error)
SMTPConfigByAggregateID(ctx context.Context, aggregateID string) (*query.SMTPConfig, error)
GetDefaultLanguage(ctx context.Context) language.Tag
}
type NotificationQueries struct {
*query.Queries
Queries
es *eventstore.Eventstore
externalDomain string
externalPort uint16
@@ -23,7 +41,7 @@ type NotificationQueries struct {
}
func NewNotificationQueries(
baseQueries *query.Queries,
baseQueries Queries,
es *eventstore.Eventstore,
externalDomain string,
externalPort uint16,

View File

@@ -22,10 +22,9 @@ const (
type quotaNotifier struct {
crdb.StatementHandler
commands *command.Commands
queries *NotificationQueries
metricSuccessfulDeliveriesJSON string
metricFailedDeliveriesJSON string
commands *command.Commands
queries *NotificationQueries
channels types.ChannelChains
}
func NewQuotaNotifier(
@@ -33,8 +32,7 @@ func NewQuotaNotifier(
config crdb.StatementHandlerConfig,
commands *command.Commands,
queries *NotificationQueries,
metricSuccessfulDeliveriesJSON,
metricFailedDeliveriesJSON string,
channels types.ChannelChains,
) *quotaNotifier {
p := new(quotaNotifier)
config.ProjectionName = QuotaNotificationsProjectionTable
@@ -42,8 +40,7 @@ func NewQuotaNotifier(
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
p.commands = commands
p.queries = queries
p.metricSuccessfulDeliveriesJSON = metricSuccessfulDeliveriesJSON
p.metricFailedDeliveriesJSON = metricFailedDeliveriesJSON
p.channels = channels
projection.NotificationsQuotaProjection = p
return p
}
@@ -75,19 +72,7 @@ func (u *quotaNotifier) reduceNotificationDue(event eventstore.Event) (*handler.
if alreadyHandled {
return crdb.NewNoOpStatement(e), nil
}
err = types.SendJSON(
ctx,
webhook.Config{
CallURL: e.CallURL,
Method: http.MethodPost,
},
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
e,
e,
u.metricSuccessfulDeliveriesJSON,
u.metricFailedDeliveriesJSON,
).WithoutTemplate()
err = types.SendJSON(ctx, webhook.Config{CallURL: e.CallURL, Method: http.MethodPost}, u.channels, e, e).WithoutTemplate()
if err != nil {
return nil, err
}

View File

@@ -38,11 +38,10 @@ type TelemetryPusherConfig struct {
type telemetryPusher struct {
crdb.StatementHandler
cfg TelemetryPusherConfig
commands *command.Commands
queries *NotificationQueries
metricSuccessfulDeliveriesJSON string
metricFailedDeliveriesJSON string
cfg TelemetryPusherConfig
commands *command.Commands
queries *NotificationQueries
channels types.ChannelChains
}
func NewTelemetryPusher(
@@ -51,8 +50,7 @@ func NewTelemetryPusher(
handlerCfg crdb.StatementHandlerConfig,
commands *command.Commands,
queries *NotificationQueries,
metricSuccessfulDeliveriesJSON,
metricFailedDeliveriesJSON string,
channels types.ChannelChains,
) *telemetryPusher {
p := new(telemetryPusher)
handlerCfg.ProjectionName = TelemetryProjectionTable
@@ -62,8 +60,7 @@ func NewTelemetryPusher(
p.StatementHandler = crdb.NewStatementHandler(ctx, handlerCfg)
p.commands = commands
p.queries = queries
p.metricSuccessfulDeliveriesJSON = metricSuccessfulDeliveriesJSON
p.metricFailedDeliveriesJSON = metricFailedDeliveriesJSON
p.channels = channels
projection.TelemetryPusherProjection = p
return p
}
@@ -132,8 +129,7 @@ func (t *telemetryPusher) pushMilestone(ctx context.Context, event *pseudo.Sched
Method: http.MethodPost,
Headers: t.cfg.Headers,
},
t.queries.GetFileSystemProvider,
t.queries.GetLogProvider,
t.channels,
&struct {
InstanceID string `json:"instanceId"`
ExternalDomain string `json:"externalDomain"`
@@ -148,8 +144,6 @@ func (t *telemetryPusher) pushMilestone(ctx context.Context, event *pseudo.Sched
ReachedDate: ms.ReachedDate,
},
event,
t.metricSuccessfulDeliveriesJSON,
t.metricFailedDeliveriesJSON,
).WithoutTemplate(); err != nil {
return err
}

View File

@@ -5,9 +5,8 @@ import (
"strings"
"time"
"github.com/zitadel/zitadel/internal/api/authz"
http_util "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/command"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
@@ -27,27 +26,19 @@ const (
type userNotifier struct {
crdb.StatementHandler
commands *command.Commands
commands Commands
queries *NotificationQueries
assetsPrefix func(context.Context) string
channels types.ChannelChains
otpEmailTmpl string
metricSuccessfulDeliveriesEmail,
metricFailedDeliveriesEmail,
metricSuccessfulDeliveriesSMS,
metricFailedDeliveriesSMS string
}
func NewUserNotifier(
ctx context.Context,
config crdb.StatementHandlerConfig,
commands *command.Commands,
commands Commands,
queries *NotificationQueries,
assetsPrefix func(context.Context) string,
channels types.ChannelChains,
otpEmailTmpl string,
metricSuccessfulDeliveriesEmail,
metricFailedDeliveriesEmail,
metricSuccessfulDeliveriesSMS,
metricFailedDeliveriesSMS string,
) *userNotifier {
p := new(userNotifier)
config.ProjectionName = UserNotificationsProjectionTable
@@ -55,12 +46,8 @@ func NewUserNotifier(
p.StatementHandler = crdb.NewStatementHandler(ctx, config)
p.commands = commands
p.queries = queries
p.assetsPrefix = assetsPrefix
p.channels = channels
p.otpEmailTmpl = otpEmailTmpl
p.metricSuccessfulDeliveriesEmail = metricSuccessfulDeliveriesEmail
p.metricFailedDeliveriesEmail = metricFailedDeliveriesEmail
p.metricSuccessfulDeliveriesSMS = metricSuccessfulDeliveriesSMS
p.metricFailedDeliveriesSMS = metricFailedDeliveriesSMS
projection.NotificationsProjection = p
return p
}
@@ -177,25 +164,12 @@ func (u *userNotifier) reduceInitCodeAdded(event eventstore.Event) (*handler.Sta
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, e)
if err != nil {
return nil, err
}
err = types.SendEmail(
ctx,
string(template.Template),
translator,
notifyUser,
u.queries.GetSMTPConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesEmail,
u.metricFailedDeliveriesEmail,
).SendUserInitCode(notifyUser, origin, code)
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e).
SendUserInitCode(ctx, notifyUser, code)
if err != nil {
return nil, err
}
@@ -247,25 +221,12 @@ func (u *userNotifier) reduceEmailCodeAdded(event eventstore.Event) (*handler.St
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, e)
if err != nil {
return nil, err
}
err = types.SendEmail(
ctx,
string(template.Template),
translator,
notifyUser,
u.queries.GetSMTPConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesEmail,
u.metricFailedDeliveriesEmail,
).SendEmailVerificationCode(notifyUser, origin, code, e.URLTemplate)
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e).
SendEmailVerificationCode(ctx, notifyUser, code, e.URLTemplate)
if err != nil {
return nil, err
}
@@ -316,41 +277,15 @@ func (u *userNotifier) reducePasswordCodeAdded(event eventstore.Event) (*handler
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, e)
if err != nil {
return nil, err
}
notify := types.SendEmail(
ctx,
string(template.Template),
translator,
notifyUser,
u.queries.GetSMTPConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesEmail,
u.metricFailedDeliveriesEmail,
)
notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e)
if e.NotificationType == domain.NotificationTypeSms {
notify = types.SendSMSTwilio(
ctx,
translator,
notifyUser,
u.queries.GetTwilioConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesSMS,
u.metricFailedDeliveriesSMS,
)
notify = types.SendSMSTwilio(ctx, u.channels, translator, notifyUser, colors, e)
}
err = notify.SendPasswordCode(notifyUser, origin, code, e.URLTemplate)
err = notify.SendPasswordCode(ctx, notifyUser, code, e.URLTemplate)
if err != nil {
return nil, err
}
@@ -437,25 +372,12 @@ func (u *userNotifier) reduceOTPSMS(
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, event)
if err != nil {
return nil, err
}
notify := types.SendSMSTwilio(
ctx,
translator,
notifyUser,
u.queries.GetTwilioConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
event,
u.metricSuccessfulDeliveriesSMS,
u.metricFailedDeliveriesSMS,
)
err = notify.SendOTPSMSCode(authz.GetInstance(ctx).RequestedDomain(), origin, plainCode, expiry)
notify := types.SendSMSTwilio(ctx, u.channels, translator, notifyUser, colors, event)
err = notify.SendOTPSMSCode(ctx, plainCode, expiry)
if err != nil {
return nil, err
}
@@ -568,30 +490,16 @@ func (u *userNotifier) reduceOTPEmail(
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, event)
if err != nil {
return nil, err
}
url, err := urlTmpl(plainCode, origin, notifyUser)
url, err := urlTmpl(plainCode, http_util.ComposedOrigin(ctx), notifyUser)
if err != nil {
return nil, err
}
notify := types.SendEmail(
ctx,
string(template.Template),
translator,
notifyUser,
u.queries.GetSMTPConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
event,
u.metricSuccessfulDeliveriesEmail,
u.metricFailedDeliveriesEmail,
)
err = notify.SendOTPEmailCode(notifyUser, url, authz.GetInstance(ctx).RequestedDomain(), origin, plainCode, expiry)
notify := types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, event)
err = notify.SendOTPEmailCode(ctx, url, plainCode, expiry)
if err != nil {
return nil, err
}
@@ -634,25 +542,12 @@ func (u *userNotifier) reduceDomainClaimed(event eventstore.Event) (*handler.Sta
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, e)
if err != nil {
return nil, err
}
err = types.SendEmail(
ctx,
string(template.Template),
translator,
notifyUser,
u.queries.GetSMTPConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesEmail,
u.metricFailedDeliveriesEmail,
).SendDomainClaimed(notifyUser, origin, e.UserName)
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e).
SendDomainClaimed(ctx, notifyUser, e.UserName)
if err != nil {
return nil, err
}
@@ -701,25 +596,12 @@ func (u *userNotifier) reducePasswordlessCodeRequested(event eventstore.Event) (
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, e)
if err != nil {
return nil, err
}
err = types.SendEmail(
ctx,
string(template.Template),
translator,
notifyUser,
u.queries.GetSMTPConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesEmail,
u.metricFailedDeliveriesEmail,
).SendPasswordlessRegistrationLink(notifyUser, origin, code, e.ID, e.URLTemplate)
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e).
SendPasswordlessRegistrationLink(ctx, notifyUser, code, e.ID, e.URLTemplate)
if err != nil {
return nil, err
}
@@ -771,25 +653,12 @@ func (u *userNotifier) reducePasswordChanged(event eventstore.Event) (*handler.S
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, e)
if err != nil {
return nil, err
}
err = types.SendEmail(
ctx,
string(template.Template),
translator,
notifyUser,
u.queries.GetSMTPConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesEmail,
u.metricFailedDeliveriesEmail,
).SendPasswordChange(notifyUser, origin)
err = types.SendEmail(ctx, u.channels, string(template.Template), translator, notifyUser, colors, e).
SendPasswordChange(ctx, notifyUser)
if err != nil {
return nil, err
}
@@ -836,24 +705,12 @@ func (u *userNotifier) reducePhoneCodeAdded(event eventstore.Event) (*handler.St
if err != nil {
return nil, err
}
ctx, origin, err := u.queries.Origin(ctx)
ctx, err = u.queries.Origin(ctx, e)
if err != nil {
return nil, err
}
err = types.SendSMSTwilio(
ctx,
translator,
notifyUser,
u.queries.GetTwilioConfig,
u.queries.GetFileSystemProvider,
u.queries.GetLogProvider,
colors,
u.assetsPrefix(ctx),
e,
u.metricSuccessfulDeliveriesSMS,
u.metricFailedDeliveriesSMS,
).SendPhoneVerificationCode(notifyUser, origin, code, authz.GetInstance(ctx).RequestedDomain())
err = types.SendSMSTwilio(ctx, u.channels, translator, notifyUser, colors, e).
SendPhoneVerificationCode(ctx, code)
if err != nil {
return nil, err
}

File diff suppressed because it is too large Load Diff