diff --git a/backend/domain/instance.go b/backend/domain/instance.go index 8036274e60..17c966bfa2 100644 --- a/backend/domain/instance.go +++ b/backend/domain/instance.go @@ -64,6 +64,5 @@ func (b *Instance) SetUp(ctx context.Context, request *SetUpInstance) (err error return err } _, err = b.user.Create(ctx, tx, request.User) - b.authorizations.authorizeusers return err } diff --git a/backend/domain/user.go b/backend/domain/user.go index 8df253c132..54ef8cb257 100644 --- a/backend/domain/user.go +++ b/backend/domain/user.go @@ -3,11 +3,124 @@ package domain import ( "context" + "github.com/zitadel/zitadel/backend/handler" "github.com/zitadel/zitadel/backend/repository" "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" ) -type userRepository interface { - Create(ctx context.Context, tx database.Transaction, user *repository.User) (*repository.User, error) - ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error) +type User struct { + db database.Pool + + user userRepository + secretGenerator secretGeneratorRepository +} + +type UserRepositoryConstructor interface { + NewUserExecutor(database.Executor) userRepository + NewUserQuerier(database.Querier) userRepository +} + +type userRepository interface { + Create(ctx context.Context, tx database.Executor, user *repository.User) (*repository.User, error) + ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error) + + EmailVerificationCode(ctx context.Context, client database.Querier, userID string) (*repository.EmailVerificationCode, error) + EmailVerificationFailed(ctx context.Context, client database.Executor, code *repository.EmailVerificationCode) error + EmailVerificationSucceeded(ctx context.Context, client database.Executor, code *repository.EmailVerificationCode) error +} + +type secretGeneratorRepository interface { + GeneratorConfigByType(ctx context.Context, client database.Querier, typ repository.SecretGeneratorType) (*crypto.GeneratorConfig, error) +} + +func NewUser(db database.Pool) *User { + b := &User{ + db: db, + user: repository.NewUser(), + secretGenerator: repository.NewSecretGenerator(), + } + + return b +} + +type VerifyEmail struct { + UserID string + Code string + Alg crypto.EncryptionAlgorithm + + client database.QueryExecutor + config *crypto.GeneratorConfig + gen crypto.Generator + code *repository.EmailVerificationCode + verificationErr error +} + +func (u *User) VerifyEmail(ctx context.Context, in *VerifyEmail) error { + _, err := handler.Deferrable( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) { + client, err := u.db.Acquire(ctx) + if err != nil { + return nil, nil, err + } + in.client = client + return in, func(ctx context.Context, _ error) error { return client.Release(ctx) }, err + }, + handler.Chains( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + in.config, err = u.secretGenerator.GeneratorConfigByType(ctx, in.client, domain.SecretGeneratorTypeVerifyEmailCode) + return in, err + }, + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + in.gen = crypto.NewEncryptionGenerator(*in.config, in.Alg) + return in, nil + }, + handler.Deferrable( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) { + client := in.client + tx, err := in.client.(database.Client).Begin(ctx, nil) + if err != nil { + return nil, nil, err + } + in.client = tx + return in, func(ctx context.Context, err error) error { + err = tx.End(ctx, err) + if err != nil { + return err + } + in.client = client + return nil + }, err + }, + handler.Chains( + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + in.code, err = u.user.EmailVerificationCode(ctx, in.client, in.UserID) + return in, err + }, + func(ctx context.Context, in *VerifyEmail) (*VerifyEmail, error) { + in.verificationErr = crypto.VerifyCode(in.code.CreatedAt, in.code.Expiry, in.code.Code, in.Code, in.gen.Alg()) + return in, nil + }, + handler.HandleIf( + func(in *VerifyEmail) bool { + return in.verificationErr == nil + }, + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + return in, u.user.EmailVerificationSucceeded(ctx, in.client, in.code) + }, + ), + handler.HandleIf( + func(in *VerifyEmail) bool { + return in.verificationErr != nil + }, + func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) { + return in, u.user.EmailVerificationFailed(ctx, in.client, in.code) + }, + ), + ), + ), + ), + )(ctx, in) + return err } diff --git a/backend/handler/handle.go b/backend/handler/handle.go index 8317ef2448..ccb9505dc6 100644 --- a/backend/handler/handle.go +++ b/backend/handler/handle.go @@ -4,20 +4,42 @@ import ( "context" ) +type Parameter[P, C any] struct { + Previous P + Current C +} + // Handle is a function that handles the in. -type Handle[Out, In any] func(ctx context.Context, in Out) (out In, err error) +type Handle[In, Out any] func(ctx context.Context, in In) (out Out, err error) + +type DeferrableHandle[In, Out any] func(ctx context.Context, in In) (out Out, deferrable func(context.Context, error) error, err error) + +type HandleNoReturn[In any] func(ctx context.Context, in In) error // Middleware is a function that decorates the handle function. // It must call the handle function but its up the the middleware to decide when and how. type Middleware[In, Out any] func(ctx context.Context, in In, handle Handle[In, Out]) (out Out, err error) +func Deferrable[In, Out, NextOut any](handle DeferrableHandle[In, Out], next Handle[Out, NextOut]) Handle[In, NextOut] { + return func(ctx context.Context, in In) (nextOut NextOut, err error) { + out, deferrable, err := handle(ctx, in) + if err != nil { + return nextOut, err + } + defer func() { + err = deferrable(ctx, err) + }() + return next(ctx, out) + } +} + // Chain chains the handle function with the next handler. // The next handler is called after the handle function. -func Chain[In, Out any](handle Handle[In, Out], next Handle[Out, Out]) Handle[In, Out] { - return func(ctx context.Context, in In) (out Out, err error) { - out, err = handle(ctx, in) +func Chain[In, Out, NextOut any](handle Handle[In, Out], next Handle[Out, NextOut]) Handle[In, NextOut] { + return func(ctx context.Context, in In) (nextOut NextOut, err error) { + out, err := handle(ctx, in) if err != nil { - return out, err + return nextOut, err } return next(ctx, out) } @@ -67,6 +89,15 @@ func SkipNext[In, Out any](handle Handle[In, Out], next Handle[In, Out]) Handle[ } } +func HandleIf[In any](cond func(In) bool, handle Handle[In, In]) Handle[In, In] { + return func(ctx context.Context, in In) (out In, err error) { + if !cond(in) { + return in, nil + } + return handle(ctx, in) + } +} + // SkipNilHandler skips the handle function if the handler is nil. // If handle is nil, an empty output is returned. // The function is safe to call with nil handler. @@ -90,6 +121,12 @@ func SkipReturnPreviousHandler[O, In any](handler *O, handle Handle[In, In]) Han } } +func CtxFuncToHandle[Out any](fn func(context.Context) (Out, error)) Handle[struct{}, Out] { + return func(ctx context.Context, in struct{}) (out Out, err error) { + return fn(ctx) + } +} + func ResFuncToHandle[In any, Out any](fn func(context.Context, In) Out) Handle[In, Out] { return func(ctx context.Context, in In) (out Out, err error) { return fn(ctx, in), nil diff --git a/backend/repository/secret_generator.go b/backend/repository/secret_generator.go new file mode 100644 index 0000000000..708dd611f5 --- /dev/null +++ b/backend/repository/secret_generator.go @@ -0,0 +1,33 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/backend/telemetry/tracing" + "github.com/zitadel/zitadel/internal/crypto" + "github.com/zitadel/zitadel/internal/domain" +) + +type SecretGeneratorOptions struct{} + +type SecretGenerator struct { + options[SecretGeneratorOptions] +} + +func NewSecretGenerator(opts ...Option[SecretGeneratorOptions]) *SecretGenerator { + i := new(SecretGenerator) + + for _, opt := range opts { + opt.apply(&i.options) + } + return i +} + +type SecretGeneratorType = domain.SecretGeneratorType + +func (sg *SecretGenerator) GeneratorConfigByType(ctx context.Context, client database.Querier, typ SecretGeneratorType) (*crypto.GeneratorConfig, error) { + return tracing.Wrap(sg.tracer, "secretGenerator.GeneratorConfigByType", + query(client).SecretGeneratorConfigByType, + )(ctx, typ) +} diff --git a/backend/repository/secret_generator_db.go b/backend/repository/secret_generator_db.go new file mode 100644 index 0000000000..c1fa987640 --- /dev/null +++ b/backend/repository/secret_generator_db.go @@ -0,0 +1,25 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/crypto" +) + +const secretGeneratorByTypeStmt = `SELECT * FROM secret_generators WHERE instance_id = $1 AND type = $2` + +func (q querier) SecretGeneratorConfigByType(ctx context.Context, typ SecretGeneratorType) (config *crypto.GeneratorConfig, err error) { + err = q.client.QueryRow(ctx, secretGeneratorByTypeStmt, authz.GetInstance(ctx).InstanceID, typ).Scan( + &config.Length, + &config.Expiry, + &config.IncludeLowerLetters, + &config.IncludeUpperLetters, + &config.IncludeDigits, + &config.IncludeSymbols, + ) + if err != nil { + return nil, err + } + return config, nil +} diff --git a/backend/repository/user.go b/backend/repository/user.go index a44cfbd30e..62e1b5382b 100644 --- a/backend/repository/user.go +++ b/backend/repository/user.go @@ -2,6 +2,7 @@ package repository import ( "context" + "time" "github.com/zitadel/zitadel/backend/handler" "github.com/zitadel/zitadel/backend/storage/database" @@ -39,15 +40,15 @@ func WithUserCache(c *UserCache) Option[UserOptions] { } } -func (u *user) Create(ctx context.Context, tx database.Transaction, user *User) (*User, error) { +func (u *user) Create(ctx context.Context, client database.Executor, user *User) (*User, error) { return tracing.Wrap(u.tracer, "user.Create", handler.Chain( handler.Decorate( - execute(tx).CreateUser, + execute(client).CreateUser, tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.sql.Create")), ), handler.Decorate( - events(tx).CreateUser, + events(client).CreateUser, tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.event.Create")), ), ), @@ -72,26 +73,59 @@ func (u *user) ByID(ctx context.Context, client database.Querier, id string) (*U type ChangeEmail struct { UserID string Email string - Opt *ChangeEmailOption + // Opt *ChangeEmailOption } -type ChangeEmailOption struct { - returnCode bool - isVerified bool - sendCode bool +// type ChangeEmailOption struct { +// returnCode bool +// isVerified bool +// sendCode bool +// } + +// type ChangeEmailVerifiedOption struct { +// isVerified bool +// } + +// type ChangeEmailReturnCodeOption struct { +// alg crypto.EncryptionAlgorithm +// } + +// type ChangeEmailSendCodeOption struct { +// alg crypto.EncryptionAlgorithm +// urlTemplate string +// } + +func (u *user) ChangeEmail(ctx context.Context, client database.Executor, change *ChangeEmail) { + } -type ChangeEmailVerifiedOption struct { - isVerified bool +type EmailVerificationCode struct { + Code *crypto.CryptoValue + CreatedAt time.Time + Expiry time.Duration } -type ChangeEmailReturnCodeOption struct { - alg crypto.EncryptionAlgorithm +func (u *user) EmailVerificationCode(ctx context.Context, client database.Querier, userID string) (*EmailVerificationCode, error) { + return tracing.Wrap(u.tracer, "user.EmailVerificationCode", + handler.Decorate( + query(client).EmailVerificationCode, + tracing.Decorate[string, *EmailVerificationCode](u.tracer, tracing.WithSpanName("user.sql.EmailVerificationCode")), + ), + )(ctx, userID) } -type ChangeEmailSendCodeOption struct { - alg crypto.EncryptionAlgorithm - urlTemplate string +func (u *user) EmailVerificationFailed(ctx context.Context, client database.Executor, userID string) error { + _, err := tracing.Wrap(u.tracer, "user.EmailVerificationFailed", + handler.ErrFuncToHandle(execute(client).EmailVerificationFailed), + )(ctx, userID) + + return err } -func (u *user) ChangeEmail(ctx context.Context, client database.Executor, change *ChangeEmail) +func (u *user) EmailVerificationSucceeded(ctx context.Context, client database.Executor, userID string) error { + _, err := tracing.Wrap(u.tracer, "user.EmailVerificationSucceeded", + handler.ErrFuncToHandle(execute(client).EmailVerificationSucceeded), + )(ctx, userID) + + return err +} diff --git a/backend/repository/user_db.go b/backend/repository/user_db.go index 4495358642..78e05edec7 100644 --- a/backend/repository/user_db.go +++ b/backend/repository/user_db.go @@ -2,6 +2,7 @@ package repository import ( "context" + "errors" "log" ) @@ -17,6 +18,24 @@ func (q *querier) UserByID(ctx context.Context, id string) (res *User, err error return &user, nil } +const emailVerificationCodeStmt = `SELECT created_at, expiry,code FROM email_verification_codes WHERE user_id = $1` + +func (q *querier) EmailVerificationCode(ctx context.Context, userID string) (res *EmailVerificationCode, err error) { + log.Println("sql.user.emailVerificationCode") + + res = new(EmailVerificationCode) + err = q.client.QueryRow(ctx, emailVerificationCodeStmt, userID). + Scan( + &res.CreatedAt, + &res.Expiry, + &res.Code, + ) + if err != nil { + return nil, err + } + return res, nil +} + func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err error) { log.Println("sql.user.create") err = e.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username) @@ -25,3 +44,11 @@ func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err e } return user, nil } + +func (e *executor) EmailVerificationFailed(ctx context.Context, userID string) error { + return errors.New("not implemented") +} + +func (e *executor) EmailVerificationSucceeded(ctx context.Context, userID string) error { + return errors.New("not implemented") +} diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index ff3b6e2418..0aadba35c8 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -25,6 +25,8 @@ type EncryptionAlgorithm interface { DecryptString(hashed []byte, keyID string) (string, error) } +// CryptoValue is a struct that can be used to store encrypted values in a database. +// The struct is compatible with the [driver.Valuer] and database/sql.Scanner interfaces. type CryptoValue struct { CryptoType CryptoType Algorithm string