From c4197fef872db02e6f147ebe8d59425d5df2436e Mon Sep 17 00:00:00 2001 From: Elio Bischof Date: Tue, 19 Sep 2023 11:02:59 +0200 Subject: [PATCH] take baseurl if saved on event --- internal/notification/handlers/origin.go | 27 +++++++++++++++---- .../notification/handlers/user_notifier.go | 27 +++++++------------ 2 files changed, 31 insertions(+), 23 deletions(-) diff --git a/internal/notification/handlers/origin.go b/internal/notification/handlers/origin.go index 2e8549a18b..17e502bec7 100644 --- a/internal/notification/handlers/origin.go +++ b/internal/notification/handlers/origin.go @@ -2,27 +2,44 @@ package handlers import ( "context" + "fmt" "github.com/zitadel/zitadel/internal/api/authz" + + "github.com/zitadel/zitadel/internal/eventstore" + http_utils "github.com/zitadel/zitadel/internal/api/http" "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/query" ) -func (n *NotificationQueries) Origin(ctx context.Context) (context.Context, string, error) { +type BaseURLEvent interface { + eventstore.Event + GetBaseURL() string +} + +func (n *NotificationQueries) Origin(ctx context.Context, e eventstore.Event) (string, error) { + baseURLEvent, ok := e.(BaseURLEvent) + if !ok { + return "", errors.ThrowInternal(fmt.Errorf("event of type %T doesn't implement BaseURLEvent", e), "NOTIF-3m9fs", "Errors.Internal") + } + baseURL := baseURLEvent.GetBaseURL() + if baseURL != "" { + return baseURL, nil + } primary, err := query.NewInstanceDomainPrimarySearchQuery(true) if err != nil { - return ctx, "", err + return "", err } domains, err := n.SearchInstanceDomains(ctx, &query.InstanceDomainSearchQueries{ Queries: []query.SearchQuery{primary}, }) if err != nil { - return ctx, "", err + return "", err } if len(domains.Domains) < 1 { - return ctx, "", errors.ThrowInternal(nil, "NOTIF-Ef3r1", "Errors.Notification.NoDomain") + return "", errors.ThrowInternal(nil, "NOTIF-Ef3r1", "Errors.Notification.NoDomain") } ctx = authz.WithRequestedDomain(ctx, domains.Domains[0].Domain) - return ctx, http_utils.BuildHTTP(domains.Domains[0].Domain, n.externalPort, n.externalSecure), nil + return http_utils.BuildHTTP(domains.Domains[0].Domain, n.externalPort, n.externalSecure), nil } diff --git a/internal/notification/handlers/user_notifier.go b/internal/notification/handlers/user_notifier.go index 7b1d407662..86d3db2d3e 100644 --- a/internal/notification/handlers/user_notifier.go +++ b/internal/notification/handlers/user_notifier.go @@ -177,8 +177,7 @@ func (u *userNotifier) reduceInitCodeAdded(event eventstore.Event) (*handler.Sta if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, e) if err != nil { return nil, err } @@ -247,8 +246,7 @@ func (u *userNotifier) reduceEmailCodeAdded(event eventstore.Event) (*handler.St if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, e) if err != nil { return nil, err } @@ -316,8 +314,7 @@ func (u *userNotifier) reducePasswordCodeAdded(event eventstore.Event) (*handler if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, e) if err != nil { return nil, err } @@ -437,8 +434,7 @@ func (u *userNotifier) reduceOTPSMS( if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, event) if err != nil { return nil, err } @@ -568,8 +564,7 @@ func (u *userNotifier) reduceOTPEmail( if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, event) if err != nil { return nil, err } @@ -634,8 +629,7 @@ func (u *userNotifier) reduceDomainClaimed(event eventstore.Event) (*handler.Sta if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, e) if err != nil { return nil, err } @@ -701,8 +695,7 @@ func (u *userNotifier) reducePasswordlessCodeRequested(event eventstore.Event) ( if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, e) if err != nil { return nil, err } @@ -771,8 +764,7 @@ func (u *userNotifier) reducePasswordChanged(event eventstore.Event) (*handler.S if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, e) if err != nil { return nil, err } @@ -836,8 +828,7 @@ func (u *userNotifier) reducePhoneCodeAdded(event eventstore.Event) (*handler.St if err != nil { return nil, err } - - ctx, origin, err := u.queries.Origin(ctx) + origin, err := u.queries.Origin(ctx, e) if err != nil { return nil, err }