mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 20:57:31 +00:00
bla
This commit is contained in:
@@ -64,6 +64,5 @@ func (b *Instance) SetUp(ctx context.Context, request *SetUpInstance) (err error
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
_, err = b.user.Create(ctx, tx, request.User)
|
_, err = b.user.Create(ctx, tx, request.User)
|
||||||
b.authorizations.authorizeusers
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@@ -3,11 +3,124 @@ package domain
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/backend/handler"
|
||||||
"github.com/zitadel/zitadel/backend/repository"
|
"github.com/zitadel/zitadel/backend/repository"
|
||||||
"github.com/zitadel/zitadel/backend/storage/database"
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||||||
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type userRepository interface {
|
type User struct {
|
||||||
Create(ctx context.Context, tx database.Transaction, user *repository.User) (*repository.User, error)
|
db database.Pool
|
||||||
ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error)
|
|
||||||
|
user userRepository
|
||||||
|
secretGenerator secretGeneratorRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
type UserRepositoryConstructor interface {
|
||||||
|
NewUserExecutor(database.Executor) userRepository
|
||||||
|
NewUserQuerier(database.Querier) userRepository
|
||||||
|
}
|
||||||
|
|
||||||
|
type userRepository interface {
|
||||||
|
Create(ctx context.Context, tx database.Executor, user *repository.User) (*repository.User, error)
|
||||||
|
ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error)
|
||||||
|
|
||||||
|
EmailVerificationCode(ctx context.Context, client database.Querier, userID string) (*repository.EmailVerificationCode, error)
|
||||||
|
EmailVerificationFailed(ctx context.Context, client database.Executor, code *repository.EmailVerificationCode) error
|
||||||
|
EmailVerificationSucceeded(ctx context.Context, client database.Executor, code *repository.EmailVerificationCode) error
|
||||||
|
}
|
||||||
|
|
||||||
|
type secretGeneratorRepository interface {
|
||||||
|
GeneratorConfigByType(ctx context.Context, client database.Querier, typ repository.SecretGeneratorType) (*crypto.GeneratorConfig, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUser(db database.Pool) *User {
|
||||||
|
b := &User{
|
||||||
|
db: db,
|
||||||
|
user: repository.NewUser(),
|
||||||
|
secretGenerator: repository.NewSecretGenerator(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return b
|
||||||
|
}
|
||||||
|
|
||||||
|
type VerifyEmail struct {
|
||||||
|
UserID string
|
||||||
|
Code string
|
||||||
|
Alg crypto.EncryptionAlgorithm
|
||||||
|
|
||||||
|
client database.QueryExecutor
|
||||||
|
config *crypto.GeneratorConfig
|
||||||
|
gen crypto.Generator
|
||||||
|
code *repository.EmailVerificationCode
|
||||||
|
verificationErr error
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *User) VerifyEmail(ctx context.Context, in *VerifyEmail) error {
|
||||||
|
_, err := handler.Deferrable(
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) {
|
||||||
|
client, err := u.db.Acquire(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
in.client = client
|
||||||
|
return in, func(ctx context.Context, _ error) error { return client.Release(ctx) }, err
|
||||||
|
},
|
||||||
|
handler.Chains(
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
|
||||||
|
in.config, err = u.secretGenerator.GeneratorConfigByType(ctx, in.client, domain.SecretGeneratorTypeVerifyEmailCode)
|
||||||
|
return in, err
|
||||||
|
},
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
|
||||||
|
in.gen = crypto.NewEncryptionGenerator(*in.config, in.Alg)
|
||||||
|
return in, nil
|
||||||
|
},
|
||||||
|
handler.Deferrable(
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) {
|
||||||
|
client := in.client
|
||||||
|
tx, err := in.client.(database.Client).Begin(ctx, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, err
|
||||||
|
}
|
||||||
|
in.client = tx
|
||||||
|
return in, func(ctx context.Context, err error) error {
|
||||||
|
err = tx.End(ctx, err)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
in.client = client
|
||||||
|
return nil
|
||||||
|
}, err
|
||||||
|
},
|
||||||
|
handler.Chains(
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
|
||||||
|
in.code, err = u.user.EmailVerificationCode(ctx, in.client, in.UserID)
|
||||||
|
return in, err
|
||||||
|
},
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (*VerifyEmail, error) {
|
||||||
|
in.verificationErr = crypto.VerifyCode(in.code.CreatedAt, in.code.Expiry, in.code.Code, in.Code, in.gen.Alg())
|
||||||
|
return in, nil
|
||||||
|
},
|
||||||
|
handler.HandleIf(
|
||||||
|
func(in *VerifyEmail) bool {
|
||||||
|
return in.verificationErr == nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
|
||||||
|
return in, u.user.EmailVerificationSucceeded(ctx, in.client, in.code)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
handler.HandleIf(
|
||||||
|
func(in *VerifyEmail) bool {
|
||||||
|
return in.verificationErr != nil
|
||||||
|
},
|
||||||
|
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
|
||||||
|
return in, u.user.EmailVerificationFailed(ctx, in.client, in.code)
|
||||||
|
},
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)(ctx, in)
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
@@ -4,20 +4,42 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Parameter[P, C any] struct {
|
||||||
|
Previous P
|
||||||
|
Current C
|
||||||
|
}
|
||||||
|
|
||||||
// Handle is a function that handles the in.
|
// Handle is a function that handles the in.
|
||||||
type Handle[Out, In any] func(ctx context.Context, in Out) (out In, err error)
|
type Handle[In, Out any] func(ctx context.Context, in In) (out Out, err error)
|
||||||
|
|
||||||
|
type DeferrableHandle[In, Out any] func(ctx context.Context, in In) (out Out, deferrable func(context.Context, error) error, err error)
|
||||||
|
|
||||||
|
type HandleNoReturn[In any] func(ctx context.Context, in In) error
|
||||||
|
|
||||||
// Middleware is a function that decorates the handle function.
|
// Middleware is a function that decorates the handle function.
|
||||||
// It must call the handle function but its up the the middleware to decide when and how.
|
// It must call the handle function but its up the the middleware to decide when and how.
|
||||||
type Middleware[In, Out any] func(ctx context.Context, in In, handle Handle[In, Out]) (out Out, err error)
|
type Middleware[In, Out any] func(ctx context.Context, in In, handle Handle[In, Out]) (out Out, err error)
|
||||||
|
|
||||||
|
func Deferrable[In, Out, NextOut any](handle DeferrableHandle[In, Out], next Handle[Out, NextOut]) Handle[In, NextOut] {
|
||||||
|
return func(ctx context.Context, in In) (nextOut NextOut, err error) {
|
||||||
|
out, deferrable, err := handle(ctx, in)
|
||||||
|
if err != nil {
|
||||||
|
return nextOut, err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = deferrable(ctx, err)
|
||||||
|
}()
|
||||||
|
return next(ctx, out)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Chain chains the handle function with the next handler.
|
// Chain chains the handle function with the next handler.
|
||||||
// The next handler is called after the handle function.
|
// The next handler is called after the handle function.
|
||||||
func Chain[In, Out any](handle Handle[In, Out], next Handle[Out, Out]) Handle[In, Out] {
|
func Chain[In, Out, NextOut any](handle Handle[In, Out], next Handle[Out, NextOut]) Handle[In, NextOut] {
|
||||||
return func(ctx context.Context, in In) (out Out, err error) {
|
return func(ctx context.Context, in In) (nextOut NextOut, err error) {
|
||||||
out, err = handle(ctx, in)
|
out, err := handle(ctx, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return out, err
|
return nextOut, err
|
||||||
}
|
}
|
||||||
return next(ctx, out)
|
return next(ctx, out)
|
||||||
}
|
}
|
||||||
@@ -67,6 +89,15 @@ func SkipNext[In, Out any](handle Handle[In, Out], next Handle[In, Out]) Handle[
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func HandleIf[In any](cond func(In) bool, handle Handle[In, In]) Handle[In, In] {
|
||||||
|
return func(ctx context.Context, in In) (out In, err error) {
|
||||||
|
if !cond(in) {
|
||||||
|
return in, nil
|
||||||
|
}
|
||||||
|
return handle(ctx, in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SkipNilHandler skips the handle function if the handler is nil.
|
// SkipNilHandler skips the handle function if the handler is nil.
|
||||||
// If handle is nil, an empty output is returned.
|
// If handle is nil, an empty output is returned.
|
||||||
// The function is safe to call with nil handler.
|
// The function is safe to call with nil handler.
|
||||||
@@ -90,6 +121,12 @@ func SkipReturnPreviousHandler[O, In any](handler *O, handle Handle[In, In]) Han
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func CtxFuncToHandle[Out any](fn func(context.Context) (Out, error)) Handle[struct{}, Out] {
|
||||||
|
return func(ctx context.Context, in struct{}) (out Out, err error) {
|
||||||
|
return fn(ctx)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func ResFuncToHandle[In any, Out any](fn func(context.Context, In) Out) Handle[In, Out] {
|
func ResFuncToHandle[In any, Out any](fn func(context.Context, In) Out) Handle[In, Out] {
|
||||||
return func(ctx context.Context, in In) (out Out, err error) {
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
return fn(ctx, in), nil
|
return fn(ctx, in), nil
|
||||||
|
33
backend/repository/secret_generator.go
Normal file
33
backend/repository/secret_generator.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
|
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
||||||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||||||
|
"github.com/zitadel/zitadel/internal/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SecretGeneratorOptions struct{}
|
||||||
|
|
||||||
|
type SecretGenerator struct {
|
||||||
|
options[SecretGeneratorOptions]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSecretGenerator(opts ...Option[SecretGeneratorOptions]) *SecretGenerator {
|
||||||
|
i := new(SecretGenerator)
|
||||||
|
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt.apply(&i.options)
|
||||||
|
}
|
||||||
|
return i
|
||||||
|
}
|
||||||
|
|
||||||
|
type SecretGeneratorType = domain.SecretGeneratorType
|
||||||
|
|
||||||
|
func (sg *SecretGenerator) GeneratorConfigByType(ctx context.Context, client database.Querier, typ SecretGeneratorType) (*crypto.GeneratorConfig, error) {
|
||||||
|
return tracing.Wrap(sg.tracer, "secretGenerator.GeneratorConfigByType",
|
||||||
|
query(client).SecretGeneratorConfigByType,
|
||||||
|
)(ctx, typ)
|
||||||
|
}
|
25
backend/repository/secret_generator_db.go
Normal file
25
backend/repository/secret_generator_db.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package repository
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
"github.com/zitadel/zitadel/internal/crypto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const secretGeneratorByTypeStmt = `SELECT * FROM secret_generators WHERE instance_id = $1 AND type = $2`
|
||||||
|
|
||||||
|
func (q querier) SecretGeneratorConfigByType(ctx context.Context, typ SecretGeneratorType) (config *crypto.GeneratorConfig, err error) {
|
||||||
|
err = q.client.QueryRow(ctx, secretGeneratorByTypeStmt, authz.GetInstance(ctx).InstanceID, typ).Scan(
|
||||||
|
&config.Length,
|
||||||
|
&config.Expiry,
|
||||||
|
&config.IncludeLowerLetters,
|
||||||
|
&config.IncludeUpperLetters,
|
||||||
|
&config.IncludeDigits,
|
||||||
|
&config.IncludeSymbols,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return config, nil
|
||||||
|
}
|
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/backend/handler"
|
"github.com/zitadel/zitadel/backend/handler"
|
||||||
"github.com/zitadel/zitadel/backend/storage/database"
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
@@ -39,15 +40,15 @@ func WithUserCache(c *UserCache) Option[UserOptions] {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) Create(ctx context.Context, tx database.Transaction, user *User) (*User, error) {
|
func (u *user) Create(ctx context.Context, client database.Executor, user *User) (*User, error) {
|
||||||
return tracing.Wrap(u.tracer, "user.Create",
|
return tracing.Wrap(u.tracer, "user.Create",
|
||||||
handler.Chain(
|
handler.Chain(
|
||||||
handler.Decorate(
|
handler.Decorate(
|
||||||
execute(tx).CreateUser,
|
execute(client).CreateUser,
|
||||||
tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.sql.Create")),
|
tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.sql.Create")),
|
||||||
),
|
),
|
||||||
handler.Decorate(
|
handler.Decorate(
|
||||||
events(tx).CreateUser,
|
events(client).CreateUser,
|
||||||
tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.event.Create")),
|
tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.event.Create")),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
@@ -72,26 +73,59 @@ func (u *user) ByID(ctx context.Context, client database.Querier, id string) (*U
|
|||||||
type ChangeEmail struct {
|
type ChangeEmail struct {
|
||||||
UserID string
|
UserID string
|
||||||
Email string
|
Email string
|
||||||
Opt *ChangeEmailOption
|
// Opt *ChangeEmailOption
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChangeEmailOption struct {
|
// type ChangeEmailOption struct {
|
||||||
returnCode bool
|
// returnCode bool
|
||||||
isVerified bool
|
// isVerified bool
|
||||||
sendCode bool
|
// sendCode bool
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type ChangeEmailVerifiedOption struct {
|
||||||
|
// isVerified bool
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type ChangeEmailReturnCodeOption struct {
|
||||||
|
// alg crypto.EncryptionAlgorithm
|
||||||
|
// }
|
||||||
|
|
||||||
|
// type ChangeEmailSendCodeOption struct {
|
||||||
|
// alg crypto.EncryptionAlgorithm
|
||||||
|
// urlTemplate string
|
||||||
|
// }
|
||||||
|
|
||||||
|
func (u *user) ChangeEmail(ctx context.Context, client database.Executor, change *ChangeEmail) {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChangeEmailVerifiedOption struct {
|
type EmailVerificationCode struct {
|
||||||
isVerified bool
|
Code *crypto.CryptoValue
|
||||||
|
CreatedAt time.Time
|
||||||
|
Expiry time.Duration
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChangeEmailReturnCodeOption struct {
|
func (u *user) EmailVerificationCode(ctx context.Context, client database.Querier, userID string) (*EmailVerificationCode, error) {
|
||||||
alg crypto.EncryptionAlgorithm
|
return tracing.Wrap(u.tracer, "user.EmailVerificationCode",
|
||||||
|
handler.Decorate(
|
||||||
|
query(client).EmailVerificationCode,
|
||||||
|
tracing.Decorate[string, *EmailVerificationCode](u.tracer, tracing.WithSpanName("user.sql.EmailVerificationCode")),
|
||||||
|
),
|
||||||
|
)(ctx, userID)
|
||||||
}
|
}
|
||||||
|
|
||||||
type ChangeEmailSendCodeOption struct {
|
func (u *user) EmailVerificationFailed(ctx context.Context, client database.Executor, userID string) error {
|
||||||
alg crypto.EncryptionAlgorithm
|
_, err := tracing.Wrap(u.tracer, "user.EmailVerificationFailed",
|
||||||
urlTemplate string
|
handler.ErrFuncToHandle(execute(client).EmailVerificationFailed),
|
||||||
|
)(ctx, userID)
|
||||||
|
|
||||||
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *user) ChangeEmail(ctx context.Context, client database.Executor, change *ChangeEmail)
|
func (u *user) EmailVerificationSucceeded(ctx context.Context, client database.Executor, userID string) error {
|
||||||
|
_, err := tracing.Wrap(u.tracer, "user.EmailVerificationSucceeded",
|
||||||
|
handler.ErrFuncToHandle(execute(client).EmailVerificationSucceeded),
|
||||||
|
)(ctx, userID)
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
@@ -2,6 +2,7 @@ package repository
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"errors"
|
||||||
"log"
|
"log"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,6 +18,24 @@ func (q *querier) UserByID(ctx context.Context, id string) (res *User, err error
|
|||||||
return &user, nil
|
return &user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const emailVerificationCodeStmt = `SELECT created_at, expiry,code FROM email_verification_codes WHERE user_id = $1`
|
||||||
|
|
||||||
|
func (q *querier) EmailVerificationCode(ctx context.Context, userID string) (res *EmailVerificationCode, err error) {
|
||||||
|
log.Println("sql.user.emailVerificationCode")
|
||||||
|
|
||||||
|
res = new(EmailVerificationCode)
|
||||||
|
err = q.client.QueryRow(ctx, emailVerificationCodeStmt, userID).
|
||||||
|
Scan(
|
||||||
|
&res.CreatedAt,
|
||||||
|
&res.Expiry,
|
||||||
|
&res.Code,
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return res, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err error) {
|
func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err error) {
|
||||||
log.Println("sql.user.create")
|
log.Println("sql.user.create")
|
||||||
err = e.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username)
|
err = e.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username)
|
||||||
@@ -25,3 +44,11 @@ func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err e
|
|||||||
}
|
}
|
||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *executor) EmailVerificationFailed(ctx context.Context, userID string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *executor) EmailVerificationSucceeded(ctx context.Context, userID string) error {
|
||||||
|
return errors.New("not implemented")
|
||||||
|
}
|
||||||
|
@@ -25,6 +25,8 @@ type EncryptionAlgorithm interface {
|
|||||||
DecryptString(hashed []byte, keyID string) (string, error)
|
DecryptString(hashed []byte, keyID string) (string, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CryptoValue is a struct that can be used to store encrypted values in a database.
|
||||||
|
// The struct is compatible with the [driver.Valuer] and database/sql.Scanner interfaces.
|
||||||
type CryptoValue struct {
|
type CryptoValue struct {
|
||||||
CryptoType CryptoType
|
CryptoType CryptoType
|
||||||
Algorithm string
|
Algorithm string
|
||||||
|
Reference in New Issue
Block a user