This commit is contained in:
adlerhurst
2025-03-19 11:27:15 +01:00
parent e00ab397e2
commit d0044058ec
8 changed files with 295 additions and 25 deletions

View File

@@ -64,6 +64,5 @@ func (b *Instance) SetUp(ctx context.Context, request *SetUpInstance) (err error
return err return err
} }
_, err = b.user.Create(ctx, tx, request.User) _, err = b.user.Create(ctx, tx, request.User)
b.authorizations.authorizeusers
return err return err
} }

View File

@@ -3,11 +3,124 @@ package domain
import ( import (
"context" "context"
"github.com/zitadel/zitadel/backend/handler"
"github.com/zitadel/zitadel/backend/repository" "github.com/zitadel/zitadel/backend/repository"
"github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
) )
type userRepository interface { type User struct {
Create(ctx context.Context, tx database.Transaction, user *repository.User) (*repository.User, error) db database.Pool
ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error)
user userRepository
secretGenerator secretGeneratorRepository
}
type UserRepositoryConstructor interface {
NewUserExecutor(database.Executor) userRepository
NewUserQuerier(database.Querier) userRepository
}
type userRepository interface {
Create(ctx context.Context, tx database.Executor, user *repository.User) (*repository.User, error)
ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error)
EmailVerificationCode(ctx context.Context, client database.Querier, userID string) (*repository.EmailVerificationCode, error)
EmailVerificationFailed(ctx context.Context, client database.Executor, code *repository.EmailVerificationCode) error
EmailVerificationSucceeded(ctx context.Context, client database.Executor, code *repository.EmailVerificationCode) error
}
type secretGeneratorRepository interface {
GeneratorConfigByType(ctx context.Context, client database.Querier, typ repository.SecretGeneratorType) (*crypto.GeneratorConfig, error)
}
func NewUser(db database.Pool) *User {
b := &User{
db: db,
user: repository.NewUser(),
secretGenerator: repository.NewSecretGenerator(),
}
return b
}
type VerifyEmail struct {
UserID string
Code string
Alg crypto.EncryptionAlgorithm
client database.QueryExecutor
config *crypto.GeneratorConfig
gen crypto.Generator
code *repository.EmailVerificationCode
verificationErr error
}
func (u *User) VerifyEmail(ctx context.Context, in *VerifyEmail) error {
_, err := handler.Deferrable(
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) {
client, err := u.db.Acquire(ctx)
if err != nil {
return nil, nil, err
}
in.client = client
return in, func(ctx context.Context, _ error) error { return client.Release(ctx) }, err
},
handler.Chains(
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
in.config, err = u.secretGenerator.GeneratorConfigByType(ctx, in.client, domain.SecretGeneratorTypeVerifyEmailCode)
return in, err
},
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
in.gen = crypto.NewEncryptionGenerator(*in.config, in.Alg)
return in, nil
},
handler.Deferrable(
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, _ func(context.Context, error) error, err error) {
client := in.client
tx, err := in.client.(database.Client).Begin(ctx, nil)
if err != nil {
return nil, nil, err
}
in.client = tx
return in, func(ctx context.Context, err error) error {
err = tx.End(ctx, err)
if err != nil {
return err
}
in.client = client
return nil
}, err
},
handler.Chains(
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
in.code, err = u.user.EmailVerificationCode(ctx, in.client, in.UserID)
return in, err
},
func(ctx context.Context, in *VerifyEmail) (*VerifyEmail, error) {
in.verificationErr = crypto.VerifyCode(in.code.CreatedAt, in.code.Expiry, in.code.Code, in.Code, in.gen.Alg())
return in, nil
},
handler.HandleIf(
func(in *VerifyEmail) bool {
return in.verificationErr == nil
},
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
return in, u.user.EmailVerificationSucceeded(ctx, in.client, in.code)
},
),
handler.HandleIf(
func(in *VerifyEmail) bool {
return in.verificationErr != nil
},
func(ctx context.Context, in *VerifyEmail) (_ *VerifyEmail, err error) {
return in, u.user.EmailVerificationFailed(ctx, in.client, in.code)
},
),
),
),
),
)(ctx, in)
return err
} }

View File

@@ -4,20 +4,42 @@ import (
"context" "context"
) )
type Parameter[P, C any] struct {
Previous P
Current C
}
// Handle is a function that handles the in. // Handle is a function that handles the in.
type Handle[Out, In any] func(ctx context.Context, in Out) (out In, err error) type Handle[In, Out any] func(ctx context.Context, in In) (out Out, err error)
type DeferrableHandle[In, Out any] func(ctx context.Context, in In) (out Out, deferrable func(context.Context, error) error, err error)
type HandleNoReturn[In any] func(ctx context.Context, in In) error
// Middleware is a function that decorates the handle function. // Middleware is a function that decorates the handle function.
// It must call the handle function but its up the the middleware to decide when and how. // It must call the handle function but its up the the middleware to decide when and how.
type Middleware[In, Out any] func(ctx context.Context, in In, handle Handle[In, Out]) (out Out, err error) type Middleware[In, Out any] func(ctx context.Context, in In, handle Handle[In, Out]) (out Out, err error)
func Deferrable[In, Out, NextOut any](handle DeferrableHandle[In, Out], next Handle[Out, NextOut]) Handle[In, NextOut] {
return func(ctx context.Context, in In) (nextOut NextOut, err error) {
out, deferrable, err := handle(ctx, in)
if err != nil {
return nextOut, err
}
defer func() {
err = deferrable(ctx, err)
}()
return next(ctx, out)
}
}
// Chain chains the handle function with the next handler. // Chain chains the handle function with the next handler.
// The next handler is called after the handle function. // The next handler is called after the handle function.
func Chain[In, Out any](handle Handle[In, Out], next Handle[Out, Out]) Handle[In, Out] { func Chain[In, Out, NextOut any](handle Handle[In, Out], next Handle[Out, NextOut]) Handle[In, NextOut] {
return func(ctx context.Context, in In) (out Out, err error) { return func(ctx context.Context, in In) (nextOut NextOut, err error) {
out, err = handle(ctx, in) out, err := handle(ctx, in)
if err != nil { if err != nil {
return out, err return nextOut, err
} }
return next(ctx, out) return next(ctx, out)
} }
@@ -67,6 +89,15 @@ func SkipNext[In, Out any](handle Handle[In, Out], next Handle[In, Out]) Handle[
} }
} }
func HandleIf[In any](cond func(In) bool, handle Handle[In, In]) Handle[In, In] {
return func(ctx context.Context, in In) (out In, err error) {
if !cond(in) {
return in, nil
}
return handle(ctx, in)
}
}
// SkipNilHandler skips the handle function if the handler is nil. // SkipNilHandler skips the handle function if the handler is nil.
// If handle is nil, an empty output is returned. // If handle is nil, an empty output is returned.
// The function is safe to call with nil handler. // The function is safe to call with nil handler.
@@ -90,6 +121,12 @@ func SkipReturnPreviousHandler[O, In any](handler *O, handle Handle[In, In]) Han
} }
} }
func CtxFuncToHandle[Out any](fn func(context.Context) (Out, error)) Handle[struct{}, Out] {
return func(ctx context.Context, in struct{}) (out Out, err error) {
return fn(ctx)
}
}
func ResFuncToHandle[In any, Out any](fn func(context.Context, In) Out) Handle[In, Out] { func ResFuncToHandle[In any, Out any](fn func(context.Context, In) Out) Handle[In, Out] {
return func(ctx context.Context, in In) (out Out, err error) { return func(ctx context.Context, in In) (out Out, err error) {
return fn(ctx, in), nil return fn(ctx, in), nil

View File

@@ -0,0 +1,33 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/telemetry/tracing"
"github.com/zitadel/zitadel/internal/crypto"
"github.com/zitadel/zitadel/internal/domain"
)
type SecretGeneratorOptions struct{}
type SecretGenerator struct {
options[SecretGeneratorOptions]
}
func NewSecretGenerator(opts ...Option[SecretGeneratorOptions]) *SecretGenerator {
i := new(SecretGenerator)
for _, opt := range opts {
opt.apply(&i.options)
}
return i
}
type SecretGeneratorType = domain.SecretGeneratorType
func (sg *SecretGenerator) GeneratorConfigByType(ctx context.Context, client database.Querier, typ SecretGeneratorType) (*crypto.GeneratorConfig, error) {
return tracing.Wrap(sg.tracer, "secretGenerator.GeneratorConfigByType",
query(client).SecretGeneratorConfigByType,
)(ctx, typ)
}

View File

@@ -0,0 +1,25 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/crypto"
)
const secretGeneratorByTypeStmt = `SELECT * FROM secret_generators WHERE instance_id = $1 AND type = $2`
func (q querier) SecretGeneratorConfigByType(ctx context.Context, typ SecretGeneratorType) (config *crypto.GeneratorConfig, err error) {
err = q.client.QueryRow(ctx, secretGeneratorByTypeStmt, authz.GetInstance(ctx).InstanceID, typ).Scan(
&config.Length,
&config.Expiry,
&config.IncludeLowerLetters,
&config.IncludeUpperLetters,
&config.IncludeDigits,
&config.IncludeSymbols,
)
if err != nil {
return nil, err
}
return config, nil
}

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/zitadel/zitadel/backend/handler" "github.com/zitadel/zitadel/backend/handler"
"github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/storage/database"
@@ -39,15 +40,15 @@ func WithUserCache(c *UserCache) Option[UserOptions] {
} }
} }
func (u *user) Create(ctx context.Context, tx database.Transaction, user *User) (*User, error) { func (u *user) Create(ctx context.Context, client database.Executor, user *User) (*User, error) {
return tracing.Wrap(u.tracer, "user.Create", return tracing.Wrap(u.tracer, "user.Create",
handler.Chain( handler.Chain(
handler.Decorate( handler.Decorate(
execute(tx).CreateUser, execute(client).CreateUser,
tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.sql.Create")), tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.sql.Create")),
), ),
handler.Decorate( handler.Decorate(
events(tx).CreateUser, events(client).CreateUser,
tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.event.Create")), tracing.Decorate[*User, *User](u.tracer, tracing.WithSpanName("user.event.Create")),
), ),
), ),
@@ -72,26 +73,59 @@ func (u *user) ByID(ctx context.Context, client database.Querier, id string) (*U
type ChangeEmail struct { type ChangeEmail struct {
UserID string UserID string
Email string Email string
Opt *ChangeEmailOption // Opt *ChangeEmailOption
} }
type ChangeEmailOption struct { // type ChangeEmailOption struct {
returnCode bool // returnCode bool
isVerified bool // isVerified bool
sendCode bool // sendCode bool
// }
// type ChangeEmailVerifiedOption struct {
// isVerified bool
// }
// type ChangeEmailReturnCodeOption struct {
// alg crypto.EncryptionAlgorithm
// }
// type ChangeEmailSendCodeOption struct {
// alg crypto.EncryptionAlgorithm
// urlTemplate string
// }
func (u *user) ChangeEmail(ctx context.Context, client database.Executor, change *ChangeEmail) {
} }
type ChangeEmailVerifiedOption struct { type EmailVerificationCode struct {
isVerified bool Code *crypto.CryptoValue
CreatedAt time.Time
Expiry time.Duration
} }
type ChangeEmailReturnCodeOption struct { func (u *user) EmailVerificationCode(ctx context.Context, client database.Querier, userID string) (*EmailVerificationCode, error) {
alg crypto.EncryptionAlgorithm return tracing.Wrap(u.tracer, "user.EmailVerificationCode",
handler.Decorate(
query(client).EmailVerificationCode,
tracing.Decorate[string, *EmailVerificationCode](u.tracer, tracing.WithSpanName("user.sql.EmailVerificationCode")),
),
)(ctx, userID)
} }
type ChangeEmailSendCodeOption struct { func (u *user) EmailVerificationFailed(ctx context.Context, client database.Executor, userID string) error {
alg crypto.EncryptionAlgorithm _, err := tracing.Wrap(u.tracer, "user.EmailVerificationFailed",
urlTemplate string handler.ErrFuncToHandle(execute(client).EmailVerificationFailed),
)(ctx, userID)
return err
} }
func (u *user) ChangeEmail(ctx context.Context, client database.Executor, change *ChangeEmail) func (u *user) EmailVerificationSucceeded(ctx context.Context, client database.Executor, userID string) error {
_, err := tracing.Wrap(u.tracer, "user.EmailVerificationSucceeded",
handler.ErrFuncToHandle(execute(client).EmailVerificationSucceeded),
)(ctx, userID)
return err
}

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"errors"
"log" "log"
) )
@@ -17,6 +18,24 @@ func (q *querier) UserByID(ctx context.Context, id string) (res *User, err error
return &user, nil return &user, nil
} }
const emailVerificationCodeStmt = `SELECT created_at, expiry,code FROM email_verification_codes WHERE user_id = $1`
func (q *querier) EmailVerificationCode(ctx context.Context, userID string) (res *EmailVerificationCode, err error) {
log.Println("sql.user.emailVerificationCode")
res = new(EmailVerificationCode)
err = q.client.QueryRow(ctx, emailVerificationCodeStmt, userID).
Scan(
&res.CreatedAt,
&res.Expiry,
&res.Code,
)
if err != nil {
return nil, err
}
return res, nil
}
func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err error) { func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err error) {
log.Println("sql.user.create") log.Println("sql.user.create")
err = e.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username) err = e.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username)
@@ -25,3 +44,11 @@ func (e *executor) CreateUser(ctx context.Context, user *User) (res *User, err e
} }
return user, nil return user, nil
} }
func (e *executor) EmailVerificationFailed(ctx context.Context, userID string) error {
return errors.New("not implemented")
}
func (e *executor) EmailVerificationSucceeded(ctx context.Context, userID string) error {
return errors.New("not implemented")
}

View File

@@ -25,6 +25,8 @@ type EncryptionAlgorithm interface {
DecryptString(hashed []byte, keyID string) (string, error) DecryptString(hashed []byte, keyID string) (string, error)
} }
// CryptoValue is a struct that can be used to store encrypted values in a database.
// The struct is compatible with the [driver.Valuer] and database/sql.Scanner interfaces.
type CryptoValue struct { type CryptoValue struct {
CryptoType CryptoType CryptoType CryptoType
Algorithm string Algorithm string