1
0
mirror of https://github.com/zitadel/zitadel.git synced 2024-12-23 00:07:31 +00:00
Silvan 99e1c654a3
feat(storage): read only transactions for queries ()
* fix: tests

* bastle wie en grosse

* fix(database): scan as callback

* fix tests

* fix merge failures

* remove as of system time

* refactor: remove unused test

* refacotr: remove unused lines
2023-08-22 10:49:22 +00:00

352 lines
10 KiB
Go

package query
import (
"context"
"database/sql"
"encoding/json"
errs "errors"
"fmt"
"io/ioutil"
"net/http"
"os"
"time"
sq "github.com/Masterminds/squirrel"
"golang.org/x/text/language"
"sigs.k8s.io/yaml"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/call"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/query/projection"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type MessageTexts struct {
InitCode MessageText
PasswordReset MessageText
VerifyEmail MessageText
VerifyPhone MessageText
VerifySMSOTP MessageText
VerifyEmailOTP MessageText
DomainClaimed MessageText
PasswordlessRegistration MessageText
PasswordChange MessageText
}
type MessageText struct {
AggregateID string
Sequence uint64
CreationDate time.Time
ChangeDate time.Time
State domain.PolicyState
IsDefault bool
Type string
Language language.Tag
Title string
PreHeader string
Subject string
Greeting string
Text string
ButtonText string
Footer string
}
var (
messageTextTable = table{
name: projection.MessageTextTable,
instanceIDCol: projection.MessageTextInstanceIDCol,
}
MessageTextColAggregateID = Column{
name: projection.MessageTextAggregateIDCol,
table: messageTextTable,
}
MessageTextColInstanceID = Column{
name: projection.MessageTextInstanceIDCol,
table: messageTextTable,
}
MessageTextColSequence = Column{
name: projection.MessageTextSequenceCol,
table: messageTextTable,
}
MessageTextColCreationDate = Column{
name: projection.MessageTextCreationDateCol,
table: messageTextTable,
}
MessageTextColChangeDate = Column{
name: projection.MessageTextChangeDateCol,
table: messageTextTable,
}
MessageTextColState = Column{
name: projection.MessageTextStateCol,
table: messageTextTable,
}
MessageTextColType = Column{
name: projection.MessageTextTypeCol,
table: messageTextTable,
}
MessageTextColLanguage = Column{
name: projection.MessageTextLanguageCol,
table: messageTextTable,
}
MessageTextColTitle = Column{
name: projection.MessageTextTitleCol,
table: messageTextTable,
}
MessageTextColPreHeader = Column{
name: projection.MessageTextPreHeaderCol,
table: messageTextTable,
}
MessageTextColSubject = Column{
name: projection.MessageTextSubjectCol,
table: messageTextTable,
}
MessageTextColGreeting = Column{
name: projection.MessageTextGreetingCol,
table: messageTextTable,
}
MessageTextColText = Column{
name: projection.MessageTextTextCol,
table: messageTextTable,
}
MessageTextColButtonText = Column{
name: projection.MessageTextButtonTextCol,
table: messageTextTable,
}
MessageTextColFooter = Column{
name: projection.MessageTextFooterCol,
table: messageTextTable,
}
MessageTextColOwnerRemoved = Column{
name: projection.MessageTextOwnerRemovedCol,
table: messageTextTable,
}
)
func (q *Queries) DefaultMessageText(ctx context.Context) (text *MessageText, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareMessageTextQuery(ctx, q.client)
query, args, err := stmt.Where(sq.Eq{
MessageTextColAggregateID.identifier(): authz.GetInstance(ctx).InstanceID(),
MessageTextColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}).
Limit(1).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
text, err = scan(row)
return err
}, query, args...)
return text, err
}
func (q *Queries) DefaultMessageTextByTypeAndLanguageFromFileSystem(ctx context.Context, messageType, language string) (_ *MessageText, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
contents, err := q.readNotificationTextMessages(ctx, language)
if err != nil {
return nil, err
}
messageTexts := new(MessageTexts)
if err := yaml.Unmarshal(contents, messageTexts); err != nil {
return nil, errors.ThrowInternal(err, "TEXT-3N9fs", "Errors.TranslationFile.ReadError")
}
return messageTexts.GetMessageTextByType(messageType), nil
}
func (q *Queries) CustomMessageTextByTypeAndLanguage(ctx context.Context, aggregateID, messageType, language string, withOwnerRemoved bool) (msg *MessageText, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
stmt, scan := prepareMessageTextQuery(ctx, q.client)
eq := sq.Eq{
MessageTextColLanguage.identifier(): language,
MessageTextColType.identifier(): messageType,
MessageTextColAggregateID.identifier(): aggregateID,
MessageTextColInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
}
if !withOwnerRemoved {
eq[MessageTextColOwnerRemoved.identifier()] = false
}
query, args, err := stmt.Where(eq).OrderBy(MessageTextColAggregateID.identifier()).Limit(1).ToSql()
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-1b9mf", "Errors.Query.SQLStatement")
}
err = q.client.QueryRowContext(ctx, func(row *sql.Row) error {
msg, err = scan(row)
return err
}, query, args...)
if errors.IsNotFound(err) {
return q.IAMMessageTextByTypeAndLanguage(ctx, messageType, language)
}
return msg, err
}
func (q *Queries) IAMMessageTextByTypeAndLanguage(ctx context.Context, messageType, language string) (_ *MessageText, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
contents, err := q.readNotificationTextMessages(ctx, language)
if err != nil {
return nil, err
}
notificationTextMap := make(map[string]interface{})
if err := yaml.Unmarshal(contents, &notificationTextMap); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-ekjFF", "Errors.TranslationFile.ReadError")
}
texts, err := q.CustomTextList(ctx, authz.GetInstance(ctx).InstanceID(), messageType, language, false)
if err != nil {
return nil, err
}
for _, text := range texts.CustomTexts {
messageTextMap, ok := notificationTextMap[messageType].(map[string]interface{})
if !ok {
continue
}
messageTextMap[text.Key] = text.Text
}
jsonbody, err := json.Marshal(notificationTextMap)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-3m8fJ", "Errors.TranslationFile.MergeError")
}
notificationText := new(MessageTexts)
if err := json.Unmarshal(jsonbody, &notificationText); err != nil {
return nil, errors.ThrowInternal(err, "QUERY-9MkfD", "Errors.TranslationFile.MergeError")
}
result := notificationText.GetMessageTextByType(messageType)
result.IsDefault = true
result.AggregateID = authz.GetInstance(ctx).InstanceID()
return result, nil
}
func (q *Queries) readNotificationTextMessages(ctx context.Context, language string) ([]byte, error) {
q.mutex.Lock()
defer q.mutex.Unlock()
var err error
contents, ok := q.NotificationTranslationFileContents[language]
if !ok {
contents, err = q.readTranslationFile(q.NotificationDir, fmt.Sprintf("/i18n/%s.yaml", language))
if errors.IsNotFound(err) {
contents, err = q.readTranslationFile(q.NotificationDir, fmt.Sprintf("/i18n/%s.yaml", authz.GetInstance(ctx).DefaultLanguage().String()))
}
if err != nil {
return nil, err
}
q.NotificationTranslationFileContents[language] = contents
}
return contents, nil
}
func prepareMessageTextQuery(ctx context.Context, db prepareDatabase) (sq.SelectBuilder, func(*sql.Row) (*MessageText, error)) {
return sq.Select(
MessageTextColAggregateID.identifier(),
MessageTextColSequence.identifier(),
MessageTextColCreationDate.identifier(),
MessageTextColChangeDate.identifier(),
MessageTextColState.identifier(),
MessageTextColType.identifier(),
MessageTextColLanguage.identifier(),
MessageTextColTitle.identifier(),
MessageTextColPreHeader.identifier(),
MessageTextColSubject.identifier(),
MessageTextColGreeting.identifier(),
MessageTextColText.identifier(),
MessageTextColButtonText.identifier(),
MessageTextColFooter.identifier(),
).
From(messageTextTable.identifier() + db.Timetravel(call.Took(ctx))).
PlaceholderFormat(sq.Dollar),
func(row *sql.Row) (*MessageText, error) {
msg := new(MessageText)
lang := ""
title := sql.NullString{}
preHeader := sql.NullString{}
subject := sql.NullString{}
greeting := sql.NullString{}
text := sql.NullString{}
buttonText := sql.NullString{}
footer := sql.NullString{}
err := row.Scan(
&msg.AggregateID,
&msg.Sequence,
&msg.CreationDate,
&msg.ChangeDate,
&msg.State,
&msg.Type,
&lang,
&title,
&preHeader,
&subject,
&greeting,
&text,
&buttonText,
&footer,
)
if err != nil {
if errs.Is(err, sql.ErrNoRows) {
return nil, errors.ThrowNotFound(err, "QUERY-3nlrS", "Errors.MessageText.NotFound")
}
return nil, errors.ThrowInternal(err, "QUERY-499gJ", "Errors.Internal")
}
msg.Language = language.Make(lang)
msg.Title = title.String
msg.PreHeader = preHeader.String
msg.Subject = subject.String
msg.Greeting = greeting.String
msg.Text = text.String
msg.ButtonText = buttonText.String
msg.Footer = footer.String
return msg, nil
}
}
func (q *Queries) readTranslationFile(dir http.FileSystem, filename string) ([]byte, error) {
r, err := dir.Open(filename)
if os.IsNotExist(err) {
return nil, errors.ThrowNotFound(err, "QUERY-sN9wg", "Errors.TranslationFile.NotFound")
}
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-93njw", "Errors.TranslationFile.ReadError")
}
contents, err := ioutil.ReadAll(r)
if err != nil {
return nil, errors.ThrowInternal(err, "QUERY-l0fse", "Errors.TranslationFile.ReadError")
}
return contents, nil
}
func (m *MessageTexts) GetMessageTextByType(msgType string) *MessageText {
switch msgType {
case domain.InitCodeMessageType:
return &m.InitCode
case domain.PasswordResetMessageType:
return &m.PasswordReset
case domain.VerifyEmailMessageType:
return &m.VerifyEmail
case domain.VerifyPhoneMessageType:
return &m.VerifyPhone
case domain.VerifySMSOTPMessageType:
return &m.VerifySMSOTP
case domain.VerifyEmailOTPMessageType:
return &m.VerifyEmailOTP
case domain.DomainClaimedMessageType:
return &m.DomainClaimed
case domain.PasswordlessRegistrationMessageType:
return &m.PasswordlessRegistration
case domain.PasswordChangeMessageType:
return &m.PasswordChange
}
return nil
}