diff --git a/backend/domain/database.go b/backend/domain/database.go new file mode 100644 index 0000000000..789386b24d --- /dev/null +++ b/backend/domain/database.go @@ -0,0 +1,45 @@ +package domain + +import ( + "context" + + "github.com/zitadel/zitadel/backend/storage/database" +) + +type poolHandler[T any] struct { + pool database.Pool + + client database.QueryExecutor +} + +func (h *poolHandler[T]) acquire(ctx context.Context, in T) (out T, _ func(context.Context, error) error, err error) { + client, err := h.pool.Acquire(ctx) + if err != nil { + return in, nil, err + } + h.client = client + + return in, func(ctx context.Context, _ error) error { return client.Release(ctx) }, nil +} + +func (h *poolHandler[T]) begin(ctx context.Context, in T) (out T, _ func(context.Context, error) error, err error) { + var beginner database.Beginner = h.pool + if h.client != nil { + beginner = h.client.(database.Beginner) + } + previousClient := h.client + tx, err := beginner.Begin(ctx, nil) + if err != nil { + return in, nil, err + } + h.client = tx + + return in, func(ctx context.Context, err error) error { + err = tx.End(ctx, err) + if err != nil { + return err + } + h.client = previousClient + return nil + }, nil +} diff --git a/backend/domain/domain.go b/backend/domain/domain.go new file mode 100644 index 0000000000..2a85f9790b --- /dev/null +++ b/backend/domain/domain.go @@ -0,0 +1,20 @@ +package domain + +import ( + "context" + + "github.com/zitadel/zitadel/backend/storage/database" +) + +type defaults struct { + db database.Pool +} + +type clientSetter interface { + setClient(database.QueryExecutor) +} + +func (d *defaults) acquire(ctx context.Context, setter clientSetter) { + d.db.Acquire(ctx) + setter.setClient(d.db) +} diff --git a/backend/domain/user.go b/backend/domain/user.go index 54ef8cb257..d5dfccde2f 100644 --- a/backend/domain/user.go +++ b/backend/domain/user.go @@ -3,16 +3,15 @@ package domain import ( "context" - "github.com/zitadel/zitadel/backend/handler" "github.com/zitadel/zitadel/backend/repository" "github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/internal/crypto" - "github.com/zitadel/zitadel/internal/domain" ) type User struct { - db database.Pool + defaults + userCodeAlg crypto.EncryptionAlgorithm user userRepository secretGenerator secretGeneratorRepository } @@ -44,83 +43,3 @@ func NewUser(db database.Pool) *User { return b } - -type VerifyEmail struct { - UserID string - Code string - Alg crypto.EncryptionAlgorithm - - client database.QueryExecutor - config *crypto.GeneratorConfig - gen crypto.Generator - code *repository.EmailVerificationCode - verificationErr error -} - -func (u *User) VerifyEmail(ctx context.Context, in *VerifyEmail) error { - _, err := handler.Deferrable( - func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) { - client, err := u.db.Acquire(ctx) - if err != nil { - return nil, nil, err - } - in.client = client - return in, func(ctx context.Context, _ error) error { return client.Release(ctx) }, err - }, - handler.Chains( - func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { - in.config, err = u.secretGenerator.GeneratorConfigByType(ctx, in.client, domain.SecretGeneratorTypeVerifyEmailCode) - return in, err - }, - func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { - in.gen = crypto.NewEncryptionGenerator(*in.config, in.Alg) - return in, nil - }, - handler.Deferrable( - func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) { - client := in.client - tx, err := in.client.(database.Client).Begin(ctx, nil) - if err != nil { - return nil, nil, err - } - in.client = tx - return in, func(ctx context.Context, err error) error { - err = tx.End(ctx, err) - if err != nil { - return err - } - in.client = client - return nil - }, err - }, - handler.Chains( - func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { - in.code, err = u.user.EmailVerificationCode(ctx, in.client, in.UserID) - return in, err - }, - func(ctx context.Context, in *VerifyEmail) (*VerifyEmail, error) { - in.verificationErr = crypto.VerifyCode(in.code.CreatedAt, in.code.Expiry, in.code.Code, in.Code, in.gen.Alg()) - return in, nil - }, - handler.HandleIf( - func(in *VerifyEmail) bool { - return in.verificationErr == nil - }, - func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { - return in, u.user.EmailVerificationSucceeded(ctx, in.client, in.code) - }, - ), - handler.HandleIf( - func(in *VerifyEmail) bool { - return in.verificationErr != nil - }, - func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { - return in, u.user.EmailVerificationFailed(ctx, in.client, in.code) - }, - ), - ), - ), - ), - )(ctx, in) - return err -} diff --git a/backend/domain/user_email.go b/backend/domain/user_email.go new file mode 100644 index 0000000000..cf1ad2dfe1 --- /dev/null +++ b/backend/domain/user_email.go @@ -0,0 +1,250 @@ +package domain + +import ( + "context" + "text/template" + + "github.com/zitadel/zitadel/backend/handler" + "github.com/zitadel/zitadel/backend/repository" + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" +) + +type VerifyEmail struct { + UserID string + Code string + + client database.QueryExecutor + config *crypto.GeneratorConfig + gen crypto.Generator + code *repository.EmailVerificationCode + verificationErr error +} + +type SetEmail struct { + *poolHandler[*SetEmail] + + UserID string + Email string + Verification handler.Handle[*SetEmail, *SetEmail] + + // config *crypto.GeneratorConfig + gen crypto.Generator + + code *crypto.CryptoValue + plainCode string + + currentEmail string +} + +func (u *User) WithEmailConfirmationURL(url template.Template) handler.Handle[*SetEmail, *SetEmail] { + return handler.Chain( + u.WithEmailReturnCode(), + func(ctx context.Context, in *SetEmail) (out *SetEmail, err error) { + // TODO: queue notification + return in, nil + }, + ) +} + +func (u *User) WithEmailReturnCode() handler.Handle[*SetEmail, *SetEmail] { + return handler.Chains( + handler.ErrFuncToHandle( + func(ctx context.Context, in *SetEmail) (err error) { + in.code, in.plainCode, err = crypto.NewCode(in.gen) + return err + }, + ), + handler.ErrFuncToHandle( + func(ctx context.Context, in *SetEmail) (err error) { + return u.user.SetEmailVerificationCode(ctx, in.poolHandler.client, in.UserID, in.code) + }, + ), + ) +} + +func (u *User) WithEmailVerified() handler.Handle[*SetEmail, *SetEmail] { + return handler.Chain( + handler.ErrFuncToHandle( + func(ctx context.Context, in *SetEmail) (err error) { + return repository.SetEmailVerificationCode(ctx, in.poolHandler.client, in.UserID, in.code) + }, + ), + handler.ErrFuncToHandle( + func(ctx context.Context, in *SetEmail) (err error) { + return u.user.EmailVerificationSucceeded(ctx, in.poolHandler.client, &repository.EmailVerificationCode{ + Code: in.code, + }) + }, + ), + ) +} + +func (u *User) WithDefaultEmailVerification() handler.Handle[*SetEmail, *SetEmail] { + return handler.Chain( + u.WithEmailReturnCode(), + func(ctx context.Context, in *SetEmail) (out *SetEmail, err error) { + // TODO: queue notification + return in, nil + }, + ) +} + +func (u *User) SetEmailDifferent(ctx context.Context, in *SetEmail) (err error) { + if in.Verification == nil { + in.Verification = u.WithDefaultEmailVerification() + } + + client, err := u.db.Acquire(ctx) + if err != nil { + return err + } + defer client.Release(ctx) + + config, err := u.secretGenerator.GeneratorConfigByType(ctx, client, domain.SecretGeneratorTypeVerifyEmailCode) + if err != nil { + return err + } + in.gen = crypto.NewEncryptionGenerator(*config, u.userCodeAlg) + + tx, err := client.Begin(ctx, nil) + if err != nil { + return err + } + defer tx.End(ctx, err) + + user, err := u.user.ByID(ctx, tx, in.UserID) + if err != nil { + return err + } + + if user.Email == in.Email { + return nil + } + + _, err = in.Verification(ctx, in) + return err +} + +func (u *User) SetEmail(ctx context.Context, in *SetEmail) error { + _, err := handler.Chain( + handler.HandleIf( + func(in *SetEmail) bool { + return in.Verification == nil + }, + func(ctx context.Context, in *SetEmail) (*SetEmail, error) { + in.Verification = u.WithDefaultEmailVerification() + return in, nil + }, + ), + handler.Deferrable( + in.poolHandler.acquire, + handler.Chains( + func(ctx context.Context, in *SetEmail) (_ *SetEmail, err error) { + config, err := u.secretGenerator.GeneratorConfigByType(ctx, in.poolHandler.client, domain.SecretGeneratorTypeVerifyEmailCode) + if err != nil { + return nil, err + } + in.gen = crypto.NewEncryptionGenerator(*config, u.userCodeAlg) + return in, nil + }, + handler.Deferrable( + in.poolHandler.begin, + handler.Chains( + func(ctx context.Context, in *SetEmail) (*SetEmail, error) { + // TODO: repository.EmailByUserID + user, err := u.user.ByID(ctx, in.poolHandler.client, in.UserID) + if err != nil { + return nil, err + } + in.currentEmail = user.Email + return in, nil + }, + handler.SkipIf( + func(in *SetEmail) bool { + return in.currentEmail == in.Email + }, + handler.Chains( + func(ctx context.Context, in *SetEmail) (*SetEmail, error) { + // TODO: repository.SetEmail + return in, nil + }, + in.Verification, + ), + ), + ), + ), + ), + ), + )(ctx, in) + return err +} + +func (u *User) VerifyEmail(ctx context.Context, in *VerifyEmail) error { + _, err := handler.Deferrable( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) { + client, err := u.db.Acquire(ctx) + if err != nil { + return nil, nil, err + } + in.client = client + return in, func(ctx context.Context, _ error) error { return client.Release(ctx) }, err + }, + handler.Chains( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + in.config, err = u.secretGenerator.GeneratorConfigByType(ctx, in.client, domain.SecretGeneratorTypeVerifyEmailCode) + return in, err + }, + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + in.gen = crypto.NewEncryptionGenerator(*in.config, u.userCodeAlg) + return in, nil + }, + handler.Deferrable( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) { + client := in.client + tx, err := in.client.(database.Client).Begin(ctx, nil) + if err != nil { + return nil, nil, err + } + in.client = tx + return in, func(ctx context.Context, err error) error { + err = tx.End(ctx, err) + if err != nil { + return err + } + in.client = client + return nil + }, err + }, + handler.Chains( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + in.code, err = u.user.EmailVerificationCode(ctx, in.client, in.UserID) + return in, err + }, + func(ctx context.Context, in *VerifyEmail) (*VerifyEmail, error) { + in.verificationErr = crypto.VerifyCode(in.code.CreatedAt, in.code.Expiry, in.code.Code, in.Code, in.gen.Alg()) + return in, nil + }, + handler.HandleIf( + func(in *VerifyEmail) bool { + return in.verificationErr == nil + }, + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + return in, u.user.EmailVerificationSucceeded(ctx, in.client, in.code) + }, + ), + handler.HandleIf( + func(in *VerifyEmail) bool { + return in.verificationErr != nil + }, + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + return in, u.user.EmailVerificationFailed(ctx, in.client, in.code) + }, + ), + ), + ), + ), + )(ctx, in) + return err +} diff --git a/backend/handler/handle.go b/backend/handler/handle.go index ccb9505dc6..7c10447e9f 100644 --- a/backend/handler/handle.go +++ b/backend/handler/handle.go @@ -14,6 +14,8 @@ type Handle[In, Out any] func(ctx context.Context, in In) (out Out, err error) type DeferrableHandle[In, Out any] func(ctx context.Context, in In) (out Out, deferrable func(context.Context, error) error, err error) +type Defer[In, Out, NextOut any] func(handle DeferrableHandle[In, Out], next Handle[Out, NextOut]) Handle[In, NextOut] + type HandleNoReturn[In any] func(ctx context.Context, in In) error // Middleware is a function that decorates the handle function. @@ -98,6 +100,15 @@ func HandleIf[In any](cond func(In) bool, handle Handle[In, In]) Handle[In, In] } } +func SkipIf[In any](cond func(In) bool, handle Handle[In, In]) Handle[In, In] { + return func(ctx context.Context, in In) (out In, err error) { + if cond(in) { + return in, nil + } + return handle(ctx, in) + } +} + // SkipNilHandler skips the handle function if the handler is nil. // If handle is nil, an empty output is returned. // The function is safe to call with nil handler. diff --git a/backend/repository/user.go b/backend/repository/user.go index 62e1b5382b..95184bad43 100644 --- a/backend/repository/user.go +++ b/backend/repository/user.go @@ -13,6 +13,7 @@ import ( type User struct { ID string Username string + Email string } type UserOptions struct { @@ -114,18 +115,18 @@ func (u *user) EmailVerificationCode(ctx context.Context, client database.Querie )(ctx, userID) } -func (u *user) EmailVerificationFailed(ctx context.Context, client database.Executor, userID string) error { +func (u *user) EmailVerificationFailed(ctx context.Context, client database.Executor, code *EmailVerificationCode) error { _, err := tracing.Wrap(u.tracer, "user.EmailVerificationFailed", handler.ErrFuncToHandle(execute(client).EmailVerificationFailed), - )(ctx, userID) + )(ctx, code) return err } -func (u *user) EmailVerificationSucceeded(ctx context.Context, client database.Executor, userID string) error { +func (u *user) EmailVerificationSucceeded(ctx context.Context, client database.Executor, code *EmailVerificationCode) error { _, err := tracing.Wrap(u.tracer, "user.EmailVerificationSucceeded", handler.ErrFuncToHandle(execute(client).EmailVerificationSucceeded), - )(ctx, userID) + )(ctx, code) return err } diff --git a/backend/repository/user_db.go b/backend/repository/user_db.go index 78e05edec7..4d394e544a 100644 --- a/backend/repository/user_db.go +++ b/backend/repository/user_db.go @@ -45,10 +45,14 @@ func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err e return user, nil } -func (e *executor) EmailVerificationFailed(ctx context.Context, userID string) error { +func (e *executor) EmailVerificationFailed(ctx context.Context, code *EmailVerificationCode) error { return errors.New("not implemented") } -func (e *executor) EmailVerificationSucceeded(ctx context.Context, userID string) error { +func (e *executor) EmailVerificationSucceeded(ctx context.Context, code *EmailVerificationCode) error { + return errors.New("not implemented") +} + +func (e *executor) SetEmail(ctx context.Context, userID, email string) error { return errors.New("not implemented") } diff --git a/backend/storage/database/database.go b/backend/storage/database/database.go index 7a4dc932ae..c5cee166c2 100644 --- a/backend/storage/database/database.go +++ b/backend/storage/database/database.go @@ -21,6 +21,8 @@ type Transaction interface { Rollback(ctx context.Context) error End(ctx context.Context, err error) error + Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error) + QueryExecutor } diff --git a/backend/storage/database/dialect/gosql/tx.go b/backend/storage/database/dialect/gosql/tx.go index 8578459316..a55f97d0fa 100644 --- a/backend/storage/database/dialect/gosql/tx.go +++ b/backend/storage/database/dialect/gosql/tx.go @@ -3,6 +3,7 @@ package gosql import ( "context" "database/sql" + "errors" "github.com/zitadel/zitadel/backend/storage/database" ) @@ -49,6 +50,12 @@ func (tx *sqlTx) Exec(ctx context.Context, sql string, args ...any) error { return err } +// Begin implements [database.Transaction]. +// it is unimplemented +func (tx *sqlTx) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + return nil, errors.New("nested transactions are not supported") +} + func transactionOptionsToSql(opts *database.TransactionOptions) *sql.TxOptions { if opts == nil { return nil diff --git a/backend/storage/database/dialect/postgres/tx.go b/backend/storage/database/dialect/postgres/tx.go index 767111b324..cfb6355dac 100644 --- a/backend/storage/database/dialect/postgres/tx.go +++ b/backend/storage/database/dialect/postgres/tx.go @@ -44,13 +44,24 @@ func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database return tx.Tx.QueryRow(ctx, sql, args...) } -// Exec implements [database.Pool]. +// Exec implements [database.Transaction]. // Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error { _, err := tx.Tx.Exec(ctx, sql, args...) return err } +// Begin implements [database.Transaction]. +// As postgres does not support nested transactions we use savepoints to emulate them. +// TransactionOptions are ignored as savepoints do not support changing isolation levels. +func (tx *pgxTx) Begin(ctx context.Context, _ *database.TransactionOptions) (database.Transaction, error) { + savepoint, err := tx.Tx.Begin(ctx) + if err != nil { + return nil, err + } + return &pgxTx{savepoint}, nil +} + func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions { if opts == nil { return pgx.TxOptions{} diff --git a/backend/storage/database/handle.go b/backend/storage/database/handle.go new file mode 100644 index 0000000000..7dab993ede --- /dev/null +++ b/backend/storage/database/handle.go @@ -0,0 +1,52 @@ +package database + +// import ( +// "context" +// "fmt" + +// "github.com/zitadel/zitadel/backend/handler" +// ) + +// func Begin[In, Out, NextOut any](ctx context.Context, beginner Beginner, opts *TransactionOptions) handler.Defer[In, Out, NextOut] { +// // func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) { +// return func(handle handler.DeferrableHandle[In, Out], next handler.Handle[Out, NextOut]) handler.Handle[In, NextOut] { +// return func(ctx context.Context, in In) (out NextOut, err error) { +// tx, err := beginner.Begin(ctx, opts) +// if err != nil { +// return out, err +// } +// defer func() { +// if err != nil { +// rollbackErr := tx.Rollback(ctx) +// if rollbackErr != nil { +// err = fmt.Errorf("query failed: %w, rollback failed: %v", err, rollbackErr) +// } +// } else { +// err = tx.Commit(ctx) +// } +// }() +// return handle(ctx, in, tx) +// } +// } + +// } + +// type QueryExecutorSetter interface { +// SetQueryExecutor(QueryExecutor) +// } + +// func Begin[In QueryExecutorSetter](ctx context.Context, beginner Beginner, in In) (_ In, _ func(context.Context, error) error, err error) { +// tx, err := beginner.Begin(ctx, nil) +// if err != nil { +// return in, nil, err +// } +// in.SetQueryExecutor(tx) +// return in, func(ctx context.Context, err error) error { +// err = tx.End(ctx, err) +// if err != nil { +// return err +// } +// in.SetQueryExecutor(beginner) +// return nil +// }, err +// } diff --git a/backend/storage/database/mock/transaction.go b/backend/storage/database/mock/transaction.go index b3fd47bb15..bb32e502fd 100644 --- a/backend/storage/database/mock/transaction.go +++ b/backend/storage/database/mock/transaction.go @@ -118,6 +118,12 @@ func (tx *Transaction) Exec(ctx context.Context, stmt string, args ...any) error return nil } +// Begin implements [database.Transaction]. +// it is unimplemented +func (tx *Transaction) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + return nil, errors.New("nested transactions are not supported") +} + // Query implements [database.Transaction]. func (tx *Transaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) { e := tx.nextExpecter()