mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 18:17:35 +00:00
multiple tries
This commit is contained in:
19
backend/v3/api/instance/v2/server.go
Normal file
19
backend/v3/api/instance/v2/server.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
logger logging.Logger
|
||||
tracer tracing.Tracer
|
||||
)
|
||||
|
||||
func SetLogger(l logging.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
func SetTracer(t tracing.Tracer) {
|
||||
tracer = t
|
||||
}
|
93
backend/v3/api/user/v2/email.go
Normal file
93
backend/v3/api/user/v2/email.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package userv2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
|
||||
func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
|
||||
var (
|
||||
verification domain.SetEmailOpt
|
||||
returnCode *domain.ReturnCodeCommand
|
||||
)
|
||||
|
||||
switch req.GetVerification().(type) {
|
||||
case *user.SetEmailRequest_IsVerified:
|
||||
verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified())
|
||||
case *user.SetEmailRequest_SendCode:
|
||||
verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
case *user.SetEmailRequest_ReturnCode:
|
||||
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
verification = returnCode
|
||||
default:
|
||||
verification = domain.NewSendCodeCommand(req.GetUserId(), nil)
|
||||
}
|
||||
|
||||
err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var code *string
|
||||
if returnCode != nil && returnCode.Code != "" {
|
||||
code = &returnCode.Code
|
||||
}
|
||||
|
||||
return &user.SetEmailResponse{
|
||||
VerificationCode: code,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
|
||||
var (
|
||||
returnCode *domain.ReturnCodeCommand
|
||||
cmd domain.Commander
|
||||
)
|
||||
|
||||
switch req.GetVerification().(type) {
|
||||
case *user.SendEmailCodeRequest_SendCode:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
case *user.SendEmailCodeRequest_ReturnCode:
|
||||
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
cmd = returnCode
|
||||
default:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
}
|
||||
err = domain.Invoke(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = new(user.SendEmailCodeResponse)
|
||||
if returnCode != nil {
|
||||
resp.VerificationCode = &returnCode.Code
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
|
||||
var (
|
||||
returnCode *domain.ReturnCodeCommand
|
||||
cmd domain.Commander
|
||||
)
|
||||
|
||||
switch req.GetVerification().(type) {
|
||||
case *user.ResendEmailCodeRequest_SendCode:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
case *user.ResendEmailCodeRequest_ReturnCode:
|
||||
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
cmd = returnCode
|
||||
default:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
}
|
||||
err = domain.Invoke(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = new(user.SendEmailCodeResponse)
|
||||
if returnCode != nil {
|
||||
resp.VerificationCode = &returnCode.Code
|
||||
}
|
||||
return resp, nil
|
||||
}
|
19
backend/v3/api/user/v2/server.go
Normal file
19
backend/v3/api/user/v2/server.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package userv2
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
logger logging.Logger
|
||||
tracer tracing.Tracer
|
||||
)
|
||||
|
||||
func SetLogger(l logging.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
func SetTracer(t tracing.Tracer) {
|
||||
tracer = t
|
||||
}
|
12
backend/v3/doc.go
Normal file
12
backend/v3/doc.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// the test used the manly relies on the following patterns:
|
||||
// - domain:
|
||||
// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package
|
||||
// - command pattern which implements the changes
|
||||
// - the invoker decorates the commands by checking for events and tracing
|
||||
// - the database connections are manged in this package
|
||||
// - the database connections are passed to the repositories
|
||||
//
|
||||
// - storage:
|
||||
// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package
|
||||
// - the repositories are used by the domain package to access the database
|
||||
package v3
|
105
backend/v3/domain/command.go
Normal file
105
backend/v3/domain/command.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Commander interface {
|
||||
Execute(ctx context.Context, opts *CommandOpts) (err error)
|
||||
}
|
||||
|
||||
type Invoker interface {
|
||||
Invoke(ctx context.Context, command Commander, opts *CommandOpts) error
|
||||
}
|
||||
|
||||
type CommandOpts struct {
|
||||
DB database.QueryExecutor
|
||||
Invoker Invoker
|
||||
}
|
||||
|
||||
type ensureTxOpts struct {
|
||||
*database.TransactionOptions
|
||||
}
|
||||
|
||||
type EnsureTransactionOpt func(*ensureTxOpts)
|
||||
|
||||
// EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction.
|
||||
// The returned close function will end the transaction. If the DB is already a transaction, the close function
|
||||
// will do nothing because another [Commander] is already responsible for ending the transaction.
|
||||
func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) {
|
||||
beginner, ok := o.DB.(database.Beginner)
|
||||
if !ok {
|
||||
// db is already a transaction
|
||||
return func(_ context.Context, err error) error {
|
||||
return err
|
||||
}, nil
|
||||
}
|
||||
|
||||
txOpts := &ensureTxOpts{
|
||||
TransactionOptions: new(database.TransactionOptions),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(txOpts)
|
||||
}
|
||||
|
||||
tx, err := beginner.Begin(ctx, txOpts.TransactionOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.DB = tx
|
||||
|
||||
return func(ctx context.Context, err error) error {
|
||||
return tx.End(ctx, err)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool].
|
||||
// The returned close function will release the client. If the o.DB is already a client or transaction, the close function
|
||||
// will do nothing because another [Commander] is already responsible for releasing the client.
|
||||
func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) {
|
||||
pool, ok := o.DB.(database.Pool)
|
||||
if !ok {
|
||||
// o.DB is already a client
|
||||
return func(_ context.Context) error {
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
client, err := pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.DB = client
|
||||
return func(ctx context.Context) error {
|
||||
return client.Release(ctx)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error {
|
||||
if o.Invoker == nil {
|
||||
return command.Execute(ctx, o)
|
||||
}
|
||||
return o.Invoker.Invoke(ctx, command, o)
|
||||
}
|
||||
|
||||
func DefaultOpts(invoker Invoker) *CommandOpts {
|
||||
if invoker == nil {
|
||||
invoker = &noopInvoker{}
|
||||
}
|
||||
return &CommandOpts{
|
||||
DB: pool,
|
||||
Invoker: invoker,
|
||||
}
|
||||
}
|
||||
|
||||
type noopInvoker struct {
|
||||
next Invoker
|
||||
}
|
||||
|
||||
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
|
||||
if i.next != nil {
|
||||
return i.next.Invoke(ctx, command, opts)
|
||||
}
|
||||
return command.Execute(ctx, opts)
|
||||
}
|
76
backend/v3/domain/create_user.go
Normal file
76
backend/v3/domain/create_user.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
)
|
||||
|
||||
type CreateUserCommand struct {
|
||||
user *User
|
||||
email *SetEmailCommand
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*CreateUserCommand)(nil)
|
||||
_ eventer = (*CreateUserCommand)(nil)
|
||||
)
|
||||
|
||||
func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand {
|
||||
cmd := &CreateUserCommand{
|
||||
user: &User{
|
||||
User: v4.User{
|
||||
Username: username,
|
||||
Traits: &v4.Human{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt.applyOnCreateHuman(cmd)
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Events implements [eventer].
|
||||
func (c *CreateUserCommand) Events() []*eventstore.Event {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Execute implements [Commander].
|
||||
func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
if err := c.ensureUserID(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.email.UserID = c.user.ID
|
||||
if err := opts.Invoke(ctx, c.email); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type CreateHumanOpt interface {
|
||||
applyOnCreateHuman(*CreateUserCommand)
|
||||
}
|
||||
|
||||
type createHumanIDOpt string
|
||||
|
||||
// applyOnCreateHuman implements [CreateHumanOpt].
|
||||
func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) {
|
||||
cmd.user.ID = string(c)
|
||||
}
|
||||
|
||||
var _ CreateHumanOpt = (*createHumanIDOpt)(nil)
|
||||
|
||||
func CreateHumanWithID(id string) CreateHumanOpt {
|
||||
return createHumanIDOpt(id)
|
||||
}
|
||||
|
||||
func (c *CreateUserCommand) ensureUserID() (err error) {
|
||||
if c.user.ID != "" {
|
||||
return nil
|
||||
}
|
||||
c.user.ID, err = generateID()
|
||||
return err
|
||||
}
|
26
backend/v3/domain/crypto.go
Normal file
26
backend/v3/domain/crypto.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
type generateCodeCommand struct {
|
||||
code string
|
||||
value *crypto.CryptoValue
|
||||
}
|
||||
|
||||
type CryptoRepository interface {
|
||||
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
|
||||
}
|
||||
|
||||
func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm)
|
||||
cmd.value, cmd.code, err = crypto.NewCode(generator)
|
||||
return err
|
||||
}
|
52
backend/v3/domain/domain.go
Normal file
52
backend/v3/domain/domain.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
pool database.Pool
|
||||
userCodeAlgorithm crypto.EncryptionAlgorithm
|
||||
tracer tracing.Tracer
|
||||
|
||||
// userRepo func(database.QueryExecutor) UserRepository
|
||||
instanceRepo func(database.QueryExecutor) InstanceRepository
|
||||
cryptoRepo func(database.QueryExecutor) CryptoRepository
|
||||
orgRepo func(database.QueryExecutor) OrgRepository
|
||||
|
||||
instanceCache cache.Cache[string, string, *Instance]
|
||||
|
||||
generateID func() (string, error) = func() (string, error) {
|
||||
return strconv.FormatUint(rand.Uint64(), 10), nil
|
||||
}
|
||||
)
|
||||
|
||||
func SetPool(p database.Pool) {
|
||||
pool = p
|
||||
}
|
||||
|
||||
func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) {
|
||||
userCodeAlgorithm = algorithm
|
||||
}
|
||||
|
||||
func SetTracer(t tracing.Tracer) {
|
||||
tracer = t
|
||||
}
|
||||
|
||||
// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
|
||||
// userRepo = repo
|
||||
// }
|
||||
|
||||
func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
|
||||
instanceRepo = repo
|
||||
}
|
||||
|
||||
func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) {
|
||||
cryptoRepo = repo
|
||||
}
|
45
backend/v3/domain/domain_test.go
Normal file
45
backend/v3/domain/domain_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
|
||||
. "github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
)
|
||||
|
||||
func TestExample(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// SetPool(pool)
|
||||
|
||||
exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
|
||||
require.NoError(t, err)
|
||||
tracerProvider := sdktrace.NewTracerProvider(
|
||||
sdktrace.WithSyncer(exporter),
|
||||
)
|
||||
otel.SetTracerProvider(tracerProvider)
|
||||
SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
|
||||
defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
|
||||
|
||||
SetUserRepository(repository.User)
|
||||
SetInstanceRepository(repository.Instance)
|
||||
SetCryptoRepository(repository.Crypto)
|
||||
|
||||
t.Run("verified email", func(t *testing.T) {
|
||||
err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("unverified email", func(t *testing.T) {
|
||||
err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
155
backend/v3/domain/email_verification.go
Normal file
155
backend/v3/domain/email_verification.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
|
||||
)
|
||||
|
||||
type EmailVerifiedCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email *Email `json:"email"`
|
||||
}
|
||||
|
||||
func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand {
|
||||
return &EmailVerifiedCommand{
|
||||
UserID: userID,
|
||||
Email: &Email{
|
||||
IsVerified: isVerified,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*EmailVerifiedCommand)(nil)
|
||||
_ SetEmailOpt = (*EmailVerifiedCommand)(nil)
|
||||
)
|
||||
|
||||
// Execute implements [Commander]
|
||||
func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
return userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmailVerified(ctx, cmd.Email.Address)
|
||||
}
|
||||
|
||||
// applyOnSetEmail implements [SetEmailOpt]
|
||||
func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
cmd.UserID = setEmailCmd.UserID
|
||||
cmd.Email.Address = setEmailCmd.Email
|
||||
setEmailCmd.verification = cmd
|
||||
}
|
||||
|
||||
type SendCodeCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
URLTemplate *string `json:"urlTemplate"`
|
||||
generator *generateCodeCommand
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*SendCodeCommand)(nil)
|
||||
_ SetEmailOpt = (*SendCodeCommand)(nil)
|
||||
)
|
||||
|
||||
func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
|
||||
return &SendCodeCommand{
|
||||
UserID: userID,
|
||||
generator: &generateCodeCommand{},
|
||||
URLTemplate: urlTemplate,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute implements [Commander]
|
||||
func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
if err := cmd.ensureEmail(ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cmd.ensureURL(ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: queue notification
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
|
||||
if cmd.Email != "" {
|
||||
return nil
|
||||
}
|
||||
email, err := userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().GetEmail(ctx)
|
||||
if err != nil || email.IsVerified {
|
||||
return err
|
||||
}
|
||||
cmd.Email = email.Address
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error {
|
||||
if cmd.URLTemplate != nil && *cmd.URLTemplate != "" {
|
||||
return nil
|
||||
}
|
||||
_, _ = ctx, opts
|
||||
// TODO: load default template
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOnSetEmail implements [SetEmailOpt]
|
||||
func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
cmd.UserID = setEmailCmd.UserID
|
||||
cmd.Email = setEmailCmd.Email
|
||||
setEmailCmd.verification = cmd
|
||||
}
|
||||
|
||||
type ReturnCodeCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
Code string `json:"code"`
|
||||
generator *generateCodeCommand
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*ReturnCodeCommand)(nil)
|
||||
_ SetEmailOpt = (*ReturnCodeCommand)(nil)
|
||||
)
|
||||
|
||||
func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
|
||||
return &ReturnCodeCommand{
|
||||
UserID: userID,
|
||||
generator: &generateCodeCommand{},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute implements [Commander]
|
||||
func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
if err := cmd.ensureEmail(ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Code = cmd.generator.code
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
|
||||
if cmd.Email != "" {
|
||||
return nil
|
||||
}
|
||||
user := v4.UserRepository(opts.DB)
|
||||
user.WithCondition(user.IDCondition(cmd.UserID))
|
||||
email, err := user.he.GetEmail(ctx)
|
||||
if err != nil || email.IsVerified {
|
||||
return err
|
||||
}
|
||||
cmd.Email = email.Address
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOnSetEmail implements [SetEmailOpt]
|
||||
func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
cmd.UserID = setEmailCmd.UserID
|
||||
cmd.Email = setEmailCmd.Email
|
||||
setEmailCmd.verification = cmd
|
||||
}
|
7
backend/v3/domain/errors.go
Normal file
7
backend/v3/domain/errors.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package domain
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoAdminSpecified = errors.New("at least one admin must be specified")
|
||||
)
|
36
backend/v3/domain/instance.go
Normal file
36
backend/v3/domain/instance.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"-"`
|
||||
UpdatedAt time.Time `json:"-"`
|
||||
DeletedAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
// Keys implements the [cache.Entry].
|
||||
func (i *Instance) Keys(index string) (key []string) {
|
||||
// TODO: Return the correct keys for the instance cache, e.g., i.ID, i.Domain
|
||||
return []string{}
|
||||
}
|
||||
|
||||
type InstanceRepository interface {
|
||||
ByID(ctx context.Context, id string) (*Instance, error)
|
||||
Create(ctx context.Context, instance *Instance) error
|
||||
On(id string) InstanceOperation
|
||||
}
|
||||
|
||||
type InstanceOperation interface {
|
||||
AdminRepository
|
||||
Update(ctx context.Context, instance *Instance) error
|
||||
Delete(ctx context.Context) error
|
||||
}
|
||||
|
||||
type CreateInstance struct {
|
||||
Name string `json:"name"`
|
||||
}
|
94
backend/v3/domain/invoke.go
Normal file
94
backend/v3/domain/invoke.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
)
|
||||
|
||||
var defaultInvoker = newEventStoreInvoker(newTraceInvoker(nil))
|
||||
|
||||
func Invoke(ctx context.Context, cmd Commander) error {
|
||||
invoker := newEventStoreInvoker(newTraceInvoker(nil))
|
||||
opts := &CommandOpts{
|
||||
Invoker: invoker.collector,
|
||||
}
|
||||
return invoker.Invoke(ctx, cmd, opts)
|
||||
}
|
||||
|
||||
type eventStoreInvoker struct {
|
||||
collector *eventCollector
|
||||
}
|
||||
|
||||
func newEventStoreInvoker(next Invoker) *eventStoreInvoker {
|
||||
return &eventStoreInvoker{collector: &eventCollector{next: next}}
|
||||
}
|
||||
|
||||
func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
err = i.collector.Invoke(ctx, command, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(i.collector.events) > 0 {
|
||||
err = eventstore.Publish(ctx, i.collector.events, opts.DB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type eventCollector struct {
|
||||
next Invoker
|
||||
events []*eventstore.Event
|
||||
}
|
||||
|
||||
type eventer interface {
|
||||
Events() []*eventstore.Event
|
||||
}
|
||||
|
||||
func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
if e, ok := command.(eventer); ok && len(e.Events()) > 0 {
|
||||
// we need to ensure all commands are executed in the same transaction
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
|
||||
i.events = append(i.events, e.Events()...)
|
||||
}
|
||||
if i.next != nil {
|
||||
err = i.next.Invoke(ctx, command, opts)
|
||||
} else {
|
||||
err = command.Execute(ctx, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type traceInvoker struct {
|
||||
next Invoker
|
||||
}
|
||||
|
||||
func newTraceInvoker(next Invoker) *traceInvoker {
|
||||
return &traceInvoker{next: next}
|
||||
}
|
||||
|
||||
func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
|
||||
defer span.End()
|
||||
|
||||
if i.next != nil {
|
||||
err = i.next.Invoke(ctx, command, opts)
|
||||
} else {
|
||||
err = command.Execute(ctx, opts)
|
||||
}
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
}
|
||||
return err
|
||||
}
|
39
backend/v3/domain/org.go
Normal file
39
backend/v3/domain/org.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Org struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type OrgRepository interface {
|
||||
ByID(ctx context.Context, orgID string) (*Org, error)
|
||||
Create(ctx context.Context, org *Org) error
|
||||
On(id string) OrgOperation
|
||||
}
|
||||
|
||||
type OrgOperation interface {
|
||||
AdminRepository
|
||||
DomainRepository
|
||||
Update(ctx context.Context, org *Org) error
|
||||
Delete(ctx context.Context) error
|
||||
}
|
||||
|
||||
type AdminRepository interface {
|
||||
AddAdmin(ctx context.Context, userID string, roles []string) error
|
||||
SetAdminRoles(ctx context.Context, userID string, roles []string) error
|
||||
RemoveAdmin(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
type DomainRepository interface {
|
||||
AddDomain(ctx context.Context, domain string) error
|
||||
SetDomainVerified(ctx context.Context, domain string) error
|
||||
RemoveDomain(ctx context.Context, domain string) error
|
||||
}
|
74
backend/v3/domain/org_add.go
Normal file
74
backend/v3/domain/org_add.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type AddOrgCommand struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Admins []AddAdminCommand `json:"admins"`
|
||||
}
|
||||
|
||||
func NewAddOrgCommand(name string, admins ...AddAdminCommand) *AddOrgCommand {
|
||||
return &AddOrgCommand{
|
||||
Name: name,
|
||||
Admins: admins,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute implements Commander.
|
||||
func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
|
||||
if len(cmd.Admins) == 0 {
|
||||
return ErrNoAdminSpecified
|
||||
}
|
||||
if err = cmd.ensureID(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
err = orgRepo(opts.DB).Create(ctx, &Org{
|
||||
ID: cmd.ID,
|
||||
Name: cmd.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*AddOrgCommand)(nil)
|
||||
)
|
||||
|
||||
func (cmd *AddOrgCommand) ensureID() (err error) {
|
||||
if cmd.ID != "" {
|
||||
return nil
|
||||
}
|
||||
cmd.ID, err = generateID()
|
||||
return err
|
||||
}
|
||||
|
||||
type AddAdminCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Roles []string `json:"roles"`
|
||||
}
|
||||
|
||||
// Execute implements Commander.
|
||||
func (a *AddAdminCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*AddAdminCommand)(nil)
|
||||
)
|
82
backend/v3/domain/repository.go
Normal file
82
backend/v3/domain/repository.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Operation interface {
|
||||
// TextOperation |
|
||||
// NumberOperation |
|
||||
// BoolOperation
|
||||
|
||||
op()
|
||||
}
|
||||
|
||||
type clause[F ~uint8, Op Operation] struct {
|
||||
field F
|
||||
op Op
|
||||
}
|
||||
|
||||
func (c *clause[F, Op]) Field() F {
|
||||
return c.field
|
||||
}
|
||||
|
||||
func (c *clause[F, Op]) Operation() Op {
|
||||
return c.op
|
||||
}
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
}
|
||||
|
||||
type TextOperation uint8
|
||||
|
||||
const (
|
||||
TextOperationEqual TextOperation = iota
|
||||
TextOperationNotEqual
|
||||
TextOperationStartsWith
|
||||
TextOperationStartsWithIgnoreCase
|
||||
)
|
||||
|
||||
func (TextOperation) op() {}
|
||||
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex | time.Time
|
||||
}
|
||||
|
||||
type NumberOperation uint8
|
||||
|
||||
const (
|
||||
NumberOperationEqual NumberOperation = iota
|
||||
NumberOperationNotEqual
|
||||
NumberOperationLessThan
|
||||
NumberOperationLessThanOrEqual
|
||||
NumberOperationGreaterThan
|
||||
NumberOperationGreaterThanOrEqual
|
||||
)
|
||||
|
||||
func (NumberOperation) op() {}
|
||||
|
||||
type Bool interface {
|
||||
~bool
|
||||
}
|
||||
|
||||
type BoolOperation uint8
|
||||
|
||||
const (
|
||||
BoolOperationIs BoolOperation = iota
|
||||
BoolOperationNot
|
||||
)
|
||||
|
||||
func (BoolOperation) op() {}
|
||||
|
||||
type ListOperation uint8
|
||||
|
||||
const (
|
||||
ListOperationContains ListOperation = iota
|
||||
ListOperationNotContains
|
||||
)
|
||||
|
||||
func (ListOperation) op() {}
|
64
backend/v3/domain/set_email.go
Normal file
64
backend/v3/domain/set_email.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
)
|
||||
|
||||
type SetEmailCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
verification Commander
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*SetEmailCommand)(nil)
|
||||
_ eventer = (*SetEmailCommand)(nil)
|
||||
_ CreateHumanOpt = (*SetEmailCommand)(nil)
|
||||
)
|
||||
|
||||
type SetEmailOpt interface {
|
||||
applyOnSetEmail(*SetEmailCommand)
|
||||
}
|
||||
|
||||
func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand {
|
||||
cmd := &SetEmailCommand{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}
|
||||
verificationType.applyOnSetEmail(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
// userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email)
|
||||
err = userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmail(ctx, cmd.Email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return opts.Invoke(ctx, cmd.verification)
|
||||
}
|
||||
|
||||
// Events implements [eventer].
|
||||
func (cmd *SetEmailCommand) Events() []*eventstore.Event {
|
||||
return []*eventstore.Event{
|
||||
{
|
||||
AggregateType: "user",
|
||||
AggregateID: cmd.UserID,
|
||||
Type: "user.email.set",
|
||||
Payload: cmd,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// applyOnCreateHuman implements [CreateHumanOpt].
|
||||
func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand[Human]) {
|
||||
createUserCmd.email = cmd
|
||||
}
|
193
backend/v3/domain/user.go
Normal file
193
backend/v3/domain/user.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
|
||||
)
|
||||
|
||||
type userColumns interface {
|
||||
// TODO: move v4.columns to domain
|
||||
InstanceIDColumn() column
|
||||
OrgIDColumn() column
|
||||
IDColumn() column
|
||||
usernameColumn() column
|
||||
CreatedAtColumn() column
|
||||
UpdatedAtColumn() column
|
||||
DeletedAtColumn() column
|
||||
}
|
||||
|
||||
type userConditions interface {
|
||||
InstanceIDCondition(instanceID string) v4.Condition
|
||||
OrgIDCondition(orgID string) v4.Condition
|
||||
IDCondition(userID string) v4.Condition
|
||||
UsernameCondition(op v4.TextOperator, username string) v4.Condition
|
||||
CreatedAtCondition(op v4.NumberOperator, createdAt time.Time) v4.Condition
|
||||
UpdatedAtCondition(op v4.NumberOperator, updatedAt time.Time) v4.Condition
|
||||
DeletedCondition(isDeleted bool) v4.Condition
|
||||
DeletedAtCondition(op v4.NumberOperator, deletedAt time.Time) v4.Condition
|
||||
}
|
||||
|
||||
type UserRepository interface {
|
||||
userColumns
|
||||
userConditions
|
||||
// TODO: move condition to domain
|
||||
WithCondition(condition v4.Condition) UserRepository
|
||||
Get(ctx context.Context) (*User, error)
|
||||
List(ctx context.Context) ([]*User, error)
|
||||
Create(ctx context.Context, user *User) error
|
||||
Delete(ctx context.Context) error
|
||||
|
||||
Human() HumanRepository
|
||||
Machine() MachineRepository
|
||||
}
|
||||
|
||||
type humanColumns interface {
|
||||
FirstNameColumn() column
|
||||
LastNameColumn() column
|
||||
EmailAddressColumn() column
|
||||
EmailVerifiedAtColumn() column
|
||||
PhoneNumberColumn() column
|
||||
PhoneVerifiedAtColumn() column
|
||||
}
|
||||
|
||||
type humanConditions interface {
|
||||
FirstNameCondition(op v4.TextOperator, firstName string) v4.Condition
|
||||
LastNameCondition(op v4.TextOperator, lastName string) v4.Condition
|
||||
EmailAddressCondition(op v4.TextOperator, email string) v4.Condition
|
||||
EmailAddressVerifiedCondition(isVerified bool) v4.Condition
|
||||
EmailVerifiedAtCondition(op v4.TextOperator, emailVerifiedAt string) v4.Condition
|
||||
PhoneNumberCondition(op v4.TextOperator, phoneNumber string) v4.Condition
|
||||
PhoneNumberVerifiedCondition(isVerified bool) v4.Condition
|
||||
PhoneVerifiedAtCondition(op v4.TextOperator, phoneVerifiedAt string) v4.Condition
|
||||
}
|
||||
|
||||
type HumanRepository interface {
|
||||
humanColumns
|
||||
humanConditions
|
||||
|
||||
GetEmail(ctx context.Context) (*Email, error)
|
||||
// TODO: replace any with add email update columns
|
||||
SetEmail(ctx context.Context, columns ...any) error
|
||||
}
|
||||
|
||||
type machineColumns interface {
|
||||
DescriptionColumn() column
|
||||
}
|
||||
|
||||
type machineConditions interface {
|
||||
DescriptionCondition(op v4.TextOperator, description string) v4.Condition
|
||||
}
|
||||
|
||||
type MachineRepository interface {
|
||||
machineColumns
|
||||
machineConditions
|
||||
}
|
||||
|
||||
// type UserRepository interface {
|
||||
// // Get(ctx context.Context, clauses ...UserClause) (*User, error)
|
||||
// // Search(ctx context.Context, clauses ...UserClause) ([]*User, error)
|
||||
|
||||
// UserQuery[UserOperation]
|
||||
// Human() HumanQuery
|
||||
// Machine() MachineQuery
|
||||
// }
|
||||
|
||||
// type UserQuery[Op UserOperation] interface {
|
||||
// ByID(id string) UserQuery[Op]
|
||||
// Username(username string) UserQuery[Op]
|
||||
// Exec() Op
|
||||
// }
|
||||
|
||||
// type HumanQuery interface {
|
||||
// UserQuery[HumanOperation]
|
||||
// Email(op TextOperation, email string) HumanQuery
|
||||
// HumanOperation
|
||||
// }
|
||||
|
||||
// type MachineQuery interface {
|
||||
// UserQuery[MachineOperation]
|
||||
// MachineOperation
|
||||
// }
|
||||
|
||||
// type UserClause interface {
|
||||
// Field() UserField
|
||||
// Operation() Operation
|
||||
// Args() []any
|
||||
// }
|
||||
|
||||
// type UserField uint8
|
||||
|
||||
// const (
|
||||
// // Fields used for all users
|
||||
// UserFieldInstanceID UserField = iota + 1
|
||||
// UserFieldOrgID
|
||||
// UserFieldID
|
||||
// UserFieldUsername
|
||||
|
||||
// // Fields used for human users
|
||||
// UserHumanFieldEmail
|
||||
// UserHumanFieldEmailVerified
|
||||
|
||||
// // Fields used for machine users
|
||||
// UserMachineFieldDescription
|
||||
// )
|
||||
|
||||
// type userByIDClause struct {
|
||||
// id string
|
||||
// }
|
||||
|
||||
// func (c *userByIDClause) Field() UserField {
|
||||
// return UserFieldID
|
||||
// }
|
||||
|
||||
// func (c *userByIDClause) Operation() Operation {
|
||||
// return TextOperationEqual
|
||||
// }
|
||||
|
||||
// func (c *userByIDClause) Args() []any {
|
||||
// return []any{c.id}
|
||||
// }
|
||||
|
||||
// type UserOperation interface {
|
||||
// Delete(ctx context.Context) error
|
||||
// SetUsername(ctx context.Context, username string) error
|
||||
// }
|
||||
|
||||
// type HumanOperation interface {
|
||||
// UserOperation
|
||||
// SetEmail(ctx context.Context, email string) error
|
||||
// SetEmailVerified(ctx context.Context, email string) error
|
||||
// GetEmail(ctx context.Context) (*Email, error)
|
||||
// }
|
||||
|
||||
// type MachineOperation interface {
|
||||
// UserOperation
|
||||
// SetDescription(ctx context.Context, description string) error
|
||||
// }
|
||||
|
||||
type User struct {
|
||||
v4.User
|
||||
}
|
||||
|
||||
// type userTraits interface {
|
||||
// isUserTraits()
|
||||
// }
|
||||
|
||||
// type Human struct {
|
||||
// Email *Email `json:"email"`
|
||||
// }
|
||||
|
||||
// func (*Human) isUserTraits() {}
|
||||
|
||||
// type Machine struct {
|
||||
// Description string `json:"description"`
|
||||
// }
|
||||
|
||||
// func (*Machine) isUserTraits() {}
|
||||
|
||||
// type Email struct {
|
||||
// Address string `json:"address"`
|
||||
// IsVerified bool `json:"isVerified"`
|
||||
// }
|
112
backend/v3/storage/cache/cache.go
vendored
Normal file
112
backend/v3/storage/cache/cache.go
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package cache provides abstraction of cache implementations that can be used by zitadel.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Purpose describes which object types are stored by a cache.
|
||||
type Purpose int
|
||||
|
||||
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
|
||||
const (
|
||||
PurposeUnspecified Purpose = iota
|
||||
PurposeAuthzInstance
|
||||
PurposeMilestones
|
||||
PurposeOrganization
|
||||
PurposeIdPFormCallback
|
||||
)
|
||||
|
||||
// Cache stores objects with a value of type `V`.
|
||||
// Objects may be referred to by one or more indices.
|
||||
// Implementations may encode the value for storage.
|
||||
// This means non-exported fields may be lost and objects
|
||||
// with function values may fail to encode.
|
||||
// See https://pkg.go.dev/encoding/json#Marshal for example.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
// Indices are defined when calling the constructor of an implementation of this interface.
|
||||
// It is illegal to refer to an idex not defined during construction.
|
||||
//
|
||||
// `K` is the type used as key in each index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
//
|
||||
// Implementations are free to use stricter type constraints or fixed typing.
|
||||
type Cache[I, K comparable, V Entry[I, K]] interface {
|
||||
// Get an object through specified index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// [ErrCacheMiss] is returned if the key was not found in the index,
|
||||
// or the object is not valid.
|
||||
Get(ctx context.Context, index I, key K) (V, bool)
|
||||
|
||||
// Set an object.
|
||||
// Keys are created on each index based in the [Entry.Keys] method.
|
||||
// If any key maps to an existing object, the object is invalidated,
|
||||
// regardless if the object has other keys defined in the new entry.
|
||||
// This to prevent ghost objects when an entry reduces the amount of keys
|
||||
// for a given index.
|
||||
Set(ctx context.Context, value V)
|
||||
|
||||
// Invalidate an object through specified index.
|
||||
// Implementations may choose to instantly delete the object,
|
||||
// defer until prune or a separate cleanup routine.
|
||||
// Invalidated object are no longer returned from Get.
|
||||
// It is safe to call Invalidate multiple times or on non-existing entries.
|
||||
Invalidate(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Delete one or more keys from a specific index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// The referred object is not invalidated and may still be accessible though
|
||||
// other indices and keys.
|
||||
// It is safe to call Delete multiple times or on non-existing entries
|
||||
Delete(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Truncate deletes all cached objects.
|
||||
Truncate(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Entry contains a value of type `V` to be cached.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
//
|
||||
// `K` is the type used as key in an index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
type Entry[I, K comparable] interface {
|
||||
// Keys returns which keys map to the object in a specified index.
|
||||
// May return nil if the index in unknown or when there are no keys.
|
||||
Keys(index I) (key []K)
|
||||
}
|
||||
|
||||
type Connector int
|
||||
|
||||
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
|
||||
const (
|
||||
// Empty line comment ensures empty string for unspecified value
|
||||
ConnectorUnspecified Connector = iota //
|
||||
ConnectorMemory
|
||||
ConnectorPostgres
|
||||
ConnectorRedis
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Connector Connector
|
||||
|
||||
// Age since an object was added to the cache,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables max age checks.
|
||||
MaxAge time.Duration
|
||||
|
||||
// Age since last use (Get) of an object,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables last use age checks.
|
||||
LastUseAge time.Duration
|
||||
|
||||
// Log allows logging of the specific cache.
|
||||
// By default only errors are logged to stdout.
|
||||
Log *logging.Config
|
||||
}
|
49
backend/v3/storage/cache/connector/connector.go
vendored
Normal file
49
backend/v3/storage/cache/connector/connector.go
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop"
|
||||
)
|
||||
|
||||
type CachesConfig struct {
|
||||
Connectors struct {
|
||||
Memory gomap.Config
|
||||
}
|
||||
Instance *cache.Config
|
||||
Milestones *cache.Config
|
||||
Organization *cache.Config
|
||||
IdPFormCallbacks *cache.Config
|
||||
}
|
||||
|
||||
type Connectors struct {
|
||||
Config CachesConfig
|
||||
Memory *gomap.Connector
|
||||
}
|
||||
|
||||
func StartConnectors(conf *CachesConfig) (Connectors, error) {
|
||||
if conf == nil {
|
||||
return Connectors{}, nil
|
||||
}
|
||||
return Connectors{
|
||||
Config: *conf,
|
||||
Memory: gomap.NewConnector(conf.Connectors.Memory),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
|
||||
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
|
||||
return noop.NewCache[I, K, V](), nil
|
||||
}
|
||||
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
|
||||
c := gomap.NewCache[I, K, V](background, indices, *conf)
|
||||
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
|
||||
}
|
23
backend/v3/storage/cache/connector/gomap/connector.go
vendored
Normal file
23
backend/v3/storage/cache/connector/gomap/connector.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
AutoPrune cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
Config cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
func NewConnector(config Config) *Connector {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
return &Connector{
|
||||
Config: config.AutoPrune,
|
||||
}
|
||||
}
|
200
backend/v3/storage/cache/connector/gomap/gomap.go
vendored
Normal file
200
backend/v3/storage/cache/connector/gomap/gomap.go
vendored
Normal file
@@ -0,0 +1,200 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
config *cache.Config
|
||||
indexMap map[I]*index[K, V]
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
|
||||
// Object values are stored as-is and there is no encoding or decoding involved.
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
|
||||
m := &mapCache[I, K, V]{
|
||||
config: &config,
|
||||
indexMap: make(map[I]*index[K, V], len(indices)),
|
||||
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelError,
|
||||
})),
|
||||
}
|
||||
if config.Log != nil {
|
||||
m.logger = config.Log.Slog()
|
||||
}
|
||||
m.logger.InfoContext(background, "map cache logging enabled")
|
||||
|
||||
for _, name := range indices {
|
||||
m.indexMap[name] = &index[K, V]{
|
||||
config: m.config,
|
||||
entries: make(map[K]*entry[V]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
entry, err := i.Get(key)
|
||||
if err == nil {
|
||||
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
|
||||
return entry.value, true
|
||||
}
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
|
||||
now := time.Now()
|
||||
entry := &entry[V]{
|
||||
value: value,
|
||||
created: now,
|
||||
}
|
||||
entry.lastUse.Store(now.UnixMicro())
|
||||
|
||||
for name, i := range c.indexMap {
|
||||
keys := value.Keys(name)
|
||||
i.Set(keys, entry)
|
||||
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Invalidate(keys)
|
||||
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Delete(keys)
|
||||
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Prune()
|
||||
c.logger.DebugContext(ctx, "map cache prune", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Truncate()
|
||||
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type index[K comparable, V any] struct {
|
||||
mutex sync.RWMutex
|
||||
config *cache.Config
|
||||
entries map[K]*entry[V]
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Get(key K) (*entry[V], error) {
|
||||
i.mutex.RLock()
|
||||
entry, ok := i.entries[key]
|
||||
i.mutex.RUnlock()
|
||||
if ok && entry.isValid(i.config) {
|
||||
return entry, nil
|
||||
}
|
||||
return nil, cache.ErrCacheMiss
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
c.entries[key] = entry
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Invalidate(keys []K) {
|
||||
i.mutex.RLock()
|
||||
for _, key := range keys {
|
||||
if entry, ok := i.entries[key]; ok {
|
||||
entry.invalid.Store(true)
|
||||
}
|
||||
}
|
||||
i.mutex.RUnlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Delete(keys []K) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
delete(c.entries, key)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Prune() {
|
||||
c.mutex.Lock()
|
||||
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
|
||||
return !entry.isValid(c.config)
|
||||
})
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Truncate() {
|
||||
c.mutex.Lock()
|
||||
c.entries = make(map[K]*entry[V])
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
type entry[V any] struct {
|
||||
value V
|
||||
created time.Time
|
||||
invalid atomic.Bool
|
||||
lastUse atomic.Int64 // UnixMicro time
|
||||
}
|
||||
|
||||
func (e *entry[V]) isValid(c *cache.Config) bool {
|
||||
if e.invalid.Load() {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if c.MaxAge > 0 {
|
||||
if e.created.Add(c.MaxAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.LastUseAge > 0 {
|
||||
lastUse := e.lastUse.Load()
|
||||
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
|
||||
}
|
||||
return true
|
||||
}
|
329
backend/v3/storage/cache/connector/gomap/gomap_test.go
vendored
Normal file
329
backend/v3/storage/cache/connector/gomap/gomap_test.go
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type testIndex int
|
||||
|
||||
const (
|
||||
testIndexID testIndex = iota
|
||||
testIndexName
|
||||
)
|
||||
|
||||
var testIndices = []testIndex{
|
||||
testIndexID,
|
||||
testIndexName,
|
||||
}
|
||||
|
||||
type testObject struct {
|
||||
id string
|
||||
names []string
|
||||
}
|
||||
|
||||
func (o *testObject) Keys(index testIndex) []string {
|
||||
switch index {
|
||||
case testIndexID:
|
||||
return []string{o.id}
|
||||
case testIndexName:
|
||||
return o.names
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Get(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
|
||||
type args struct {
|
||||
index testIndex
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *testObject
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id",
|
||||
},
|
||||
want: obj,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "miss",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "spanac",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "unknown index",
|
||||
args: args{
|
||||
index: 99,
|
||||
key: "id",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantOk, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Invalidate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Delete(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Delete(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Shouldn't find object by deleted name
|
||||
got, ok := c.Get(context.Background(), testIndexName, "bar")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Should find object by other name
|
||||
got, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Should find object by id
|
||||
got, ok = c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Prune(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
// invalidate one entry
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.(cache.Pruner).Prune(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Other object should still be found
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id2")
|
||||
assert.Equal(t, objects[1], got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Truncate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
|
||||
err := c.Truncate(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
mc := c.(*mapCache[testIndex, string, *testObject])
|
||||
for _, index := range mc.indexMap {
|
||||
index.mutex.RLock()
|
||||
assert.Len(t, index.entries, 0)
|
||||
index.mutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_entry_isValid(t *testing.T) {
|
||||
type fields struct {
|
||||
created time.Time
|
||||
invalid bool
|
||||
lastUse time.Time
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
config *cache.Config
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: true,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "last use age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "last use age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &entry[any]{
|
||||
created: tt.fields.created,
|
||||
}
|
||||
e.invalid.Store(tt.fields.invalid)
|
||||
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
|
||||
got := e.isValid(tt.config)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
21
backend/v3/storage/cache/connector/noop/noop.go
vendored
Normal file
21
backend/v3/storage/cache/connector/noop/noop.go
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package noop
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
|
||||
|
||||
// NewCache returns a cache that does nothing
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
|
||||
return noop[I, K, V]{}
|
||||
}
|
||||
|
||||
func (noop[I, K, V]) Set(context.Context, V) {}
|
||||
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
|
||||
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
|
||||
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }
|
98
backend/v3/storage/cache/connector_enumer.go
vendored
Normal file
98
backend/v3/storage/cache/connector_enumer.go
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _ConnectorName = "memorypostgresredis"
|
||||
|
||||
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
|
||||
|
||||
const _ConnectorLowerName = "memorypostgresredis"
|
||||
|
||||
func (i Connector) String() string {
|
||||
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
|
||||
return fmt.Sprintf("Connector(%d)", i)
|
||||
}
|
||||
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _ConnectorNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[ConnectorUnspecified-(0)]
|
||||
_ = x[ConnectorMemory-(1)]
|
||||
_ = x[ConnectorPostgres-(2)]
|
||||
_ = x[ConnectorRedis-(3)]
|
||||
}
|
||||
|
||||
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
|
||||
|
||||
var _ConnectorNameToValueMap = map[string]Connector{
|
||||
_ConnectorName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorLowerName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorName[0:6]: ConnectorMemory,
|
||||
_ConnectorLowerName[0:6]: ConnectorMemory,
|
||||
_ConnectorName[6:14]: ConnectorPostgres,
|
||||
_ConnectorLowerName[6:14]: ConnectorPostgres,
|
||||
_ConnectorName[14:19]: ConnectorRedis,
|
||||
_ConnectorLowerName[14:19]: ConnectorRedis,
|
||||
}
|
||||
|
||||
var _ConnectorNames = []string{
|
||||
_ConnectorName[0:0],
|
||||
_ConnectorName[0:6],
|
||||
_ConnectorName[6:14],
|
||||
_ConnectorName[14:19],
|
||||
}
|
||||
|
||||
// ConnectorString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func ConnectorString(s string) (Connector, error) {
|
||||
if val, ok := _ConnectorNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Connector values", s)
|
||||
}
|
||||
|
||||
// ConnectorValues returns all values of the enum
|
||||
func ConnectorValues() []Connector {
|
||||
return _ConnectorValues
|
||||
}
|
||||
|
||||
// ConnectorStrings returns a slice of all String values of the enum
|
||||
func ConnectorStrings() []string {
|
||||
strs := make([]string, len(_ConnectorNames))
|
||||
copy(strs, _ConnectorNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Connector) IsAConnector() bool {
|
||||
for _, v := range _ConnectorValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for Connector
|
||||
func (i Connector) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
|
||||
func (i *Connector) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = ConnectorString(string(text))
|
||||
return err
|
||||
}
|
29
backend/v3/storage/cache/error.go
vendored
Normal file
29
backend/v3/storage/cache/error.go
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type IndexUnknownError[I comparable] struct {
|
||||
index I
|
||||
}
|
||||
|
||||
func NewIndexUnknownErr[I comparable](index I) error {
|
||||
return IndexUnknownError[I]{index}
|
||||
}
|
||||
|
||||
func (i IndexUnknownError[I]) Error() string {
|
||||
return fmt.Sprintf("index %v unknown", i.index)
|
||||
}
|
||||
|
||||
func (a IndexUnknownError[I]) Is(err error) bool {
|
||||
if b, ok := err.(IndexUnknownError[I]); ok {
|
||||
return a.index == b.index
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
)
|
76
backend/v3/storage/cache/pruner.go
vendored
Normal file
76
backend/v3/storage/cache/pruner.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Pruner is an optional [Cache] interface.
|
||||
type Pruner interface {
|
||||
// Prune deletes all invalidated or expired objects.
|
||||
Prune(ctx context.Context) error
|
||||
}
|
||||
|
||||
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
|
||||
Cache[I, K, V]
|
||||
Pruner
|
||||
}
|
||||
|
||||
type AutoPruneConfig struct {
|
||||
// Interval at which the cache is automatically pruned.
|
||||
// 0 or lower disables automatic pruning.
|
||||
Interval time.Duration
|
||||
|
||||
// Timeout for an automatic prune.
|
||||
// It is recommended to keep the value shorter than AutoPruneInterval
|
||||
// 0 or lower disables automatic pruning.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
|
||||
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
|
||||
if c.Interval <= 0 {
|
||||
return func() {}
|
||||
}
|
||||
background, cancel := context.WithCancel(background)
|
||||
// randomize the first interval
|
||||
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
|
||||
go c.pruneTimer(background, pruner, purpose, timer)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
<-timer.Chan()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-background.Done():
|
||||
return
|
||||
case <-timer.Chan():
|
||||
err := c.doPrune(background, pruner)
|
||||
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
|
||||
timer.Reset(c.Interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
|
||||
ctx, cancel := context.WithCancel(background)
|
||||
defer cancel()
|
||||
if c.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(background, c.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
return pruner.Prune(ctx)
|
||||
}
|
43
backend/v3/storage/cache/pruner_test.go
vendored
Normal file
43
backend/v3/storage/cache/pruner_test.go
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testPruner struct {
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func (p *testPruner) Prune(context.Context) error {
|
||||
p.called <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
|
||||
c := AutoPruneConfig{
|
||||
Interval: time.Second,
|
||||
Timeout: time.Millisecond,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
pruner := testPruner{
|
||||
called: make(chan struct{}),
|
||||
}
|
||||
clock := clockwork.NewFakeClock()
|
||||
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
|
||||
defer close()
|
||||
clock.Advance(time.Second)
|
||||
|
||||
select {
|
||||
case _, ok := <-pruner.called:
|
||||
assert.True(t, ok)
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
90
backend/v3/storage/cache/purpose_enumer.go
vendored
Normal file
90
backend/v3/storage/cache/purpose_enumer.go
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
|
||||
|
||||
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
func (i Purpose) String() string {
|
||||
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
|
||||
return fmt.Sprintf("Purpose(%d)", i)
|
||||
}
|
||||
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _PurposeNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[PurposeUnspecified-(0)]
|
||||
_ = x[PurposeAuthzInstance-(1)]
|
||||
_ = x[PurposeMilestones-(2)]
|
||||
_ = x[PurposeOrganization-(3)]
|
||||
_ = x[PurposeIdPFormCallback-(4)]
|
||||
}
|
||||
|
||||
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
|
||||
|
||||
var _PurposeNameToValueMap = map[string]Purpose{
|
||||
_PurposeName[0:11]: PurposeUnspecified,
|
||||
_PurposeLowerName[0:11]: PurposeUnspecified,
|
||||
_PurposeName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeLowerName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeName[25:35]: PurposeMilestones,
|
||||
_PurposeLowerName[25:35]: PurposeMilestones,
|
||||
_PurposeName[35:47]: PurposeOrganization,
|
||||
_PurposeLowerName[35:47]: PurposeOrganization,
|
||||
_PurposeName[47:65]: PurposeIdPFormCallback,
|
||||
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
|
||||
}
|
||||
|
||||
var _PurposeNames = []string{
|
||||
_PurposeName[0:11],
|
||||
_PurposeName[11:25],
|
||||
_PurposeName[25:35],
|
||||
_PurposeName[35:47],
|
||||
_PurposeName[47:65],
|
||||
}
|
||||
|
||||
// PurposeString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func PurposeString(s string) (Purpose, error) {
|
||||
if val, ok := _PurposeNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
|
||||
}
|
||||
|
||||
// PurposeValues returns all values of the enum
|
||||
func PurposeValues() []Purpose {
|
||||
return _PurposeValues
|
||||
}
|
||||
|
||||
// PurposeStrings returns a slice of all String values of the enum
|
||||
func PurposeStrings() []string {
|
||||
strs := make([]string, len(_PurposeNames))
|
||||
copy(strs, _PurposeNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Purpose) IsAPurpose() bool {
|
||||
for _, v := range _PurposeValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
9
backend/v3/storage/database/config.go
Normal file
9
backend/v3/storage/database/config.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Connector interface {
|
||||
Connect(ctx context.Context) (Pool, error)
|
||||
}
|
60
backend/v3/storage/database/database.go
Normal file
60
backend/v3/storage/database/database.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
var (
|
||||
db *database
|
||||
)
|
||||
|
||||
type database struct {
|
||||
connector Connector
|
||||
pool Pool
|
||||
}
|
||||
|
||||
type Pool interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
|
||||
Acquire(ctx context.Context) (Client, error)
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
|
||||
Release(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Querier interface {
|
||||
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, stmt string, args ...any) Row
|
||||
}
|
||||
|
||||
type Executor interface {
|
||||
Exec(ctx context.Context, stmt string, args ...any) error
|
||||
}
|
||||
|
||||
type QueryExecutor interface {
|
||||
Querier
|
||||
Executor
|
||||
}
|
||||
|
||||
type Scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
type Row interface {
|
||||
Scanner
|
||||
}
|
||||
|
||||
type Rows interface {
|
||||
Row
|
||||
Next() bool
|
||||
Close() error
|
||||
Err() error
|
||||
}
|
||||
|
||||
type Query[T any] func(querier Querier) (result T, err error)
|
92
backend/v3/storage/database/dialect/config.go
Normal file
92
backend/v3/storage/database/dialect/config.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
|
||||
)
|
||||
|
||||
type Hook struct {
|
||||
Match func(string) bool
|
||||
Decode func(config any) (database.Connector, error)
|
||||
Name string
|
||||
Constructor func() database.Connector
|
||||
}
|
||||
|
||||
var hooks = []Hook{
|
||||
{
|
||||
Match: postgres.NameMatcher,
|
||||
Decode: postgres.DecodeConfig,
|
||||
Name: postgres.Name,
|
||||
Constructor: func() database.Connector { return new(postgres.Config) },
|
||||
},
|
||||
// {
|
||||
// Match: gosql.NameMatcher,
|
||||
// Decode: gosql.DecodeConfig,
|
||||
// Name: gosql.Name,
|
||||
// Constructor: func() database.Connector { return new(gosql.Config) },
|
||||
// },
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
|
||||
|
||||
connector database.Connector
|
||||
}
|
||||
|
||||
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
if len(c.Dialects) != 1 {
|
||||
return nil, errors.New("Exactly one dialect must be configured")
|
||||
}
|
||||
|
||||
return c.connector.Connect(ctx)
|
||||
}
|
||||
|
||||
// Hooks implements [configure.Unmarshaller].
|
||||
func (c Config) Hooks() []viper.DecoderConfigOption {
|
||||
return []viper.DecoderConfigOption{
|
||||
viper.DecodeHook(decodeHook),
|
||||
}
|
||||
}
|
||||
|
||||
func decodeHook(from, to reflect.Value) (_ any, err error) {
|
||||
if to.Type() != reflect.TypeOf(Config{}) {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
config := new(Config)
|
||||
if err = mapstructure.Decode(from.Interface(), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = config.decodeDialect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (c *Config) decodeDialect() error {
|
||||
for _, hook := range hooks {
|
||||
for name, config := range c.Dialects {
|
||||
if !hook.Match(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
connector, err := hook.Decode(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.connector = connector
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("no dialect found")
|
||||
}
|
80
backend/v3/storage/database/dialect/postgres/config.go
Normal file
80
backend/v3/storage/database/dialect/postgres/config.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var (
|
||||
_ database.Connector = (*Config)(nil)
|
||||
Name = "postgres"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
*pgxpool.Config
|
||||
|
||||
// Host string
|
||||
// Port int32
|
||||
// Database string
|
||||
// MaxOpenConns uint32
|
||||
// MaxIdleConns uint32
|
||||
// MaxConnLifetime time.Duration
|
||||
// MaxConnIdleTime time.Duration
|
||||
// User User
|
||||
// // Additional options to be appended as options=<Options>
|
||||
// // The value will be taken as is. Multiple options are space separated.
|
||||
// Options string
|
||||
|
||||
configuredFields []string
|
||||
}
|
||||
|
||||
// Connect implements [database.Connector].
|
||||
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
pool, err := pgxpool.NewWithConfig(ctx, c.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = pool.Ping(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxPool{pool}, nil
|
||||
}
|
||||
|
||||
func NameMatcher(name string) bool {
|
||||
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
|
||||
}
|
||||
|
||||
func DecodeConfig(input any) (database.Connector, error) {
|
||||
switch c := input.(type) {
|
||||
case string:
|
||||
config, err := pgxpool.ParseConfig(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{Config: config}, nil
|
||||
case map[string]any:
|
||||
connector := new(Config)
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: connector,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = decoder.Decode(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{
|
||||
Config: &pgxpool.Config{},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid configuration")
|
||||
}
|
48
backend/v3/storage/database/dialect/postgres/conn.go
Normal file
48
backend/v3/storage/database/dialect/postgres/conn.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type pgxConn struct{ *pgxpool.Conn }
|
||||
|
||||
var _ database.Client = (*pgxConn)(nil)
|
||||
|
||||
// Release implements [database.Client].
|
||||
func (c *pgxConn) Release(_ context.Context) error {
|
||||
c.Conn.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Begin implements [database.Client].
|
||||
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Query implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
|
||||
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Conn.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
|
||||
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return c.Conn.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := c.Conn.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
57
backend/v3/storage/database/dialect/postgres/pool.go
Normal file
57
backend/v3/storage/database/dialect/postgres/pool.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type pgxPool struct{ *pgxpool.Pool }
|
||||
|
||||
var _ database.Pool = (*pgxPool)(nil)
|
||||
|
||||
// Acquire implements [database.Pool].
|
||||
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
conn, err := c.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxConn{conn}, nil
|
||||
}
|
||||
|
||||
// Query implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Pool.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return c.Pool.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := c.Pool.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Begin implements [database.Pool].
|
||||
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Close implements [database.Pool].
|
||||
func (c *pgxPool) Close(_ context.Context) error {
|
||||
c.Pool.Close()
|
||||
return nil
|
||||
}
|
18
backend/v3/storage/database/dialect/postgres/rows.go
Normal file
18
backend/v3/storage/database/dialect/postgres/rows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var _ database.Rows = (*Rows)(nil)
|
||||
|
||||
type Rows struct{ pgx.Rows }
|
||||
|
||||
// Close implements [database.Rows].
|
||||
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
|
||||
func (r *Rows) Close() error {
|
||||
r.Rows.Close()
|
||||
return nil
|
||||
}
|
95
backend/v3/storage/database/dialect/postgres/tx.go
Normal file
95
backend/v3/storage/database/dialect/postgres/tx.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type pgxTx struct{ pgx.Tx }
|
||||
|
||||
var _ database.Transaction = (*pgxTx)(nil)
|
||||
|
||||
// Commit implements [database.Transaction].
|
||||
func (tx *pgxTx) Commit(ctx context.Context) error {
|
||||
return tx.Tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Rollback implements [database.Transaction].
|
||||
func (tx *pgxTx) Rollback(ctx context.Context) error {
|
||||
return tx.Tx.Rollback(ctx)
|
||||
}
|
||||
|
||||
// End implements [database.Transaction].
|
||||
func (tx *pgxTx) End(ctx context.Context, err error) error {
|
||||
if err != nil {
|
||||
tx.Rollback(ctx)
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Query implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
|
||||
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := tx.Tx.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
|
||||
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return tx.Tx.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := tx.Tx.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Begin implements [database.Transaction].
|
||||
// As postgres does not support nested transactions we use savepoints to emulate them.
|
||||
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
|
||||
savepoint, err := tx.Tx.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{savepoint}, nil
|
||||
}
|
||||
|
||||
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
|
||||
if opts == nil {
|
||||
return pgx.TxOptions{}
|
||||
}
|
||||
|
||||
return pgx.TxOptions{
|
||||
IsoLevel: isolationToPgx(opts.IsolationLevel),
|
||||
AccessMode: accessModeToPgx(opts.AccessMode),
|
||||
}
|
||||
}
|
||||
|
||||
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
|
||||
switch isolation {
|
||||
case database.IsolationLevelSerializable:
|
||||
return pgx.Serializable
|
||||
case database.IsolationLevelReadCommitted:
|
||||
return pgx.ReadCommitted
|
||||
default:
|
||||
return pgx.Serializable
|
||||
}
|
||||
}
|
||||
|
||||
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
|
||||
switch accessMode {
|
||||
case database.AccessModeReadWrite:
|
||||
return pgx.ReadWrite
|
||||
case database.AccessModeReadOnly:
|
||||
return pgx.ReadOnly
|
||||
default:
|
||||
return pgx.ReadWrite
|
||||
}
|
||||
}
|
3
backend/v3/storage/database/gen_mock.go
Normal file
3
backend/v3/storage/database/gen_mock.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package database
|
||||
|
||||
//go:generate mockgen -typed -package mock -destination ./mock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
|
1067
backend/v3/storage/database/mock/database.mock.go
Normal file
1067
backend/v3/storage/database/mock/database.mock.go
Normal file
File diff suppressed because it is too large
Load Diff
160
backend/v3/storage/database/repository/clause.go
Normal file
160
backend/v3/storage/database/repository/clause.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
)
|
||||
|
||||
type field interface {
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
type fieldDescriptor struct {
|
||||
schema string
|
||||
table string
|
||||
name string
|
||||
}
|
||||
|
||||
func (f fieldDescriptor) String() string {
|
||||
return f.schema + "." + f.table + "." + f.name
|
||||
}
|
||||
|
||||
type ignoreCaseFieldDescriptor struct {
|
||||
fieldDescriptor
|
||||
fieldNameSuffix string
|
||||
}
|
||||
|
||||
func (f ignoreCaseFieldDescriptor) String() string {
|
||||
return f.fieldDescriptor.String() + f.fieldNameSuffix
|
||||
}
|
||||
|
||||
type textFieldDescriptor struct {
|
||||
field
|
||||
isIgnoreCase bool
|
||||
}
|
||||
|
||||
type clause[Op domain.Operation] struct {
|
||||
field field
|
||||
op Op
|
||||
}
|
||||
|
||||
const (
|
||||
schema = "zitadel"
|
||||
userTable = "users"
|
||||
)
|
||||
|
||||
var userFields = map[domain.UserField]field{
|
||||
domain.UserFieldInstanceID: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "instance_id",
|
||||
},
|
||||
domain.UserFieldOrgID: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "org_id",
|
||||
},
|
||||
domain.UserFieldID: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "id",
|
||||
},
|
||||
domain.UserFieldUsername: textFieldDescriptor{
|
||||
field: ignoreCaseFieldDescriptor{
|
||||
fieldDescriptor: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "username",
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
},
|
||||
domain.UserHumanFieldEmail: textFieldDescriptor{
|
||||
field: ignoreCaseFieldDescriptor{
|
||||
fieldDescriptor: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "email",
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
},
|
||||
domain.UserHumanFieldEmailVerified: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "email_is_verified",
|
||||
},
|
||||
}
|
||||
|
||||
type textClause[V domain.Text] struct {
|
||||
clause[domain.TextOperation]
|
||||
value V
|
||||
}
|
||||
|
||||
var textOp map[domain.TextOperation]string = map[domain.TextOperation]string{
|
||||
domain.TextOperationEqual: " = ",
|
||||
domain.TextOperationNotEqual: " <> ",
|
||||
domain.TextOperationStartsWith: " LIKE ",
|
||||
domain.TextOperationStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
func (tc textClause[V]) Write(stmt *statement) {
|
||||
placeholder := stmt.appendArg(tc.value)
|
||||
var (
|
||||
left, right string
|
||||
)
|
||||
switch tc.clause.op {
|
||||
case domain.TextOperationEqual:
|
||||
left = tc.clause.field.String()
|
||||
right = placeholder
|
||||
case domain.TextOperationNotEqual:
|
||||
left = tc.clause.field.String()
|
||||
right = placeholder
|
||||
case domain.TextOperationStartsWith:
|
||||
left = tc.clause.field.String()
|
||||
right = placeholder + "%"
|
||||
case domain.TextOperationStartsWithIgnoreCase:
|
||||
left = tc.clause.field.String()
|
||||
if _, ok := tc.clause.field.(ignoreCaseFieldDescriptor); !ok {
|
||||
left = "LOWER(" + left + ")"
|
||||
}
|
||||
right = "LOWER(" + placeholder + "%)"
|
||||
}
|
||||
|
||||
stmt.builder.WriteString(left)
|
||||
stmt.builder.WriteString(textOp[tc.clause.op])
|
||||
stmt.builder.WriteString(right)
|
||||
}
|
||||
|
||||
type boolClause[V domain.Bool] struct {
|
||||
clause[domain.BoolOperation]
|
||||
value V
|
||||
}
|
||||
|
||||
func (bc boolClause[V]) Write(stmt *statement) {
|
||||
if !bc.value {
|
||||
stmt.builder.WriteString("NOT ")
|
||||
}
|
||||
stmt.builder.WriteString(bc.clause.field.String())
|
||||
}
|
||||
|
||||
type numberClause[V domain.Number] struct {
|
||||
clause[domain.NumberOperation]
|
||||
value V
|
||||
}
|
||||
|
||||
var numberOp map[domain.NumberOperation]string = map[domain.NumberOperation]string{
|
||||
domain.NumberOperationEqual: " = ",
|
||||
domain.NumberOperationNotEqual: " <> ",
|
||||
domain.NumberOperationLessThan: " < ",
|
||||
domain.NumberOperationLessThanOrEqual: " <= ",
|
||||
domain.NumberOperationGreaterThan: " > ",
|
||||
domain.NumberOperationGreaterThanOrEqual: " >= ",
|
||||
}
|
||||
|
||||
func (nc numberClause[V]) Write(stmt *statement) {
|
||||
stmt.builder.WriteString(nc.clause.field.String())
|
||||
stmt.builder.WriteString(numberOp[nc.clause.op])
|
||||
stmt.builder.WriteString(stmt.appendArg(nc.value))
|
||||
}
|
45
backend/v3/storage/database/repository/crypto.go
Normal file
45
backend/v3/storage/database/repository/crypto.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
type cryptoRepo struct {
|
||||
database.QueryExecutor
|
||||
}
|
||||
|
||||
func Crypto(db database.QueryExecutor) domain.CryptoRepository {
|
||||
return &cryptoRepo{
|
||||
QueryExecutor: db,
|
||||
}
|
||||
}
|
||||
|
||||
const getEncryptionConfigQuery = "SELECT" +
|
||||
" length" +
|
||||
", expiry" +
|
||||
", should_include_lower_letters" +
|
||||
", should_include_upper_letters" +
|
||||
", should_include_digits" +
|
||||
", should_include_symbols" +
|
||||
" FROM encryption_config"
|
||||
|
||||
func (repo *cryptoRepo) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) {
|
||||
var config crypto.GeneratorConfig
|
||||
row := repo.QueryRow(ctx, getEncryptionConfigQuery)
|
||||
err := row.Scan(
|
||||
&config.Length,
|
||||
&config.Expiry,
|
||||
&config.IncludeLowerLetters,
|
||||
&config.IncludeUpperLetters,
|
||||
&config.IncludeDigits,
|
||||
&config.IncludeSymbols,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
7
backend/v3/storage/database/repository/doc.go
Normal file
7
backend/v3/storage/database/repository/doc.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Repository package provides the database repository for the application.
|
||||
// It contains the implementation of the [repository pattern](https://martinfowler.com/eaaCatalog/repository.html) for the database.
|
||||
|
||||
// funcs which need to interact with the database should create interfaces which are implemented by the
|
||||
// [query] and [exec] structs respectively their factory methods [Query] and [Execute]. The [query] struct is used for read operations, while the [exec] struct is used for write operations.
|
||||
|
||||
package repository
|
54
backend/v3/storage/database/repository/instance.go
Normal file
54
backend/v3/storage/database/repository/instance.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type instance struct {
|
||||
database.QueryExecutor
|
||||
}
|
||||
|
||||
func Instance(client database.QueryExecutor) domain.InstanceRepository {
|
||||
return &instance{QueryExecutor: client}
|
||||
}
|
||||
|
||||
func (i *instance) ByID(ctx context.Context, id string) (*domain.Instance, error) {
|
||||
var instance domain.Instance
|
||||
err := i.QueryExecutor.QueryRow(ctx, `SELECT id, name, created_at, updated_at, deleted_at FROM instances WHERE id = $1`, id).Scan(
|
||||
&instance.ID,
|
||||
&instance.Name,
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
&instance.DeletedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &instance, nil
|
||||
}
|
||||
|
||||
const createInstanceStmt = `INSERT INTO instances (id, name) VALUES ($1, $2) RETURNING created_at, updated_at`
|
||||
|
||||
// Create implements [domain.InstanceRepository].
|
||||
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
|
||||
return i.QueryExecutor.QueryRow(ctx, createInstanceStmt,
|
||||
instance.ID,
|
||||
instance.Name,
|
||||
).Scan(
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
)
|
||||
}
|
||||
|
||||
// On implements [domain.InstanceRepository].
|
||||
func (i *instance) On(id string) domain.InstanceOperation {
|
||||
return &instanceOperation{
|
||||
QueryExecutor: i.QueryExecutor,
|
||||
id: id,
|
||||
}
|
||||
}
|
||||
|
||||
var _ domain.InstanceRepository = (*instance)(nil)
|
52
backend/v3/storage/database/repository/instance_operation.go
Normal file
52
backend/v3/storage/database/repository/instance_operation.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type instanceOperation struct {
|
||||
database.QueryExecutor
|
||||
id string
|
||||
}
|
||||
|
||||
const addInstanceAdminStmt = `INSERT INTO instance_admins (instance_id, user_id, roles) VALUES ($1, $2, $3)`
|
||||
|
||||
// AddAdmin implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) AddAdmin(ctx context.Context, userID string, roles []string) error {
|
||||
return i.QueryExecutor.Exec(ctx, addInstanceAdminStmt, i.id, userID, roles)
|
||||
}
|
||||
|
||||
// Delete implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) Delete(ctx context.Context) error {
|
||||
return i.QueryExecutor.Exec(ctx, `DELETE FROM instances WHERE id = $1`, i.id)
|
||||
}
|
||||
|
||||
const removeInstanceAdminStmt = `DELETE FROM instance_admins WHERE instance_id = $1 AND user_id = $2`
|
||||
|
||||
// RemoveAdmin implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) RemoveAdmin(ctx context.Context, userID string) error {
|
||||
return i.QueryExecutor.Exec(ctx, removeInstanceAdminStmt, i.id, userID)
|
||||
}
|
||||
|
||||
const setInstanceAdminRolesStmt = `UPDATE instance_admins SET roles = $1 WHERE instance_id = $2 AND user_id = $3`
|
||||
|
||||
// SetAdminRoles implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) SetAdminRoles(ctx context.Context, userID string, roles []string) error {
|
||||
return i.QueryExecutor.Exec(ctx, setInstanceAdminRolesStmt, roles, i.id, userID)
|
||||
}
|
||||
|
||||
const updateInstanceStmt = `UPDATE instances SET name = $1, updated_at = $2 WHERE id = $3 RETURNING updated_at`
|
||||
|
||||
// Update implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) Update(ctx context.Context, instance *domain.Instance) error {
|
||||
return i.QueryExecutor.QueryRow(ctx, updateInstanceStmt,
|
||||
instance.Name,
|
||||
instance.UpdatedAt,
|
||||
i.id,
|
||||
).Scan(&instance.UpdatedAt)
|
||||
}
|
||||
|
||||
var _ domain.InstanceOperation = (*instanceOperation)(nil)
|
17
backend/v3/storage/database/repository/query.go
Normal file
17
backend/v3/storage/database/repository/query.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type query struct{ database.Querier }
|
||||
|
||||
func Query(querier database.Querier) *query {
|
||||
return &query{Querier: querier}
|
||||
}
|
||||
|
||||
type executor struct{ database.Executor }
|
||||
|
||||
func Execute(exec database.Executor) *executor {
|
||||
return &executor{Executor: exec}
|
||||
}
|
21
backend/v3/storage/database/repository/statement.go
Normal file
21
backend/v3/storage/database/repository/statement.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package repository
|
||||
|
||||
import "strings"
|
||||
|
||||
type statement struct {
|
||||
builder strings.Builder
|
||||
args []any
|
||||
}
|
||||
|
||||
func (s *statement) appendArg(arg any) (placeholder string) {
|
||||
s.args = append(s.args, arg)
|
||||
return "$" + string(len(s.args))
|
||||
}
|
||||
|
||||
func (s *statement) appendArgs(args ...any) (placeholders []string) {
|
||||
placeholders = make([]string, len(args))
|
||||
for i, arg := range args {
|
||||
placeholders[i] = s.appendArg(arg)
|
||||
}
|
||||
return placeholders
|
||||
}
|
43
backend/v3/storage/database/repository/stmt/column.go
Normal file
43
backend/v3/storage/database/repository/stmt/column.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package stmt
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Column[T any] interface {
|
||||
fmt.Stringer
|
||||
statementApplier[T]
|
||||
scanner(t *T) any
|
||||
}
|
||||
|
||||
type columnDescriptor[T any] struct {
|
||||
name string
|
||||
scan func(*T) any
|
||||
}
|
||||
|
||||
func (cd columnDescriptor[T]) scanner(t *T) any {
|
||||
return cd.scan(t)
|
||||
}
|
||||
|
||||
// Apply implements [Column].
|
||||
func (f columnDescriptor[T]) Apply(stmt *statement[T]) {
|
||||
stmt.builder.WriteString(stmt.columnPrefix())
|
||||
stmt.builder.WriteString(f.String())
|
||||
}
|
||||
|
||||
// String implements [Column].
|
||||
func (f columnDescriptor[T]) String() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
var _ Column[any] = (*columnDescriptor[any])(nil)
|
||||
|
||||
type ignoreCaseColumnDescriptor[T any] struct {
|
||||
columnDescriptor[T]
|
||||
fieldNameSuffix string
|
||||
}
|
||||
|
||||
func (f ignoreCaseColumnDescriptor[T]) ApplyIgnoreCase(stmt *statement[T]) {
|
||||
stmt.builder.WriteString(f.String())
|
||||
stmt.builder.WriteString(f.fieldNameSuffix)
|
||||
}
|
||||
|
||||
var _ Column[any] = (*ignoreCaseColumnDescriptor[any])(nil)
|
97
backend/v3/storage/database/repository/stmt/condition.go
Normal file
97
backend/v3/storage/database/repository/stmt/condition.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package stmt
|
||||
|
||||
import "fmt"
|
||||
|
||||
type statementApplier[T any] interface {
|
||||
// Apply writes the statement to the builder.
|
||||
Apply(stmt *statement[T])
|
||||
}
|
||||
|
||||
type Condition[T any] interface {
|
||||
statementApplier[T]
|
||||
}
|
||||
|
||||
type op interface {
|
||||
TextOperation | NumberOperation | ListOperation
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
type operation[T any, O op] struct {
|
||||
o O
|
||||
}
|
||||
|
||||
func (o operation[T, O]) String() string {
|
||||
return o.o.String()
|
||||
}
|
||||
|
||||
func (o operation[T, O]) Apply(stmt *statement[T]) {
|
||||
stmt.builder.WriteString(o.o.String())
|
||||
}
|
||||
|
||||
type condition[V, T any, OP op] struct {
|
||||
field Column[T]
|
||||
op OP
|
||||
value V
|
||||
}
|
||||
|
||||
func (c *condition[V, T, OP]) Apply(stmt *statement[T]) {
|
||||
// placeholder := stmt.appendArg(c.value)
|
||||
stmt.builder.WriteString(stmt.columnPrefix())
|
||||
stmt.builder.WriteString(c.field.String())
|
||||
// stmt.builder.WriteString(c.op)
|
||||
// stmt.builder.WriteString(placeholder)
|
||||
}
|
||||
|
||||
type and[T any] struct {
|
||||
conditions []Condition[T]
|
||||
}
|
||||
|
||||
func And[T any](conditions ...Condition[T]) *and[T] {
|
||||
return &and[T]{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply implements [Condition].
|
||||
func (a *and[T]) Apply(stmt *statement[T]) {
|
||||
if len(a.conditions) > 1 {
|
||||
stmt.builder.WriteString("(")
|
||||
defer stmt.builder.WriteString(")")
|
||||
}
|
||||
|
||||
for i, condition := range a.conditions {
|
||||
if i > 0 {
|
||||
stmt.builder.WriteString(" AND ")
|
||||
}
|
||||
condition.Apply(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition[any] = (*and[any])(nil)
|
||||
|
||||
type or[T any] struct {
|
||||
conditions []Condition[T]
|
||||
}
|
||||
|
||||
func Or[T any](conditions ...Condition[T]) *or[T] {
|
||||
return &or[T]{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply implements [Condition].
|
||||
func (o *or[T]) Apply(stmt *statement[T]) {
|
||||
if len(o.conditions) > 1 {
|
||||
stmt.builder.WriteString("(")
|
||||
defer stmt.builder.WriteString(")")
|
||||
}
|
||||
|
||||
for i, condition := range o.conditions {
|
||||
if i > 0 {
|
||||
stmt.builder.WriteString(" OR ")
|
||||
}
|
||||
condition.Apply(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition[any] = (*or[any])(nil)
|
71
backend/v3/storage/database/repository/stmt/list.go
Normal file
71
backend/v3/storage/database/repository/stmt/list.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package stmt
|
||||
|
||||
type ListEntry interface {
|
||||
Number | Text | any
|
||||
}
|
||||
|
||||
type ListCondition[E ListEntry, T any] struct {
|
||||
condition[[]E, T, ListOperation]
|
||||
}
|
||||
|
||||
func (lc *ListCondition[E, T]) Apply(stmt *statement[T]) {
|
||||
placeholder := stmt.appendArg(lc.value)
|
||||
|
||||
switch lc.op {
|
||||
case ListOperationEqual, ListOperationNotEqual:
|
||||
lc.field.Apply(stmt)
|
||||
operation[T, ListOperation]{lc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
case ListOperationContainsAny, ListOperationContainsAll:
|
||||
lc.field.Apply(stmt)
|
||||
operation[T, ListOperation]{lc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
case ListOperationNotContainsAny, ListOperationNotContainsAll:
|
||||
stmt.builder.WriteString("NOT (")
|
||||
lc.field.Apply(stmt)
|
||||
operation[T, ListOperation]{lc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString(")")
|
||||
default:
|
||||
panic("unknown list operation")
|
||||
}
|
||||
}
|
||||
|
||||
type ListOperation uint8
|
||||
|
||||
const (
|
||||
// ListOperationEqual checks if the arrays are equal including the order of the elements
|
||||
ListOperationEqual ListOperation = iota + 1
|
||||
// ListOperationNotEqual checks if the arrays are not equal including the order of the elements
|
||||
ListOperationNotEqual
|
||||
|
||||
// ListOperationContains checks if the array column contains all the values of the specified array
|
||||
ListOperationContainsAll
|
||||
// ListOperationContainsAny checks if the arrays have at least one value in common
|
||||
ListOperationContainsAny
|
||||
// ListOperationContainsAll checks if the array column contains all the values of the specified array
|
||||
|
||||
// ListOperationNotContainsAll checks if the specified array is not contained by the column
|
||||
ListOperationNotContainsAll
|
||||
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
|
||||
ListOperationNotContainsAny
|
||||
)
|
||||
|
||||
var listOperations = map[ListOperation]string{
|
||||
// ListOperationEqual checks if the lists are equal
|
||||
ListOperationEqual: " = ",
|
||||
// ListOperationNotEqual checks if the lists are not equal
|
||||
ListOperationNotEqual: " <> ",
|
||||
// ListOperationContainsAny checks if the arrays have at least one value in common
|
||||
ListOperationContainsAny: " && ",
|
||||
// ListOperationContainsAll checks if the array column contains all the values of the specified array
|
||||
ListOperationContainsAll: " @> ",
|
||||
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
|
||||
ListOperationNotContainsAny: " && ", // Base operator for NOT (A && B)
|
||||
// ListOperationNotContainsAll checks if the array column is not contained by the specified array
|
||||
ListOperationNotContainsAll: " <@ ", // Base operator for NOT (A <@ B)
|
||||
}
|
||||
|
||||
func (lo ListOperation) String() string {
|
||||
return listOperations[lo]
|
||||
}
|
61
backend/v3/storage/database/repository/stmt/number.go
Normal file
61
backend/v3/storage/database/repository/stmt/number.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package stmt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
|
||||
}
|
||||
|
||||
type between[N Number] struct {
|
||||
min, max N
|
||||
}
|
||||
|
||||
type NumberBetween[V Number, T any] struct {
|
||||
condition[between[V], T, NumberOperation]
|
||||
}
|
||||
|
||||
func (nb *NumberBetween[V, T]) Apply(stmt *statement[T]) {
|
||||
nb.field.Apply(stmt)
|
||||
stmt.builder.WriteString(" BETWEEN ")
|
||||
stmt.builder.WriteString(stmt.appendArg(nb.value.min))
|
||||
stmt.builder.WriteString(" AND ")
|
||||
stmt.builder.WriteString(stmt.appendArg(nb.value.max))
|
||||
}
|
||||
|
||||
type NumberCondition[V Number, T any] struct {
|
||||
condition[V, T, NumberOperation]
|
||||
}
|
||||
|
||||
func (nc *NumberCondition[V, T]) Apply(stmt *statement[T]) {
|
||||
nc.field.Apply(stmt)
|
||||
operation[T, NumberOperation]{nc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(stmt.appendArg(nc.value))
|
||||
}
|
||||
|
||||
type NumberOperation uint8
|
||||
|
||||
const (
|
||||
NumberOperationEqual NumberOperation = iota + 1
|
||||
NumberOperationNotEqual
|
||||
NumberOperationLessThan
|
||||
NumberOperationLessThanOrEqual
|
||||
NumberOperationGreaterThan
|
||||
NumberOperationGreaterThanOrEqual
|
||||
)
|
||||
|
||||
var numberOperations = map[NumberOperation]string{
|
||||
NumberOperationEqual: " = ",
|
||||
NumberOperationNotEqual: " <> ",
|
||||
NumberOperationLessThan: " < ",
|
||||
NumberOperationLessThanOrEqual: " <= ",
|
||||
NumberOperationGreaterThan: " > ",
|
||||
NumberOperationGreaterThanOrEqual: " >= ",
|
||||
}
|
||||
|
||||
func (no NumberOperation) String() string {
|
||||
return numberOperations[no]
|
||||
}
|
104
backend/v3/storage/database/repository/stmt/statement.go
Normal file
104
backend/v3/storage/database/repository/stmt/statement.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package stmt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type statement[T any] struct {
|
||||
builder strings.Builder
|
||||
client database.QueryExecutor
|
||||
|
||||
columns []Column[T]
|
||||
|
||||
schema string
|
||||
table string
|
||||
alias string
|
||||
|
||||
condition Condition[T]
|
||||
|
||||
limit uint32
|
||||
offset uint32
|
||||
// order by fieldname and sort direction false for asc true for desc
|
||||
// orderBy SortingColumns[C]
|
||||
args []any
|
||||
existingArgs map[any]string
|
||||
}
|
||||
|
||||
func (s *statement[T]) scanners(t *T) []any {
|
||||
scanners := make([]any, len(s.columns))
|
||||
for i, column := range s.columns {
|
||||
scanners[i] = column.scanner(t)
|
||||
}
|
||||
return scanners
|
||||
}
|
||||
|
||||
func (s *statement[T]) query() string {
|
||||
s.builder.WriteString(`SELECT `)
|
||||
for i, column := range s.columns {
|
||||
if i > 0 {
|
||||
s.builder.WriteString(", ")
|
||||
}
|
||||
column.Apply(s)
|
||||
}
|
||||
s.builder.WriteString(` FROM `)
|
||||
s.builder.WriteString(s.schema)
|
||||
s.builder.WriteRune('.')
|
||||
s.builder.WriteString(s.table)
|
||||
if s.alias != "" {
|
||||
s.builder.WriteString(" AS ")
|
||||
s.builder.WriteString(s.alias)
|
||||
}
|
||||
|
||||
s.builder.WriteString(` WHERE `)
|
||||
|
||||
s.condition.Apply(s)
|
||||
|
||||
if s.limit > 0 {
|
||||
s.builder.WriteString(` LIMIT `)
|
||||
s.builder.WriteString(s.appendArg(s.limit))
|
||||
}
|
||||
if s.offset > 0 {
|
||||
s.builder.WriteString(` OFFSET `)
|
||||
s.builder.WriteString(s.appendArg(s.offset))
|
||||
}
|
||||
|
||||
return s.builder.String()
|
||||
}
|
||||
|
||||
// func (s *statement[T]) Where(condition Condition[T]) *statement[T] {
|
||||
// s.condition = condition
|
||||
// return s
|
||||
// }
|
||||
|
||||
// func (s *statement[T]) Limit(limit uint32) *statement[T] {
|
||||
// s.limit = limit
|
||||
// return s
|
||||
// }
|
||||
|
||||
// func (s *statement[T]) Offset(offset uint32) *statement[T] {
|
||||
// s.offset = offset
|
||||
// return s
|
||||
// }
|
||||
|
||||
func (s *statement[T]) columnPrefix() string {
|
||||
if s.alias != "" {
|
||||
return s.alias + "."
|
||||
}
|
||||
return s.schema + "." + s.table + "."
|
||||
}
|
||||
|
||||
func (s *statement[T]) appendArg(arg any) string {
|
||||
if s.existingArgs == nil {
|
||||
s.existingArgs = make(map[any]string)
|
||||
}
|
||||
if existing, ok := s.existingArgs[arg]; ok {
|
||||
return existing
|
||||
}
|
||||
s.args = append(s.args, arg)
|
||||
placeholder := fmt.Sprintf("$%d", len(s.args))
|
||||
s.existingArgs[arg] = placeholder
|
||||
return placeholder
|
||||
}
|
18
backend/v3/storage/database/repository/stmt/stmt_test.go
Normal file
18
backend/v3/storage/database/repository/stmt/stmt_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package stmt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt"
|
||||
)
|
||||
|
||||
func Test_Bla(t *testing.T) {
|
||||
stmt.User(nil).Where(
|
||||
stmt.Or(
|
||||
stmt.UserIDCondition("123"),
|
||||
stmt.UserIDCondition("123"),
|
||||
stmt.UserUsernameCondition(stmt.TextOperationEqualIgnoreCase, "test"),
|
||||
),
|
||||
).Limit(1).Offset(1).Get(context.Background())
|
||||
}
|
72
backend/v3/storage/database/repository/stmt/text.go
Normal file
72
backend/v3/storage/database/repository/stmt/text.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package stmt
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
}
|
||||
|
||||
type TextCondition[V Text, T any] struct {
|
||||
condition[V, T, TextOperation]
|
||||
}
|
||||
|
||||
func (tc *TextCondition[V, T]) Apply(stmt *statement[T]) {
|
||||
placeholder := stmt.appendArg(tc.value)
|
||||
|
||||
switch tc.op {
|
||||
case TextOperationEqual, TextOperationNotEqual:
|
||||
tc.field.Apply(stmt)
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
case TextOperationEqualIgnoreCase:
|
||||
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
|
||||
desc.ApplyIgnoreCase(stmt)
|
||||
} else {
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
tc.field.Apply(stmt)
|
||||
stmt.builder.WriteString(")")
|
||||
}
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString(")")
|
||||
case TextOperationStartsWith:
|
||||
tc.field.Apply(stmt)
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString("|| '%'")
|
||||
case TextOperationStartsWithIgnoreCase:
|
||||
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
|
||||
desc.ApplyIgnoreCase(stmt)
|
||||
} else {
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
tc.field.Apply(stmt)
|
||||
stmt.builder.WriteString(")")
|
||||
}
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString(")")
|
||||
stmt.builder.WriteString("|| '%'")
|
||||
}
|
||||
}
|
||||
|
||||
type TextOperation uint8
|
||||
|
||||
const (
|
||||
TextOperationEqual TextOperation = iota + 1
|
||||
TextOperationEqualIgnoreCase
|
||||
TextOperationNotEqual
|
||||
TextOperationStartsWith
|
||||
TextOperationStartsWithIgnoreCase
|
||||
)
|
||||
|
||||
var textOperations = map[TextOperation]string{
|
||||
TextOperationEqual: " = ",
|
||||
TextOperationEqualIgnoreCase: " = ",
|
||||
TextOperationNotEqual: " <> ",
|
||||
TextOperationStartsWith: " LIKE ",
|
||||
TextOperationStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
func (to TextOperation) String() string {
|
||||
return textOperations[to]
|
||||
}
|
193
backend/v3/storage/database/repository/stmt/user.go
Normal file
193
backend/v3/storage/database/repository/stmt/user.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package stmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type userStatement struct {
|
||||
statement[domain.User]
|
||||
}
|
||||
|
||||
func User(client database.QueryExecutor) *userStatement {
|
||||
return &userStatement{
|
||||
statement: statement[domain.User]{
|
||||
schema: "zitadel",
|
||||
table: "users",
|
||||
alias: "u",
|
||||
client: client,
|
||||
columns: []Column[domain.User]{
|
||||
userColumns[UserInstanceID],
|
||||
userColumns[UserOrgID],
|
||||
userColumns[UserColumnID],
|
||||
userColumns[UserColumnUsername],
|
||||
userColumns[UserCreatedAt],
|
||||
userColumns[UserUpdatedAt],
|
||||
userColumns[UserDeletedAt],
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userStatement) Where(condition Condition[domain.User]) *userStatement {
|
||||
s.condition = condition
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *userStatement) Limit(limit uint32) *userStatement {
|
||||
s.limit = limit
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *userStatement) Offset(offset uint32) *userStatement {
|
||||
s.offset = offset
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *userStatement) Get(ctx context.Context) (*domain.User, error) {
|
||||
var user domain.User
|
||||
err := s.client.QueryRow(ctx, s.query(), s.statement.args...).Scan(s.scanners(&user)...)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *userStatement) List(ctx context.Context) ([]*domain.User, error) {
|
||||
var users []*domain.User
|
||||
rows, err := s.client.Query(ctx, s.query(), s.statement.args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var user domain.User
|
||||
err = rows.Scan(s.scanners(&user)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, &user)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *userStatement) SetUsername(ctx context.Context, username string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserColumn uint8
|
||||
|
||||
var (
|
||||
userColumns map[UserColumn]Column[domain.User] = map[UserColumn]Column[domain.User]{
|
||||
UserInstanceID: columnDescriptor[domain.User]{
|
||||
name: "instance_id",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.InstanceID
|
||||
},
|
||||
},
|
||||
UserOrgID: columnDescriptor[domain.User]{
|
||||
name: "org_id",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.OrgID
|
||||
},
|
||||
},
|
||||
UserColumnID: columnDescriptor[domain.User]{
|
||||
name: "id",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.ID
|
||||
},
|
||||
},
|
||||
UserColumnUsername: ignoreCaseColumnDescriptor[domain.User]{
|
||||
columnDescriptor: columnDescriptor[domain.User]{
|
||||
name: "username",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.Username
|
||||
},
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
UserCreatedAt: columnDescriptor[domain.User]{
|
||||
name: "created_at",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.CreatedAt
|
||||
},
|
||||
},
|
||||
UserUpdatedAt: columnDescriptor[domain.User]{
|
||||
name: "updated_at",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.UpdatedAt
|
||||
},
|
||||
},
|
||||
UserDeletedAt: columnDescriptor[domain.User]{
|
||||
name: "deleted_at",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.DeletedAt
|
||||
},
|
||||
},
|
||||
}
|
||||
humanColumns = map[UserColumn]Column[domain.User]{
|
||||
UserHumanColumnEmail: ignoreCaseColumnDescriptor[domain.User]{
|
||||
columnDescriptor: columnDescriptor[domain.User]{
|
||||
name: "email",
|
||||
scan: func(u *domain.User) any {
|
||||
human, ok := u.Traits.(*domain.Human)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if human.Email == nil {
|
||||
human.Email = new(domain.Email)
|
||||
}
|
||||
return &human.Email.Address
|
||||
},
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
UserHumanColumnEmailVerified: columnDescriptor[domain.User]{
|
||||
name: "email_is_verified",
|
||||
scan: func(u *domain.User) any {
|
||||
human, ok := u.Traits.(*domain.Human)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if human.Email == nil {
|
||||
human.Email = new(domain.Email)
|
||||
}
|
||||
return &human.Email.IsVerified
|
||||
},
|
||||
},
|
||||
}
|
||||
machineColumns = map[UserColumn]Column[domain.User]{
|
||||
UserMachineDescription: columnDescriptor[domain.User]{
|
||||
name: "description",
|
||||
scan: func(u *domain.User) any {
|
||||
machine, ok := u.Traits.(*domain.Machine)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if machine == nil {
|
||||
machine = new(domain.Machine)
|
||||
}
|
||||
return &machine.Description
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
UserInstanceID UserColumn = iota + 1
|
||||
UserOrgID
|
||||
UserColumnID
|
||||
UserColumnUsername
|
||||
UserHumanColumnEmail
|
||||
UserHumanColumnEmailVerified
|
||||
UserMachineDescription
|
||||
UserCreatedAt
|
||||
UserUpdatedAt
|
||||
UserDeletedAt
|
||||
)
|
@@ -0,0 +1,23 @@
|
||||
package stmt
|
||||
|
||||
import "github.com/zitadel/zitadel/backend/v3/domain"
|
||||
|
||||
func UserIDCondition(id string) *TextCondition[string, domain.User] {
|
||||
return &TextCondition[string, domain.User]{
|
||||
condition: condition[string, domain.User, TextOperation]{
|
||||
field: userColumns[UserColumnID],
|
||||
op: TextOperationEqual,
|
||||
value: id,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func UserUsernameCondition(op TextOperation, username string) *TextCondition[string, domain.User] {
|
||||
return &TextCondition[string, domain.User]{
|
||||
condition: condition[string, domain.User, TextOperation]{
|
||||
field: userColumns[UserColumnUsername],
|
||||
op: op,
|
||||
value: username,
|
||||
},
|
||||
}
|
||||
}
|
135
backend/v3/storage/database/repository/stmt/v2/table.go
Normal file
135
backend/v3/storage/database/repository/stmt/v2/table.go
Normal file
@@ -0,0 +1,135 @@
|
||||
package stmt
|
||||
|
||||
// type table struct {
|
||||
// schema string
|
||||
// name string
|
||||
|
||||
// possibleJoins []*join
|
||||
|
||||
// columns []*col
|
||||
// }
|
||||
|
||||
// type col struct {
|
||||
// *table
|
||||
|
||||
// name string
|
||||
// }
|
||||
|
||||
// type join struct {
|
||||
// *table
|
||||
|
||||
// on []*joinColumns
|
||||
// }
|
||||
|
||||
// type joinColumns struct {
|
||||
// left, right *col
|
||||
// }
|
||||
|
||||
// var (
|
||||
// userTable = &table{
|
||||
// schema: "zitadel",
|
||||
// name: "users",
|
||||
// }
|
||||
// userColumns = []*col{
|
||||
// userInstanceIDColumn,
|
||||
// userOrgIDColumn,
|
||||
// userIDColumn,
|
||||
// userUsernameColumn,
|
||||
// }
|
||||
// userInstanceIDColumn = &col{
|
||||
// table: userTable,
|
||||
// name: "instance_id",
|
||||
// }
|
||||
// userOrgIDColumn = &col{
|
||||
// table: userTable,
|
||||
// name: "org_id",
|
||||
// }
|
||||
// userIDColumn = &col{
|
||||
// table: userTable,
|
||||
// name: "id",
|
||||
// }
|
||||
// userUsernameColumn = &col{
|
||||
// table: userTable,
|
||||
// name: "username",
|
||||
// }
|
||||
// userJoins = []*join{
|
||||
// {
|
||||
// table: instanceTable,
|
||||
// on: []*joinColumns{
|
||||
// {
|
||||
// left: instanceIDColumn,
|
||||
// right: userInstanceIDColumn,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// {
|
||||
// table: orgTable,
|
||||
// on: []*joinColumns{
|
||||
// {
|
||||
// left: orgIDColumn,
|
||||
// right: userOrgIDColumn,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// )
|
||||
|
||||
// var (
|
||||
// instanceTable = &table{
|
||||
// schema: "zitadel",
|
||||
// name: "instances",
|
||||
// }
|
||||
// instanceColumns = []*col{
|
||||
// instanceIDColumn,
|
||||
// instanceNameColumn,
|
||||
// }
|
||||
// instanceIDColumn = &col{
|
||||
// table: instanceTable,
|
||||
// name: "id",
|
||||
// }
|
||||
// instanceNameColumn = &col{
|
||||
// table: instanceTable,
|
||||
// name: "name",
|
||||
// }
|
||||
// )
|
||||
|
||||
// var (
|
||||
// orgTable = &table{
|
||||
// schema: "zitadel",
|
||||
// name: "orgs",
|
||||
// }
|
||||
// orgColumns = []*col{
|
||||
// orgInstanceIDColumn,
|
||||
// orgIDColumn,
|
||||
// orgNameColumn,
|
||||
// }
|
||||
// orgInstanceIDColumn = &col{
|
||||
// table: orgTable,
|
||||
// name: "instance_id",
|
||||
// }
|
||||
// orgIDColumn = &col{
|
||||
// table: orgTable,
|
||||
// name: "id",
|
||||
// }
|
||||
// orgNameColumn = &col{
|
||||
// table: orgTable,
|
||||
// name: "name",
|
||||
// }
|
||||
// )
|
||||
|
||||
// func init() {
|
||||
// instanceTable.columns = instanceColumns
|
||||
// userTable.columns = userColumns
|
||||
|
||||
// userTable.possibleJoins = []join{
|
||||
// {
|
||||
// table: userTable,
|
||||
// on: []joinColumns{
|
||||
// {
|
||||
// left: userIDColumn,
|
||||
// right: userIDColumn,
|
||||
// },
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// }
|
55
backend/v3/storage/database/repository/stmt/v3/column.go
Normal file
55
backend/v3/storage/database/repository/stmt/v3/column.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package v3
|
||||
|
||||
type Column interface {
|
||||
Name() string
|
||||
Write(builder statementBuilder)
|
||||
}
|
||||
|
||||
type ignoreCaseColumn interface {
|
||||
Column
|
||||
WriteIgnoreCase(builder statementBuilder)
|
||||
}
|
||||
|
||||
var (
|
||||
columnNameID = "id"
|
||||
columnNameName = "name"
|
||||
columnNameCreatedAt = "created_at"
|
||||
columnNameUpdatedAt = "updated_at"
|
||||
columnNameDeletedAt = "deleted_at"
|
||||
|
||||
columnNameInstanceID = "instance_id"
|
||||
|
||||
columnNameOrgID = "org_id"
|
||||
)
|
||||
|
||||
type column struct {
|
||||
table Table
|
||||
name string
|
||||
}
|
||||
|
||||
// Write implements Column.
|
||||
func (c *column) Write(builder statementBuilder) {
|
||||
c.table.writeOn(builder)
|
||||
builder.writeRune('.')
|
||||
builder.writeString(c.name)
|
||||
}
|
||||
|
||||
// Name implements [Column].
|
||||
func (c *column) Name() string {
|
||||
return c.name
|
||||
}
|
||||
|
||||
var _ Column = (*column)(nil)
|
||||
|
||||
type columnIgnoreCase struct {
|
||||
column
|
||||
suffix string
|
||||
}
|
||||
|
||||
// WriteIgnoreCase implements ignoreCaseColumn.
|
||||
func (c *columnIgnoreCase) WriteIgnoreCase(builder statementBuilder) {
|
||||
c.Write(builder)
|
||||
builder.writeString(c.suffix)
|
||||
}
|
||||
|
||||
var _ ignoreCaseColumn = (*columnIgnoreCase)(nil)
|
182
backend/v3/storage/database/repository/stmt/v3/condition.go
Normal file
182
backend/v3/storage/database/repository/stmt/v3/condition.go
Normal file
@@ -0,0 +1,182 @@
|
||||
package v3
|
||||
|
||||
type statementBuilder interface {
|
||||
write([]byte)
|
||||
writeString(string)
|
||||
writeRune(rune)
|
||||
|
||||
appendArg(any) (placeholder string)
|
||||
table() Table
|
||||
}
|
||||
|
||||
type Condition interface {
|
||||
writeOn(builder statementBuilder)
|
||||
}
|
||||
|
||||
type and struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
func And(conditions ...Condition) *and {
|
||||
return &and{conditions: conditions}
|
||||
}
|
||||
|
||||
// writeOn implements [Condition].
|
||||
func (a *and) writeOn(builder statementBuilder) {
|
||||
if len(a.conditions) > 1 {
|
||||
builder.writeString("(")
|
||||
defer builder.writeString(")")
|
||||
}
|
||||
|
||||
for i, condition := range a.conditions {
|
||||
if i > 0 {
|
||||
builder.writeString(" AND ")
|
||||
}
|
||||
condition.writeOn(builder)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition = (*and)(nil)
|
||||
|
||||
type or struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
func Or(conditions ...Condition) *or {
|
||||
return &or{conditions: conditions}
|
||||
}
|
||||
|
||||
// writeOn implements [Condition].
|
||||
func (o *or) writeOn(builder statementBuilder) {
|
||||
if len(o.conditions) > 1 {
|
||||
builder.writeString("(")
|
||||
defer builder.writeString(")")
|
||||
}
|
||||
|
||||
for i, condition := range o.conditions {
|
||||
if i > 0 {
|
||||
builder.writeString(" OR ")
|
||||
}
|
||||
condition.writeOn(builder)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition = (*or)(nil)
|
||||
|
||||
type isNull struct {
|
||||
column Column
|
||||
}
|
||||
|
||||
func IsNull(column Column) *isNull {
|
||||
return &isNull{column: column}
|
||||
}
|
||||
|
||||
// writeOn implements [Condition].
|
||||
func (cond *isNull) writeOn(builder statementBuilder) {
|
||||
cond.column.Write(builder)
|
||||
builder.writeString(" IS NULL")
|
||||
}
|
||||
|
||||
var _ Condition = (*isNull)(nil)
|
||||
|
||||
type isNotNull struct {
|
||||
column Column
|
||||
}
|
||||
|
||||
func IsNotNull(column Column) *isNotNull {
|
||||
return &isNotNull{column: column}
|
||||
}
|
||||
|
||||
// writeOn implements [Condition].
|
||||
func (cond *isNotNull) writeOn(builder statementBuilder) {
|
||||
cond.column.Write(builder)
|
||||
builder.writeString(" IS NOT NULL")
|
||||
}
|
||||
|
||||
var _ Condition = (*isNotNull)(nil)
|
||||
|
||||
type condition[Op Operator, V Value] struct {
|
||||
column Column
|
||||
operator Op
|
||||
value V
|
||||
}
|
||||
|
||||
// writeOn implements [Condition].
|
||||
func (cond condition[Op, V]) writeOn(builder statementBuilder) {
|
||||
cond.column.Write(builder)
|
||||
builder.writeString(cond.operator.String())
|
||||
builder.writeString(builder.appendArg(cond.value))
|
||||
}
|
||||
|
||||
var _ Condition = (*condition[TextOperator, string])(nil)
|
||||
|
||||
type textCondition[V Text] struct {
|
||||
condition[TextOperator, V]
|
||||
}
|
||||
|
||||
func NewTextCondition[V Text](column Column, operator TextOperator, value V) *textCondition[V] {
|
||||
return &textCondition[V]{
|
||||
condition: condition[TextOperator, V]{
|
||||
column: column,
|
||||
operator: operator,
|
||||
value: value,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// writeOn implements [Condition].
|
||||
func (cond *textCondition[V]) writeOn(builder statementBuilder) {
|
||||
switch cond.operator {
|
||||
case TextOperatorEqual, TextOperatorNotEqual:
|
||||
cond.column.Write(builder)
|
||||
builder.writeString(cond.operator.String())
|
||||
builder.writeString(builder.appendArg(cond.value))
|
||||
case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase:
|
||||
if col, ok := cond.column.(ignoreCaseColumn); ok {
|
||||
col.WriteIgnoreCase(builder)
|
||||
} else {
|
||||
builder.writeString("LOWER(")
|
||||
cond.column.Write(builder)
|
||||
builder.writeString(")")
|
||||
}
|
||||
builder.writeString(cond.operator.String())
|
||||
builder.writeString("LOWER(")
|
||||
builder.writeString(builder.appendArg(cond.value))
|
||||
builder.writeString(")")
|
||||
case TextOperatorStartsWith:
|
||||
cond.column.Write(builder)
|
||||
builder.writeString(cond.operator.String())
|
||||
builder.writeString(builder.appendArg(cond.value))
|
||||
builder.writeString(" || '%'")
|
||||
case TextOperatorStartsWithIgnoreCase:
|
||||
if col, ok := cond.column.(ignoreCaseColumn); ok {
|
||||
col.WriteIgnoreCase(builder)
|
||||
} else {
|
||||
builder.writeString("LOWER(")
|
||||
cond.column.Write(builder)
|
||||
builder.writeString(")")
|
||||
}
|
||||
builder.writeString(cond.operator.String())
|
||||
builder.writeString("LOWER(")
|
||||
builder.writeString(builder.appendArg(cond.value))
|
||||
builder.writeString(") || '%'")
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition = (*textCondition[string])(nil)
|
||||
|
||||
type numberCondition[V Number] struct {
|
||||
condition[NumberOperator, V]
|
||||
}
|
||||
|
||||
func NewNumberCondition[V Number](column Column, operator NumberOperator, value V) *numberCondition[V] {
|
||||
return &numberCondition[V]{
|
||||
condition: condition[NumberOperator, V]{
|
||||
column: column,
|
||||
operator: operator,
|
||||
value: value,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition = (*numberCondition[int])(nil)
|
104
backend/v3/storage/database/repository/stmt/v3/instance.go
Normal file
104
backend/v3/storage/database/repository/stmt/v3/instance.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package v3
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
id string
|
||||
name string
|
||||
|
||||
createdAt time.Time
|
||||
updatedAt time.Time
|
||||
deletedAt time.Time
|
||||
}
|
||||
|
||||
// Columns implements [object].
|
||||
func (Instance) Columns(table Table) []Column {
|
||||
return []Column{
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameID,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameName,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameCreatedAt,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameUpdatedAt,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameDeletedAt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Scan implements [object].
|
||||
func (i Instance) Scan(row database.Scanner) error {
|
||||
return row.Scan(
|
||||
&i.id,
|
||||
&i.name,
|
||||
&i.createdAt,
|
||||
&i.updatedAt,
|
||||
&i.deletedAt,
|
||||
)
|
||||
}
|
||||
|
||||
type instanceTable struct {
|
||||
*table
|
||||
}
|
||||
|
||||
func InstanceTable() *instanceTable {
|
||||
table := &instanceTable{
|
||||
table: newTable[Instance]("zitadel", "instances"),
|
||||
}
|
||||
|
||||
table.possibleJoins = func(t Table) map[string]Column {
|
||||
switch on := t.(type) {
|
||||
case *instanceTable:
|
||||
return map[string]Column{
|
||||
columnNameID: on.IDColumn(),
|
||||
}
|
||||
case *orgTable:
|
||||
return map[string]Column{
|
||||
columnNameID: on.InstanceIDColumn(),
|
||||
}
|
||||
case *userTable:
|
||||
return map[string]Column{
|
||||
columnNameID: on.InstanceIDColumn(),
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return table
|
||||
}
|
||||
|
||||
func (i *instanceTable) IDColumn() Column {
|
||||
return i.columns[columnNameID]
|
||||
}
|
||||
|
||||
func (i *instanceTable) NameColumn() Column {
|
||||
return i.columns[columnNameName]
|
||||
}
|
||||
|
||||
func (i *instanceTable) CreatedAtColumn() Column {
|
||||
return i.columns[columnNameCreatedAt]
|
||||
}
|
||||
|
||||
func (i *instanceTable) UpdatedAtColumn() Column {
|
||||
return i.columns[columnNameUpdatedAt]
|
||||
}
|
||||
|
||||
func (i *instanceTable) DeletedAtColumn() Column {
|
||||
return i.columns[columnNameDeletedAt]
|
||||
}
|
11
backend/v3/storage/database/repository/stmt/v3/join.go
Normal file
11
backend/v3/storage/database/repository/stmt/v3/join.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package v3
|
||||
|
||||
type join struct {
|
||||
table Table
|
||||
conditions []joinCondition
|
||||
}
|
||||
|
||||
type joinCondition struct {
|
||||
left Column
|
||||
right Column
|
||||
}
|
82
backend/v3/storage/database/repository/stmt/v3/operator.go
Normal file
82
backend/v3/storage/database/repository/stmt/v3/operator.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package v3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Value interface {
|
||||
Bool | Number | Text
|
||||
}
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
}
|
||||
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
|
||||
}
|
||||
|
||||
type Bool interface {
|
||||
~bool
|
||||
}
|
||||
|
||||
type Operator interface {
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
type TextOperator uint8
|
||||
|
||||
// String implements [Operator].
|
||||
func (t TextOperator) String() string {
|
||||
return textOperators[t]
|
||||
}
|
||||
|
||||
const (
|
||||
TextOperatorEqual TextOperator = iota + 1
|
||||
TextOperatorEqualIgnoreCase
|
||||
TextOperatorNotEqual
|
||||
TextOperatorNotEqualIgnoreCase
|
||||
TextOperatorStartsWith
|
||||
TextOperatorStartsWithIgnoreCase
|
||||
)
|
||||
|
||||
var textOperators = map[TextOperator]string{
|
||||
TextOperatorEqual: " = ",
|
||||
TextOperatorEqualIgnoreCase: " LIKE ",
|
||||
TextOperatorNotEqual: " <> ",
|
||||
TextOperatorNotEqualIgnoreCase: " NOT LIKE ",
|
||||
TextOperatorStartsWith: " LIKE ",
|
||||
TextOperatorStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
var _ Operator = TextOperator(0)
|
||||
|
||||
type NumberOperator uint8
|
||||
|
||||
// String implements Operator.
|
||||
func (n NumberOperator) String() string {
|
||||
return numberOperators[n]
|
||||
}
|
||||
|
||||
const (
|
||||
NumberOperatorEqual NumberOperator = iota + 1
|
||||
NumberOperatorNotEqual
|
||||
NumberOperatorLessThan
|
||||
NumberOperatorLessThanOrEqual
|
||||
NumberOperatorGreaterThan
|
||||
NumberOperatorGreaterThanOrEqual
|
||||
)
|
||||
|
||||
var numberOperators = map[NumberOperator]string{
|
||||
NumberOperatorEqual: " = ",
|
||||
NumberOperatorNotEqual: " <> ",
|
||||
NumberOperatorLessThan: " < ",
|
||||
NumberOperatorLessThanOrEqual: " <= ",
|
||||
NumberOperatorGreaterThan: " > ",
|
||||
NumberOperatorGreaterThanOrEqual: " >= ",
|
||||
}
|
||||
|
||||
var _ Operator = NumberOperator(0)
|
117
backend/v3/storage/database/repository/stmt/v3/org.go
Normal file
117
backend/v3/storage/database/repository/stmt/v3/org.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package v3
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Org struct {
|
||||
instanceID string
|
||||
id string
|
||||
|
||||
name string
|
||||
|
||||
createdAt time.Time
|
||||
updatedAt time.Time
|
||||
deletedAt time.Time
|
||||
}
|
||||
|
||||
// Columns implements [object].
|
||||
func (Org) Columns(table Table) []Column {
|
||||
return []Column{
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameInstanceID,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameID,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameName,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameCreatedAt,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameUpdatedAt,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameDeletedAt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Scan implements [object].
|
||||
func (o Org) Scan(row database.Scanner) error {
|
||||
return row.Scan(
|
||||
&o.instanceID,
|
||||
&o.id,
|
||||
&o.name,
|
||||
&o.createdAt,
|
||||
&o.updatedAt,
|
||||
&o.deletedAt,
|
||||
)
|
||||
}
|
||||
|
||||
type orgTable struct {
|
||||
*table
|
||||
}
|
||||
|
||||
func OrgTable() *orgTable {
|
||||
table := &orgTable{
|
||||
table: newTable[Org]("zitadel", "orgs"),
|
||||
}
|
||||
|
||||
table.possibleJoins = func(table Table) map[string]Column {
|
||||
switch on := table.(type) {
|
||||
case *instanceTable:
|
||||
return map[string]Column{
|
||||
columnNameInstanceID: on.IDColumn(),
|
||||
}
|
||||
case *orgTable:
|
||||
return map[string]Column{
|
||||
columnNameInstanceID: on.InstanceIDColumn(),
|
||||
columnNameID: on.IDColumn(),
|
||||
}
|
||||
case *userTable:
|
||||
return map[string]Column{
|
||||
columnNameInstanceID: on.InstanceIDColumn(),
|
||||
columnNameID: on.IDColumn(),
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return table
|
||||
}
|
||||
|
||||
func (o *orgTable) InstanceIDColumn() Column {
|
||||
return o.columns[columnNameInstanceID]
|
||||
}
|
||||
|
||||
func (o *orgTable) IDColumn() Column {
|
||||
return o.columns[columnNameID]
|
||||
}
|
||||
|
||||
func (o *orgTable) NameColumn() Column {
|
||||
return o.columns[columnNameName]
|
||||
}
|
||||
|
||||
func (o *orgTable) CreatedAtColumn() Column {
|
||||
return o.columns[columnNameCreatedAt]
|
||||
}
|
||||
|
||||
func (o *orgTable) UpdatedAtColumn() Column {
|
||||
return o.columns[columnNameUpdatedAt]
|
||||
}
|
||||
|
||||
func (o *orgTable) DeletedAtColumn() Column {
|
||||
return o.columns[columnNameDeletedAt]
|
||||
}
|
188
backend/v3/storage/database/repository/stmt/v3/query.go
Normal file
188
backend/v3/storage/database/repository/stmt/v3/query.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package v3
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Query[O object] interface {
|
||||
Where(condition Condition)
|
||||
Join(tables ...Table)
|
||||
Limit(limit uint32)
|
||||
Offset(offset uint32)
|
||||
OrderBy(columns ...Column)
|
||||
|
||||
Result(ctx context.Context, client database.Querier) (*O, error)
|
||||
Results(ctx context.Context, client database.Querier) ([]O, error)
|
||||
|
||||
fmt.Stringer
|
||||
statementBuilder
|
||||
}
|
||||
|
||||
type query[O object] struct {
|
||||
*statement[O]
|
||||
joins []join
|
||||
limit uint32
|
||||
offset uint32
|
||||
orderBy []Column
|
||||
}
|
||||
|
||||
func NewQuery[O object](table Table) Query[O] {
|
||||
return &query[O]{
|
||||
statement: newStatement[O](table),
|
||||
}
|
||||
}
|
||||
|
||||
// Result implements [Query].
|
||||
func (q *query[O]) Result(ctx context.Context, client database.Querier) (*O, error) {
|
||||
var object O
|
||||
row := client.QueryRow(ctx, q.String(), q.args...)
|
||||
if err := object.Scan(row); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &object, nil
|
||||
}
|
||||
|
||||
// Results implements [Query].
|
||||
func (q *query[O]) Results(ctx context.Context, client database.Querier) ([]O, error) {
|
||||
var objects []O
|
||||
rows, err := client.Query(ctx, q.String(), q.args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var object O
|
||||
if err := object.Scan(rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
objects = append(objects, object)
|
||||
}
|
||||
|
||||
return objects, rows.Err()
|
||||
}
|
||||
|
||||
// Join implements [Query].
|
||||
func (q *query[O]) Join(tables ...Table) {
|
||||
for _, tbl := range tables {
|
||||
cols := q.tbl.(*table).possibleJoins(tbl)
|
||||
if len(cols) == 0 {
|
||||
panic(fmt.Sprintf("table %q does not have any possible joins with table %q", q.tbl.Name(), tbl.Name()))
|
||||
}
|
||||
|
||||
q.joins = append(q.joins, join{
|
||||
table: tbl,
|
||||
conditions: make([]joinCondition, 0, len(cols)),
|
||||
})
|
||||
|
||||
for colName, col := range cols {
|
||||
q.joins[len(q.joins)-1].conditions = append(q.joins[len(q.joins)-1].conditions, joinCondition{
|
||||
left: q.tbl.(*table).columns[colName],
|
||||
right: col,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *query[O]) Limit(limit uint32) {
|
||||
q.limit = limit
|
||||
}
|
||||
|
||||
func (q *query[O]) Offset(offset uint32) {
|
||||
q.offset = offset
|
||||
}
|
||||
|
||||
func (q *query[O]) OrderBy(columns ...Column) {
|
||||
for _, allowedColumn := range q.columns {
|
||||
for _, column := range columns {
|
||||
if allowedColumn.Name() == column.Name() {
|
||||
q.orderBy = append(q.orderBy, column)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// String implements [fmt.Stringer] and [Query].
|
||||
func (q *query[O]) String() string {
|
||||
q.writeSelectColumns()
|
||||
q.writeFrom()
|
||||
q.writeJoins()
|
||||
q.writeCondition()
|
||||
q.writeOrderBy()
|
||||
q.writeLimit()
|
||||
q.writeOffset()
|
||||
q.writeGroupBy()
|
||||
return q.builder.String()
|
||||
}
|
||||
|
||||
func (q *query[O]) writeSelectColumns() {
|
||||
q.builder.WriteString("SELECT ")
|
||||
for i, column := range q.columns {
|
||||
if i > 0 {
|
||||
q.builder.WriteString(", ")
|
||||
}
|
||||
q.builder.WriteString(q.tbl.Alias())
|
||||
q.builder.WriteRune('.')
|
||||
q.builder.WriteString(column.Name())
|
||||
}
|
||||
}
|
||||
|
||||
func (q *query[O]) writeJoins() {
|
||||
for _, join := range q.joins {
|
||||
q.builder.WriteString(" JOIN ")
|
||||
q.builder.WriteString(join.table.Schema())
|
||||
q.builder.WriteRune('.')
|
||||
q.builder.WriteString(join.table.Name())
|
||||
if join.table.Alias() != "" {
|
||||
q.builder.WriteString(" AS ")
|
||||
q.builder.WriteString(join.table.Alias())
|
||||
}
|
||||
|
||||
q.builder.WriteString(" ON ")
|
||||
for i, condition := range join.conditions {
|
||||
if i > 0 {
|
||||
q.builder.WriteString(" AND ")
|
||||
}
|
||||
q.builder.WriteString(condition.left.Name())
|
||||
q.builder.WriteString(" = ")
|
||||
q.builder.WriteString(condition.right.Name())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (q *query[O]) writeOrderBy() {
|
||||
if len(q.orderBy) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
q.builder.WriteString(" ORDER BY ")
|
||||
for i, order := range q.orderBy {
|
||||
if i > 0 {
|
||||
q.builder.WriteString(", ")
|
||||
}
|
||||
order.Write(q)
|
||||
}
|
||||
}
|
||||
|
||||
func (q *query[O]) writeLimit() {
|
||||
if q.limit == 0 {
|
||||
return
|
||||
}
|
||||
q.builder.WriteString(" LIMIT ")
|
||||
q.builder.WriteString(q.appendArg(q.limit))
|
||||
}
|
||||
|
||||
func (q *query[O]) writeOffset() {
|
||||
if q.offset == 0 {
|
||||
return
|
||||
}
|
||||
q.builder.WriteString(" OFFSET ")
|
||||
q.builder.WriteString(q.appendArg(q.offset))
|
||||
}
|
||||
|
||||
func (q *query[O]) writeGroupBy() {
|
||||
q.builder.WriteString(" GROUP BY ")
|
||||
}
|
85
backend/v3/storage/database/repository/stmt/v3/statement.go
Normal file
85
backend/v3/storage/database/repository/stmt/v3/statement.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package v3
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type statement[T object] struct {
|
||||
tbl Table
|
||||
columns []Column
|
||||
condition Condition
|
||||
|
||||
builder strings.Builder
|
||||
args []any
|
||||
existingArgs map[any]string
|
||||
}
|
||||
|
||||
func newStatement[O object](t Table) *statement[O] {
|
||||
var o O
|
||||
return &statement[O]{
|
||||
tbl: t,
|
||||
columns: o.Columns(t),
|
||||
}
|
||||
}
|
||||
|
||||
// Where implements [Query].
|
||||
func (stmt *statement[T]) Where(condition Condition) {
|
||||
stmt.condition = condition
|
||||
}
|
||||
|
||||
func (stmt *statement[T]) writeFrom() {
|
||||
stmt.builder.WriteString(" FROM ")
|
||||
stmt.builder.WriteString(stmt.tbl.Schema())
|
||||
stmt.builder.WriteRune('.')
|
||||
stmt.builder.WriteString(stmt.tbl.Name())
|
||||
if stmt.tbl.Alias() != "" {
|
||||
stmt.builder.WriteString(" AS ")
|
||||
stmt.builder.WriteString(stmt.tbl.Alias())
|
||||
}
|
||||
}
|
||||
|
||||
func (stmt *statement[T]) writeCondition() {
|
||||
if stmt.condition == nil {
|
||||
return
|
||||
}
|
||||
stmt.builder.WriteString(" WHERE ")
|
||||
stmt.condition.writeOn(stmt)
|
||||
}
|
||||
|
||||
// appendArg implements [statementBuilder].
|
||||
func (stmt *statement[T]) appendArg(arg any) (placeholder string) {
|
||||
if stmt.existingArgs == nil {
|
||||
stmt.existingArgs = make(map[any]string)
|
||||
}
|
||||
if placeholder, ok := stmt.existingArgs[arg]; ok {
|
||||
return placeholder
|
||||
}
|
||||
|
||||
stmt.args = append(stmt.args, arg)
|
||||
placeholder = fmt.Sprintf("$%d", len(stmt.args))
|
||||
stmt.existingArgs[arg] = placeholder
|
||||
return placeholder
|
||||
}
|
||||
|
||||
// table implements [statementBuilder].
|
||||
func (stmt *statement[T]) table() Table {
|
||||
return stmt.tbl
|
||||
}
|
||||
|
||||
// write implements [statementBuilder].
|
||||
func (stmt *statement[T]) write(data []byte) {
|
||||
stmt.builder.Write(data)
|
||||
}
|
||||
|
||||
// writeRune implements [statementBuilder].
|
||||
func (stmt *statement[T]) writeRune(r rune) {
|
||||
stmt.builder.WriteRune(r)
|
||||
}
|
||||
|
||||
// writeString implements [statementBuilder].
|
||||
func (stmt *statement[T]) writeString(s string) {
|
||||
stmt.builder.WriteString(s)
|
||||
}
|
||||
|
||||
var _ statementBuilder = (*statement[Instance])(nil)
|
84
backend/v3/storage/database/repository/stmt/v3/table.go
Normal file
84
backend/v3/storage/database/repository/stmt/v3/table.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package v3
|
||||
|
||||
import "github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
|
||||
type object interface {
|
||||
User | Org | Instance
|
||||
Columns(t Table) []Column
|
||||
Scan(s database.Scanner) error
|
||||
}
|
||||
|
||||
type Table interface {
|
||||
Schema() string
|
||||
Name() string
|
||||
Alias() string
|
||||
Columns() []Column
|
||||
|
||||
writeOn(builder statementBuilder)
|
||||
}
|
||||
|
||||
type table struct {
|
||||
schema string
|
||||
name string
|
||||
alias string
|
||||
|
||||
possibleJoins func(table Table) map[string]Column
|
||||
|
||||
columns map[string]Column
|
||||
colList []Column
|
||||
}
|
||||
|
||||
func newTable[O object](schema, name string) *table {
|
||||
t := &table{
|
||||
schema: schema,
|
||||
name: name,
|
||||
}
|
||||
|
||||
var o O
|
||||
t.colList = o.Columns(t)
|
||||
t.columns = make(map[string]Column, len(t.colList))
|
||||
for _, col := range t.colList {
|
||||
t.columns[col.Name()] = col
|
||||
}
|
||||
|
||||
return t
|
||||
}
|
||||
|
||||
// Columns implements [Table].
|
||||
func (t *table) Columns() []Column {
|
||||
if len(t.colList) > 0 {
|
||||
return t.colList
|
||||
}
|
||||
|
||||
t.colList = make([]Column, 0, len(t.columns))
|
||||
for _, column := range t.columns {
|
||||
t.colList = append(t.colList, column)
|
||||
}
|
||||
|
||||
return t.colList
|
||||
}
|
||||
|
||||
// Name implements [Table].
|
||||
func (t *table) Name() string {
|
||||
return t.name
|
||||
}
|
||||
|
||||
// Schema implements [Table].
|
||||
func (t *table) Schema() string {
|
||||
return t.schema
|
||||
}
|
||||
|
||||
// Alias implements [Table].
|
||||
func (t *table) Alias() string {
|
||||
if t.alias != "" {
|
||||
return t.alias
|
||||
}
|
||||
return t.schema + "." + t.name
|
||||
}
|
||||
|
||||
// writeOn implements [Table].
|
||||
func (t *table) writeOn(builder statementBuilder) {
|
||||
builder.writeString(t.Alias())
|
||||
}
|
||||
|
||||
var _ Table = (*table)(nil)
|
170
backend/v3/storage/database/repository/stmt/v3/user.go
Normal file
170
backend/v3/storage/database/repository/stmt/v3/user.go
Normal file
@@ -0,0 +1,170 @@
|
||||
package v3
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
instanceID string
|
||||
orgID string
|
||||
id string
|
||||
username string
|
||||
|
||||
createdAt time.Time
|
||||
updatedAt time.Time
|
||||
deletedAt time.Time
|
||||
}
|
||||
|
||||
// Columns implements [object].
|
||||
func (u User) Columns(table Table) []Column {
|
||||
return []Column{
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameInstanceID,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameOrgID,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameID,
|
||||
},
|
||||
&columnIgnoreCase{
|
||||
column: column{
|
||||
table: table,
|
||||
name: userTableUsernameColumn,
|
||||
},
|
||||
suffix: "_lower",
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameCreatedAt,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameUpdatedAt,
|
||||
},
|
||||
&column{
|
||||
table: table,
|
||||
name: columnNameDeletedAt,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// Scan implements [object].
|
||||
func (u User) Scan(row database.Scanner) error {
|
||||
return row.Scan(
|
||||
&u.instanceID,
|
||||
&u.orgID,
|
||||
&u.id,
|
||||
&u.username,
|
||||
&u.createdAt,
|
||||
&u.updatedAt,
|
||||
&u.deletedAt,
|
||||
)
|
||||
}
|
||||
|
||||
type userTable struct {
|
||||
*table
|
||||
}
|
||||
|
||||
const (
|
||||
userTableUsernameColumn = "username"
|
||||
)
|
||||
|
||||
func UserTable() *userTable {
|
||||
table := &userTable{
|
||||
table: newTable[User]("zitadel", "users"),
|
||||
}
|
||||
|
||||
table.possibleJoins = func(table Table) map[string]Column {
|
||||
switch on := table.(type) {
|
||||
case *userTable:
|
||||
return map[string]Column{
|
||||
columnNameInstanceID: on.InstanceIDColumn(),
|
||||
columnNameOrgID: on.OrgIDColumn(),
|
||||
columnNameID: on.IDColumn(),
|
||||
}
|
||||
case *orgTable:
|
||||
return map[string]Column{
|
||||
columnNameInstanceID: on.InstanceIDColumn(),
|
||||
columnNameOrgID: on.IDColumn(),
|
||||
}
|
||||
case *instanceTable:
|
||||
return map[string]Column{
|
||||
columnNameInstanceID: on.IDColumn(),
|
||||
}
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return table
|
||||
}
|
||||
|
||||
func (t *userTable) InstanceIDColumn() Column {
|
||||
return t.columns[columnNameInstanceID]
|
||||
}
|
||||
|
||||
func (t *userTable) OrgIDColumn() Column {
|
||||
return t.columns[columnNameOrgID]
|
||||
}
|
||||
|
||||
func (t *userTable) IDColumn() Column {
|
||||
return t.columns[columnNameID]
|
||||
}
|
||||
|
||||
func (t *userTable) UsernameColumn() Column {
|
||||
return t.columns[userTableUsernameColumn]
|
||||
}
|
||||
|
||||
func (t *userTable) CreatedAtColumn() Column {
|
||||
return t.columns[columnNameCreatedAt]
|
||||
}
|
||||
|
||||
func (t *userTable) UpdatedAtColumn() Column {
|
||||
return t.columns[columnNameUpdatedAt]
|
||||
}
|
||||
|
||||
func (t *userTable) DeletedAtColumn() Column {
|
||||
return t.columns[columnNameDeletedAt]
|
||||
}
|
||||
|
||||
func NewUserQuery() Query[User] {
|
||||
q := NewQuery[User](UserTable())
|
||||
return q
|
||||
}
|
||||
|
||||
type userByIDCondition[T Text] struct {
|
||||
id T
|
||||
}
|
||||
|
||||
func UserByID[T Text](id T) Condition {
|
||||
return &userByIDCondition[T]{id: id}
|
||||
}
|
||||
|
||||
// writeOn implements Condition.
|
||||
func (u *userByIDCondition[T]) writeOn(builder statementBuilder) {
|
||||
NewTextCondition(builder.table().(*userTable).IDColumn(), TextOperatorEqual, u.id).writeOn(builder)
|
||||
}
|
||||
|
||||
var _ Condition = (*userByIDCondition[string])(nil)
|
||||
|
||||
type userByUsernameCondition[T Text] struct {
|
||||
username T
|
||||
operator TextOperator
|
||||
}
|
||||
|
||||
func UserByUsername[T Text](username T, operator TextOperator) Condition {
|
||||
return &userByUsernameCondition[T]{username: username, operator: operator}
|
||||
}
|
||||
|
||||
// writeOn implements Condition.
|
||||
func (u *userByUsernameCondition[T]) writeOn(builder statementBuilder) {
|
||||
NewTextCondition(builder.table().(*userTable).UsernameColumn(), u.operator, u.username).writeOn(builder)
|
||||
}
|
||||
|
||||
var _ Condition = (*userByUsernameCondition[string])(nil)
|
25
backend/v3/storage/database/repository/stmt/v3/user_test.go
Normal file
25
backend/v3/storage/database/repository/stmt/v3/user_test.go
Normal file
@@ -0,0 +1,25 @@
|
||||
package v3_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
v3 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v3"
|
||||
)
|
||||
|
||||
type user struct{}
|
||||
|
||||
func TestUser(t *testing.T) {
|
||||
query := v3.NewUserQuery()
|
||||
query.Where(
|
||||
v3.Or(
|
||||
v3.UserByID("123"),
|
||||
v3.UserByUsername("test", v3.TextOperatorStartsWithIgnoreCase),
|
||||
),
|
||||
)
|
||||
query.Limit(10)
|
||||
query.Offset(5)
|
||||
// query.OrderBy(
|
||||
|
||||
query.Result(context.TODO(), nil)
|
||||
}
|
78
backend/v3/storage/database/repository/stmt/v4/column.go
Normal file
78
backend/v3/storage/database/repository/stmt/v4/column.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package v4
|
||||
|
||||
type Change interface {
|
||||
Column
|
||||
}
|
||||
|
||||
type change[V Value] struct {
|
||||
column Column
|
||||
value V
|
||||
}
|
||||
|
||||
func newChange[V Value](col Column, value V) Change {
|
||||
return &change[V]{
|
||||
column: col,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func newUpdatePtrColumn[V Value](col Column, value *V) Change {
|
||||
if value == nil {
|
||||
return newChange(col, nullDBInstruction)
|
||||
}
|
||||
return newChange(col, *value)
|
||||
}
|
||||
|
||||
// writeTo implements [Change].
|
||||
func (c change[V]) writeTo(builder *statementBuilder) {
|
||||
c.column.writeTo(builder)
|
||||
builder.WriteString(" = ")
|
||||
builder.writeArg(c.value)
|
||||
}
|
||||
|
||||
type Changes []Change
|
||||
|
||||
func newChanges(cols ...Change) Change {
|
||||
return Changes(cols)
|
||||
}
|
||||
|
||||
// writeTo implements [Change].
|
||||
func (m Changes) writeTo(builder *statementBuilder) {
|
||||
for i, col := range m {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
col.writeTo(builder)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Change = Changes(nil)
|
||||
|
||||
var _ Change = (*change[string])(nil)
|
||||
|
||||
type Column interface {
|
||||
writeTo(builder *statementBuilder)
|
||||
}
|
||||
|
||||
type column struct {
|
||||
name string
|
||||
}
|
||||
|
||||
func (c column) writeTo(builder *statementBuilder) {
|
||||
builder.WriteString(c.name)
|
||||
}
|
||||
|
||||
type ignoreCaseColumn interface {
|
||||
Column
|
||||
writeIgnoreCaseTo(builder *statementBuilder)
|
||||
}
|
||||
|
||||
type ignoreCaseCol struct {
|
||||
column
|
||||
suffix string
|
||||
}
|
||||
|
||||
func (c ignoreCaseCol) writeIgnoreCaseTo(builder *statementBuilder) {
|
||||
c.column.writeTo(builder)
|
||||
builder.WriteString(c.suffix)
|
||||
}
|
112
backend/v3/storage/database/repository/stmt/v4/condition.go
Normal file
112
backend/v3/storage/database/repository/stmt/v4/condition.go
Normal file
@@ -0,0 +1,112 @@
|
||||
package v4
|
||||
|
||||
type Condition interface {
|
||||
writeTo(builder *statementBuilder)
|
||||
}
|
||||
|
||||
type and struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
// writeTo implements [Condition].
|
||||
func (a *and) writeTo(builder *statementBuilder) {
|
||||
if len(a.conditions) > 1 {
|
||||
builder.WriteString("(")
|
||||
defer builder.WriteString(")")
|
||||
}
|
||||
for i, condition := range a.conditions {
|
||||
if i > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
}
|
||||
condition.writeTo(builder)
|
||||
}
|
||||
}
|
||||
|
||||
func And(conditions ...Condition) *and {
|
||||
return &and{conditions: conditions}
|
||||
}
|
||||
|
||||
var _ Condition = (*and)(nil)
|
||||
|
||||
type or struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
// writeTo implements [Condition].
|
||||
func (o *or) writeTo(builder *statementBuilder) {
|
||||
if len(o.conditions) > 1 {
|
||||
builder.WriteString("(")
|
||||
defer builder.WriteString(")")
|
||||
}
|
||||
for i, condition := range o.conditions {
|
||||
if i > 0 {
|
||||
builder.WriteString(" OR ")
|
||||
}
|
||||
condition.writeTo(builder)
|
||||
}
|
||||
}
|
||||
|
||||
func Or(conditions ...Condition) *or {
|
||||
return &or{conditions: conditions}
|
||||
}
|
||||
|
||||
var _ Condition = (*or)(nil)
|
||||
|
||||
type isNull struct {
|
||||
column Column
|
||||
}
|
||||
|
||||
// writeTo implements [Condition].
|
||||
func (i *isNull) writeTo(builder *statementBuilder) {
|
||||
i.column.writeTo(builder)
|
||||
builder.WriteString(" IS NULL")
|
||||
}
|
||||
|
||||
func IsNull(column Column) *isNull {
|
||||
return &isNull{column: column}
|
||||
}
|
||||
|
||||
var _ Condition = (*isNull)(nil)
|
||||
|
||||
type isNotNull struct {
|
||||
column Column
|
||||
}
|
||||
|
||||
// writeTo implements [Condition].
|
||||
func (i *isNotNull) writeTo(builder *statementBuilder) {
|
||||
i.column.writeTo(builder)
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
}
|
||||
|
||||
func IsNotNull(column Column) *isNotNull {
|
||||
return &isNotNull{column: column}
|
||||
}
|
||||
|
||||
var _ Condition = (*isNotNull)(nil)
|
||||
|
||||
type valueCondition func(builder *statementBuilder)
|
||||
|
||||
func newTextCondition[V Text](col Column, op TextOperator, value V) Condition {
|
||||
return valueCondition(func(builder *statementBuilder) {
|
||||
writeTextOperation(builder, col, op, value)
|
||||
})
|
||||
}
|
||||
|
||||
func newNumberCondition[V Number](col Column, op NumberOperator, value V) Condition {
|
||||
return valueCondition(func(builder *statementBuilder) {
|
||||
writeNumberOperation(builder, col, op, value)
|
||||
})
|
||||
}
|
||||
|
||||
func newBooleanCondition[V Boolean](col Column, value V) Condition {
|
||||
return valueCondition(func(builder *statementBuilder) {
|
||||
writeBooleanOperation(builder, col, value)
|
||||
})
|
||||
}
|
||||
|
||||
// writeTo implements [Condition].
|
||||
func (c valueCondition) writeTo(builder *statementBuilder) {
|
||||
c(builder)
|
||||
}
|
||||
|
||||
var _ Condition = (*valueCondition)(nil)
|
2
backend/v3/storage/database/repository/stmt/v4/doc.go
Normal file
2
backend/v3/storage/database/repository/stmt/v4/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// this test focuses on queries rather than on tables
|
||||
package v4
|
149
backend/v3/storage/database/repository/stmt/v4/inheritance.sql
Normal file
149
backend/v3/storage/database/repository/stmt/v4/inheritance.sql
Normal file
@@ -0,0 +1,149 @@
|
||||
CREATE TABLE objects (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
deleted_at TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = NOW();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TABLE instances(
|
||||
name VARCHAR(50) NOT NULL
|
||||
, PRIMARY KEY (id)
|
||||
) INHERITS (objects);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON instances
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE instance_objects(
|
||||
instance_id INT NOT NULL
|
||||
, PRIMARY KEY (instance_id, id)
|
||||
-- as foreign keys are not inherited we need to define them on the child tables
|
||||
--, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (objects);
|
||||
|
||||
CREATE TABLE orgs(
|
||||
name VARCHAR(50) NOT NULL
|
||||
, PRIMARY KEY (instance_id, id)
|
||||
, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (instance_objects);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON orgs
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE org_objects(
|
||||
org_id INT NOT NULL
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- as foreign keys are not inherited we need to define them on the child tables
|
||||
-- CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id),
|
||||
-- CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (instance_objects);
|
||||
|
||||
CREATE TABLE users (
|
||||
username VARCHAR(50) NOT NULL
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- as foreign keys are not inherited we need to define them on the child tables
|
||||
-- , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
|
||||
-- , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (org_objects);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE human_users(
|
||||
first_name VARCHAR(50)
|
||||
, last_name VARCHAR(50)
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id),
|
||||
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
|
||||
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (users);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON human_users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE machine_users(
|
||||
description VARCHAR(50)
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- , CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id)
|
||||
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
|
||||
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (users);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON machine_users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
|
||||
select u.*, hu.first_name, hu.last_name, mu.description from users u
|
||||
left join human_users hu on u.instance_id = hu.instance_id and u.org_id = hu.org_id and u.id = hu.id
|
||||
left join machine_users mu on u.instance_id = mu.instance_id and u.org_id = mu.org_id and u.id = mu.id
|
||||
-- where
|
||||
-- u.instance_id = 1
|
||||
-- and u.org_id = 3
|
||||
-- and u.id = 7
|
||||
;
|
||||
|
||||
create view users_view as (
|
||||
SELECT
|
||||
id
|
||||
, created_at
|
||||
, updated_at
|
||||
, deleted_at
|
||||
, instance_id
|
||||
, org_id
|
||||
, username
|
||||
, first_name
|
||||
, last_name
|
||||
, description
|
||||
FROM (
|
||||
(SELECT
|
||||
id
|
||||
, created_at
|
||||
, updated_at
|
||||
, deleted_at
|
||||
, instance_id
|
||||
, org_id
|
||||
, username
|
||||
, first_name
|
||||
, last_name
|
||||
, NULL AS description
|
||||
FROM
|
||||
human_users)
|
||||
|
||||
UNION
|
||||
|
||||
(SELECT
|
||||
id
|
||||
, created_at
|
||||
, updated_at
|
||||
, deleted_at
|
||||
, instance_id
|
||||
, org_id
|
||||
, username
|
||||
, NULL AS first_name
|
||||
, NULL AS last_name
|
||||
, description
|
||||
FROM
|
||||
machine_users)
|
||||
));
|
139
backend/v3/storage/database/repository/stmt/v4/operators.go
Normal file
139
backend/v3/storage/database/repository/stmt/v4/operators.go
Normal file
@@ -0,0 +1,139 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Value interface {
|
||||
Boolean | Number | Text | databaseInstruction
|
||||
}
|
||||
|
||||
type Operator interface {
|
||||
BooleanOperator | NumberOperator | TextOperator
|
||||
}
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
}
|
||||
|
||||
type TextOperator uint8
|
||||
|
||||
const (
|
||||
// TextOperatorEqual compares two strings for equality.
|
||||
TextOperatorEqual TextOperator = iota + 1
|
||||
// TextOperatorEqualIgnoreCase compares two strings for equality, ignoring case.
|
||||
TextOperatorEqualIgnoreCase
|
||||
// TextOperatorNotEqual compares two strings for inequality.
|
||||
TextOperatorNotEqual
|
||||
// TextOperatorNotEqualIgnoreCase compares two strings for inequality, ignoring case.
|
||||
TextOperatorNotEqualIgnoreCase
|
||||
// TextOperatorStartsWith checks if the first string starts with the second.
|
||||
TextOperatorStartsWith
|
||||
// TextOperatorStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
|
||||
TextOperatorStartsWithIgnoreCase
|
||||
)
|
||||
|
||||
var textOperators = map[TextOperator]string{
|
||||
TextOperatorEqual: " = ",
|
||||
TextOperatorEqualIgnoreCase: " LIKE ",
|
||||
TextOperatorNotEqual: " <> ",
|
||||
TextOperatorNotEqualIgnoreCase: " NOT LIKE ",
|
||||
TextOperatorStartsWith: " LIKE ",
|
||||
TextOperatorStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
func writeTextOperation[T Text](builder *statementBuilder, col Column, op TextOperator, value T) {
|
||||
switch op {
|
||||
case TextOperatorEqual, TextOperatorNotEqual:
|
||||
col.writeTo(builder)
|
||||
builder.WriteString(textOperators[op])
|
||||
builder.WriteString(builder.appendArg(value))
|
||||
case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase:
|
||||
if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok {
|
||||
ignoreCaseCol.writeIgnoreCaseTo(builder)
|
||||
} else {
|
||||
builder.WriteString("LOWER(")
|
||||
col.writeTo(builder)
|
||||
builder.WriteString(")")
|
||||
}
|
||||
builder.WriteString(textOperators[op])
|
||||
builder.WriteString("LOWER(")
|
||||
builder.WriteString(builder.appendArg(value))
|
||||
builder.WriteString(")")
|
||||
case TextOperatorStartsWith:
|
||||
col.writeTo(builder)
|
||||
builder.WriteString(textOperators[op])
|
||||
builder.WriteString(builder.appendArg(value))
|
||||
builder.WriteString(" || '%'")
|
||||
case TextOperatorStartsWithIgnoreCase:
|
||||
if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok {
|
||||
ignoreCaseCol.writeIgnoreCaseTo(builder)
|
||||
} else {
|
||||
builder.WriteString("LOWER(")
|
||||
col.writeTo(builder)
|
||||
builder.WriteString(")")
|
||||
}
|
||||
builder.WriteString(textOperators[op])
|
||||
builder.WriteString("LOWER(")
|
||||
builder.WriteString(builder.appendArg(value))
|
||||
builder.WriteString(")")
|
||||
builder.WriteString(" || '%'")
|
||||
default:
|
||||
panic("unsupported text operation")
|
||||
}
|
||||
}
|
||||
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
|
||||
}
|
||||
|
||||
type NumberOperator uint8
|
||||
|
||||
const (
|
||||
// NumberOperatorEqual compares two numbers for equality.
|
||||
NumberOperatorEqual NumberOperator = iota + 1
|
||||
// NumberOperatorNotEqual compares two numbers for inequality.
|
||||
NumberOperatorNotEqual
|
||||
// NumberOperatorLessThan compares two numbers to check if the first is less than the second.
|
||||
NumberOperatorLessThan
|
||||
// NumberOperatorLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
|
||||
NumberOperatorAtLeast
|
||||
// NumberOperatorGreaterThan compares two numbers to check if the first is greater than the second.
|
||||
NumberOperatorGreaterThan
|
||||
// NumberOperatorGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
|
||||
NumberOperatorAtMost
|
||||
)
|
||||
|
||||
var numberOperators = map[NumberOperator]string{
|
||||
NumberOperatorEqual: " = ",
|
||||
NumberOperatorNotEqual: " <> ",
|
||||
NumberOperatorLessThan: " < ",
|
||||
NumberOperatorAtLeast: " <= ",
|
||||
NumberOperatorGreaterThan: " > ",
|
||||
NumberOperatorAtMost: " >= ",
|
||||
}
|
||||
|
||||
func writeNumberOperation[T Number](builder *statementBuilder, col Column, op NumberOperator, value T) {
|
||||
col.writeTo(builder)
|
||||
builder.WriteString(numberOperators[op])
|
||||
builder.WriteString(builder.appendArg(value))
|
||||
}
|
||||
|
||||
type Boolean interface {
|
||||
~bool
|
||||
}
|
||||
|
||||
type BooleanOperator uint8
|
||||
|
||||
const (
|
||||
BooleanOperatorIsTrue BooleanOperator = iota + 1
|
||||
BooleanOperatorIsFalse
|
||||
)
|
||||
|
||||
func writeBooleanOperation[T Boolean](builder *statementBuilder, col Column, value T) {
|
||||
col.writeTo(builder)
|
||||
builder.WriteString(" IS ")
|
||||
builder.WriteString(builder.appendArg(value))
|
||||
}
|
18
backend/v3/storage/database/repository/stmt/v4/org.go
Normal file
18
backend/v3/storage/database/repository/stmt/v4/org.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package v4
|
||||
|
||||
type Org struct {
|
||||
InstanceID string
|
||||
ID string
|
||||
Name string
|
||||
Dates
|
||||
}
|
||||
|
||||
type GetOrg struct{}
|
||||
|
||||
type ListOrgs struct{}
|
||||
|
||||
type CreateOrg struct{}
|
||||
|
||||
type UpdateOrg struct{}
|
||||
|
||||
type DeleteOrg struct{}
|
46
backend/v3/storage/database/repository/stmt/v4/statement.go
Normal file
46
backend/v3/storage/database/repository/stmt/v4/statement.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type databaseInstruction string
|
||||
|
||||
const (
|
||||
nowDBInstruction databaseInstruction = "NOW()"
|
||||
nullDBInstruction databaseInstruction = "NULL"
|
||||
)
|
||||
|
||||
type statementBuilder struct {
|
||||
strings.Builder
|
||||
args []any
|
||||
existingArgs map[any]string
|
||||
}
|
||||
|
||||
func (b *statementBuilder) writeArg(arg any) {
|
||||
b.WriteString(b.appendArg(arg))
|
||||
}
|
||||
|
||||
func (b *statementBuilder) appendArg(arg any) (placeholder string) {
|
||||
if b.existingArgs == nil {
|
||||
b.existingArgs = make(map[any]string)
|
||||
}
|
||||
if placeholder, ok := b.existingArgs[arg]; ok {
|
||||
return placeholder
|
||||
}
|
||||
if instruction, ok := arg.(databaseInstruction); ok {
|
||||
return string(instruction)
|
||||
}
|
||||
|
||||
b.args = append(b.args, arg)
|
||||
placeholder = "$" + strconv.Itoa(len(b.args))
|
||||
b.existingArgs[arg] = placeholder
|
||||
return placeholder
|
||||
}
|
||||
|
||||
func (b *statementBuilder) appendArgs(args ...any) {
|
||||
for _, arg := range args {
|
||||
b.appendArg(arg)
|
||||
}
|
||||
}
|
239
backend/v3/storage/database/repository/stmt/v4/user.go
Normal file
239
backend/v3/storage/database/repository/stmt/v4/user.go
Normal file
@@ -0,0 +1,239 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Dates struct {
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
}
|
||||
|
||||
type User struct {
|
||||
InstanceID string
|
||||
OrgID string
|
||||
ID string
|
||||
Username string
|
||||
Traits userTrait
|
||||
Dates
|
||||
}
|
||||
|
||||
type UserType string
|
||||
|
||||
type userTrait interface {
|
||||
userTrait()
|
||||
Type() UserType
|
||||
}
|
||||
|
||||
const userQuery = `SELECT u.instance_id, u.org_id, u.id, u.username, u.type, u.created_at, u.updated_at, u.deleted_at,` +
|
||||
` h.first_name, h.last_name, h.email_address, h.email_verified_at, h.phone_number, h.phone_verified_at, m.description` +
|
||||
` FROM users u` +
|
||||
` LEFT JOIN user_humans h ON u.instance_id = h.instance_id AND u.org_id = h.org_id AND u.id = h.id` +
|
||||
` LEFT JOIN user_machines m ON u.instance_id = m.instance_id AND u.org_id = m.org_id AND u.id = m.id`
|
||||
|
||||
type user struct {
|
||||
builder statementBuilder
|
||||
client database.QueryExecutor
|
||||
|
||||
condition Condition
|
||||
}
|
||||
|
||||
func UserRepository(client database.QueryExecutor) *user {
|
||||
return &user{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
func (u *user) WithCondition(condition Condition) *user {
|
||||
u.condition = condition
|
||||
return u
|
||||
}
|
||||
|
||||
func (u *user) Get(ctx context.Context) (*User, error) {
|
||||
u.builder.WriteString(userQuery)
|
||||
u.writeCondition()
|
||||
return scanUser(u.client.QueryRow(ctx, u.builder.String(), u.builder.args...))
|
||||
}
|
||||
|
||||
func (u *user) List(ctx context.Context) (users []*User, err error) {
|
||||
u.builder.WriteString(userQuery)
|
||||
u.writeCondition()
|
||||
|
||||
rows, err := u.client.Query(ctx, u.builder.String(), u.builder.args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = closeErr
|
||||
}()
|
||||
for rows.Next() {
|
||||
user, err := scanUser(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
const (
|
||||
createUserCte = `WITH user AS (` +
|
||||
`INSERT INTO users (instance_id, org_id, id, username, type) VALUES ($1, $2, $3, $4, $5)` +
|
||||
` RETURNING *)`
|
||||
createHumanStmt = createUserCte + ` INSERT INTO user_humans h (instance_id, org_id, user_id, first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at)` +
|
||||
` SELECT u.instance_id, u.org_id, u.id, $6, $7, $8, $9, $10, $11` +
|
||||
` FROM user u` +
|
||||
` RETURNING u.created_at, u.updated_at, u.deleted_at`
|
||||
createMachineStmt = createUserCte + ` INSERT INTO user_machines (instance_id, org_id, user_id, description)` +
|
||||
` SELECT u.instance_id, u.org_id, u.id, $6` +
|
||||
` FROM user u` +
|
||||
` RETURNING u.created_at, u.updated_at`
|
||||
)
|
||||
|
||||
func (u *user) Create(ctx context.Context, user *User) error {
|
||||
u.builder.appendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
|
||||
switch trait := user.Traits.(type) {
|
||||
case *Human:
|
||||
u.builder.WriteString(createHumanStmt)
|
||||
u.builder.appendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
|
||||
case *Machine:
|
||||
u.builder.WriteString(createMachineStmt)
|
||||
u.builder.appendArgs(trait.Description)
|
||||
}
|
||||
return u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan(user.CreatedAt, user.UpdatedAt)
|
||||
}
|
||||
|
||||
func (u *user) InstanceIDColumn() Column {
|
||||
return column{name: "u.instance_id"}
|
||||
}
|
||||
|
||||
func (u *user) InstanceIDCondition(instanceID string) Condition {
|
||||
return newTextCondition(u.InstanceIDColumn(), TextOperatorEqual, instanceID)
|
||||
}
|
||||
|
||||
func (u *user) OrgIDColumn() Column {
|
||||
return column{name: "u.org_id"}
|
||||
}
|
||||
|
||||
func (u *user) OrgIDCondition(orgID string) Condition {
|
||||
return newTextCondition(u.OrgIDColumn(), TextOperatorEqual, orgID)
|
||||
}
|
||||
|
||||
func (u *user) IDColumn() Column {
|
||||
return column{name: "u.id"}
|
||||
}
|
||||
|
||||
func (u *user) IDCondition(userID string) Condition {
|
||||
return newTextCondition(u.IDColumn(), TextOperatorEqual, userID)
|
||||
}
|
||||
|
||||
func (u *user) UsernameColumn() Column {
|
||||
return ignoreCaseCol{
|
||||
column: column{name: "u.username"},
|
||||
suffix: "_lower",
|
||||
}
|
||||
}
|
||||
|
||||
func (u user) SetUsername(username string) Change {
|
||||
return newChange(u.UsernameColumn(), username)
|
||||
}
|
||||
|
||||
func (u *user) UsernameCondition(op TextOperator, username string) Condition {
|
||||
return newTextCondition(u.UsernameColumn(), op, username)
|
||||
}
|
||||
|
||||
func (u *user) CreatedAtColumn() Column {
|
||||
return column{name: "u.created_at"}
|
||||
}
|
||||
|
||||
func (u *user) CreatedAtCondition(op NumberOperator, createdAt time.Time) Condition {
|
||||
return newNumberCondition(u.CreatedAtColumn(), op, createdAt)
|
||||
}
|
||||
|
||||
func (u *user) UpdatedAtColumn() Column {
|
||||
return column{name: "u.updated_at"}
|
||||
}
|
||||
|
||||
func (u *user) UpdatedAtCondition(op NumberOperator, updatedAt time.Time) Condition {
|
||||
return newNumberCondition(u.UpdatedAtColumn(), op, updatedAt)
|
||||
}
|
||||
|
||||
func (u *user) DeletedAtColumn() Column {
|
||||
return column{name: "u.deleted_at"}
|
||||
}
|
||||
|
||||
func (u *user) DeletedCondition(isDeleted bool) Condition {
|
||||
if isDeleted {
|
||||
return IsNotNull(u.DeletedAtColumn())
|
||||
}
|
||||
return IsNull(u.DeletedAtColumn())
|
||||
}
|
||||
|
||||
func (u *user) DeletedAtCondition(op NumberOperator, deletedAt time.Time) Condition {
|
||||
return newNumberCondition(u.DeletedAtColumn(), op, deletedAt)
|
||||
}
|
||||
|
||||
func (u *user) writeCondition() {
|
||||
if u.condition == nil {
|
||||
return
|
||||
}
|
||||
u.builder.WriteString(" WHERE ")
|
||||
u.condition.writeTo(&u.builder)
|
||||
}
|
||||
|
||||
func scanUser(scanner database.Scanner) (*User, error) {
|
||||
var (
|
||||
user User
|
||||
human Human
|
||||
email Email
|
||||
phone Phone
|
||||
machine Machine
|
||||
typ UserType
|
||||
)
|
||||
err := scanner.Scan(
|
||||
&user.InstanceID,
|
||||
&user.OrgID,
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&typ,
|
||||
&user.Dates.CreatedAt,
|
||||
&user.Dates.UpdatedAt,
|
||||
&user.Dates.DeletedAt,
|
||||
&human.FirstName,
|
||||
&human.LastName,
|
||||
&email.Address,
|
||||
&email.VerifiedAt,
|
||||
&phone.Number,
|
||||
&phone.VerifiedAt,
|
||||
&machine.Description,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch typ {
|
||||
case UserTypeHuman:
|
||||
if email.Address != "" {
|
||||
human.Email = &email
|
||||
}
|
||||
if phone.Number != "" {
|
||||
human.Phone = &phone
|
||||
}
|
||||
user.Traits = &human
|
||||
case UserTypeMachine:
|
||||
user.Traits = &machine
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
187
backend/v3/storage/database/repository/stmt/v4/user_human.go
Normal file
187
backend/v3/storage/database/repository/stmt/v4/user_human.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package v4
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Human struct {
|
||||
FirstName string
|
||||
LastName string
|
||||
Email *Email
|
||||
Phone *Phone
|
||||
}
|
||||
|
||||
const UserTypeHuman UserType = "human"
|
||||
|
||||
func (Human) userTrait() {}
|
||||
|
||||
func (h Human) Type() UserType {
|
||||
return UserTypeHuman
|
||||
}
|
||||
|
||||
var _ userTrait = (*Human)(nil)
|
||||
|
||||
type Email struct {
|
||||
Address string
|
||||
Verification
|
||||
}
|
||||
|
||||
type Phone struct {
|
||||
Number string
|
||||
Verification
|
||||
}
|
||||
|
||||
type Verification struct {
|
||||
VerifiedAt time.Time
|
||||
}
|
||||
|
||||
type userHuman struct {
|
||||
*user
|
||||
}
|
||||
|
||||
func (u *user) Human() *userHuman {
|
||||
return &userHuman{user: u}
|
||||
}
|
||||
|
||||
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
|
||||
|
||||
func (u *userHuman) GetEmail(ctx context.Context) (*Email, error) {
|
||||
var email Email
|
||||
|
||||
u.builder.WriteString(userEmailQuery)
|
||||
u.writeCondition()
|
||||
|
||||
err := u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan(
|
||||
&email.Address,
|
||||
&email.Verification.VerifiedAt,
|
||||
)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &email, nil
|
||||
}
|
||||
|
||||
func (h userHuman) Update(ctx context.Context, changes ...Change) error {
|
||||
h.builder.WriteString(`UPDATE human_users h SET `)
|
||||
Changes(changes).writeTo(&h.builder)
|
||||
h.writeCondition()
|
||||
|
||||
stmt := h.builder.String()
|
||||
|
||||
return h.client.Exec(ctx, stmt, h.builder.args...)
|
||||
}
|
||||
|
||||
func (h userHuman) SetFirstName(firstName string) Change {
|
||||
return newChange(h.FirstNameColumn(), firstName)
|
||||
}
|
||||
|
||||
func (h userHuman) FirstNameColumn() Column {
|
||||
return column{"h.first_name"}
|
||||
}
|
||||
|
||||
func (h userHuman) FirstNameCondition(op TextOperator, firstName string) Condition {
|
||||
return newTextCondition(h.FirstNameColumn(), op, firstName)
|
||||
}
|
||||
|
||||
func (h userHuman) SetLastName(lastName string) Change {
|
||||
return newChange(h.LastNameColumn(), lastName)
|
||||
}
|
||||
|
||||
func (h userHuman) LastNameColumn() Column {
|
||||
return column{"h.last_name"}
|
||||
}
|
||||
|
||||
func (h userHuman) LastNameCondition(op TextOperator, lastName string) Condition {
|
||||
return newTextCondition(h.LastNameColumn(), op, lastName)
|
||||
}
|
||||
|
||||
func (h userHuman) EmailAddressColumn() Column {
|
||||
return ignoreCaseCol{
|
||||
column: column{"h.email_address"},
|
||||
suffix: "_lower",
|
||||
}
|
||||
}
|
||||
|
||||
func (h userHuman) EmailAddressCondition(op TextOperator, email string) Condition {
|
||||
return newTextCondition(h.EmailAddressColumn(), op, email)
|
||||
}
|
||||
|
||||
func (h userHuman) EmailVerifiedAtColumn() Column {
|
||||
return column{"h.email_verified_at"}
|
||||
}
|
||||
|
||||
func (h *userHuman) EmailAddressVerifiedCondition(isVerified bool) Condition {
|
||||
if isVerified {
|
||||
return IsNotNull(h.EmailVerifiedAtColumn())
|
||||
}
|
||||
return IsNull(h.EmailVerifiedAtColumn())
|
||||
}
|
||||
|
||||
func (h userHuman) EmailVerifiedAtCondition(op TextOperator, emailVerifiedAt string) Condition {
|
||||
return newTextCondition(h.EmailVerifiedAtColumn(), op, emailVerifiedAt)
|
||||
}
|
||||
|
||||
func (h userHuman) SetEmailAddress(address string) Change {
|
||||
return newChange(h.EmailAddressColumn(), address)
|
||||
}
|
||||
|
||||
// SetEmailVerified sets the verified column of the email
|
||||
// if at is zero the statement uses the database timestamp
|
||||
func (h userHuman) SetEmailVerified(at time.Time) Change {
|
||||
if at.IsZero() {
|
||||
return newChange(h.EmailVerifiedAtColumn(), nowDBInstruction)
|
||||
}
|
||||
return newChange(h.EmailVerifiedAtColumn(), at)
|
||||
}
|
||||
|
||||
func (h userHuman) SetEmail(address string, verified *time.Time) Change {
|
||||
return newChanges(
|
||||
h.SetEmailAddress(address),
|
||||
newUpdatePtrColumn(h.EmailVerifiedAtColumn(), verified),
|
||||
)
|
||||
}
|
||||
|
||||
func (h userHuman) PhoneNumberColumn() Column {
|
||||
return column{"h.phone_number"}
|
||||
}
|
||||
|
||||
func (h userHuman) SetPhoneNumber(number string) Change {
|
||||
return newChange(h.PhoneNumberColumn(), number)
|
||||
}
|
||||
|
||||
func (h userHuman) PhoneNumberCondition(op TextOperator, phoneNumber string) Condition {
|
||||
return newTextCondition(h.PhoneNumberColumn(), op, phoneNumber)
|
||||
}
|
||||
|
||||
func (h userHuman) PhoneVerifiedAtColumn() Column {
|
||||
return column{"h.phone_verified_at"}
|
||||
}
|
||||
|
||||
func (h userHuman) PhoneNumberVerifiedCondition(isVerified bool) Condition {
|
||||
if isVerified {
|
||||
return IsNotNull(h.PhoneVerifiedAtColumn())
|
||||
}
|
||||
return IsNull(h.PhoneVerifiedAtColumn())
|
||||
}
|
||||
|
||||
// SetPhoneVerified sets the verified column of the phone
|
||||
// if at is zero the statement uses the database timestamp
|
||||
func (h userHuman) SetPhoneVerified(at time.Time) Change {
|
||||
if at.IsZero() {
|
||||
return newChange(h.PhoneVerifiedAtColumn(), nowDBInstruction)
|
||||
}
|
||||
return newChange(h.PhoneVerifiedAtColumn(), at)
|
||||
}
|
||||
|
||||
func (h userHuman) PhoneVerifiedAtCondition(op TextOperator, phoneVerifiedAt string) Condition {
|
||||
return newTextCondition(h.PhoneVerifiedAtColumn(), op, phoneVerifiedAt)
|
||||
}
|
||||
|
||||
func (h userHuman) SetPhone(number string, verifiedAt *time.Time) Change {
|
||||
return newChanges(
|
||||
h.SetPhoneNumber(number),
|
||||
newUpdatePtrColumn(h.PhoneVerifiedAtColumn(), verifiedAt),
|
||||
)
|
||||
}
|
@@ -0,0 +1,41 @@
|
||||
package v4
|
||||
|
||||
import "context"
|
||||
|
||||
type Machine struct {
|
||||
Description string
|
||||
}
|
||||
|
||||
func (Machine) userTrait() {}
|
||||
|
||||
func (m Machine) Type() UserType {
|
||||
return UserTypeMachine
|
||||
}
|
||||
|
||||
const UserTypeMachine UserType = "machine"
|
||||
|
||||
var _ userTrait = (*Machine)(nil)
|
||||
|
||||
type userMachine struct {
|
||||
*user
|
||||
}
|
||||
|
||||
func (u *user) Machine() *userMachine {
|
||||
return &userMachine{user: u}
|
||||
}
|
||||
|
||||
func (m userMachine) Update(ctx context.Context, cols ...Change) (*Machine, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (userMachine) DescriptionColumn() Column {
|
||||
return column{"m.description"}
|
||||
}
|
||||
|
||||
func (m userMachine) SetDescription(description string) Change {
|
||||
return newChange(m.DescriptionColumn(), description)
|
||||
}
|
||||
|
||||
func (m userMachine) DescriptionCondition(op TextOperator, description string) Condition {
|
||||
return newTextCondition(m.DescriptionColumn(), op, description)
|
||||
}
|
65
backend/v3/storage/database/repository/stmt/v4/user_test.go
Normal file
65
backend/v3/storage/database/repository/stmt/v4/user_test.go
Normal file
@@ -0,0 +1,65 @@
|
||||
package v4_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
|
||||
)
|
||||
|
||||
func TestQueryUser(t *testing.T) {
|
||||
t.Run("User filters", func(t *testing.T) {
|
||||
user := v4.UserRepository(nil)
|
||||
user.WithCondition(
|
||||
v4.And(
|
||||
v4.Or(
|
||||
user.IDCondition("test"),
|
||||
user.IDCondition("2"),
|
||||
),
|
||||
user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
|
||||
),
|
||||
).Get(context.Background())
|
||||
})
|
||||
|
||||
t.Run("machine and human filters", func(t *testing.T) {
|
||||
user := v4.UserRepository(nil)
|
||||
machine := user.Machine()
|
||||
human := user.Human()
|
||||
user.WithCondition(
|
||||
v4.And(
|
||||
user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
|
||||
v4.Or(
|
||||
machine.DescriptionCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
|
||||
human.EmailAddressVerifiedCondition(true),
|
||||
v4.IsNotNull(machine.DescriptionColumn()),
|
||||
),
|
||||
),
|
||||
)
|
||||
human.GetEmail(context.Background())
|
||||
})
|
||||
}
|
||||
|
||||
type dbInstruction string
|
||||
|
||||
func TestArg(t *testing.T) {
|
||||
var bla any = "asdf"
|
||||
instr, ok := bla.(dbInstruction)
|
||||
assert.False(t, ok)
|
||||
assert.Empty(t, instr)
|
||||
bla = dbInstruction("asdf")
|
||||
instr, ok = bla.(dbInstruction)
|
||||
assert.True(t, ok)
|
||||
assert.Equal(t, instr, dbInstruction("asdf"))
|
||||
}
|
||||
|
||||
func TestWriteUser(t *testing.T) {
|
||||
t.Run("update user", func(t *testing.T) {
|
||||
user := v4.UserRepository(nil)
|
||||
user.WithCondition(user.IDCondition("test")).Human().Update(
|
||||
context.Background(),
|
||||
user.SetUsername("test"),
|
||||
)
|
||||
|
||||
})
|
||||
}
|
39
backend/v3/storage/database/repository/user.go
Normal file
39
backend/v3/storage/database/repository/user.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type user struct {
|
||||
database.QueryExecutor
|
||||
}
|
||||
|
||||
func User(client database.QueryExecutor) domain.UserRepository {
|
||||
// return &user{QueryExecutor: client}
|
||||
return nil
|
||||
}
|
||||
|
||||
// On implements [domain.UserRepository].
|
||||
func (exec *user) On(clauses ...domain.UserClause) domain.UserOperation {
|
||||
return &userOperation{
|
||||
QueryExecutor: exec.QueryExecutor,
|
||||
clauses: clauses,
|
||||
}
|
||||
}
|
||||
|
||||
// OnHuman implements [domain.UserRepository].
|
||||
func (exec *user) OnHuman(clauses ...domain.UserClause) domain.HumanOperation {
|
||||
return &humanOperation{
|
||||
userOperation: *exec.On(clauses...).(*userOperation),
|
||||
}
|
||||
}
|
||||
|
||||
// OnMachine implements [domain.UserRepository].
|
||||
func (exec *user) OnMachine(clauses ...domain.UserClause) domain.MachineOperation {
|
||||
return &machineOperation{
|
||||
userOperation: *exec.On(clauses...).(*userOperation),
|
||||
}
|
||||
}
|
||||
|
||||
// var _ domain.UserRepository = (*user)(nil)
|
@@ -0,0 +1,36 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
)
|
||||
|
||||
type humanOperation struct {
|
||||
userOperation
|
||||
}
|
||||
|
||||
// GetEmail implements domain.HumanOperation.
|
||||
func (h *humanOperation) GetEmail(ctx context.Context) (*domain.Email, error) {
|
||||
var email domain.Email
|
||||
err := h.QueryExecutor.QueryRow(ctx, `SELECT email, is_email_verified FROM human_users WHERE id = $1`, h.clauses).Scan(
|
||||
&email.Address,
|
||||
&email.IsVerified,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &email, nil
|
||||
}
|
||||
|
||||
// SetEmail implements domain.HumanOperation.
|
||||
func (h *humanOperation) SetEmail(ctx context.Context, email string) error {
|
||||
return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET email = $1 WHERE id = $2`, email, h.clauses)
|
||||
}
|
||||
|
||||
// SetEmailVerified implements domain.HumanOperation.
|
||||
func (h *humanOperation) SetEmailVerified(ctx context.Context, email string) error {
|
||||
return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET is_email_verified = $1 WHERE id = $2 AND email = $3`, true, h.clauses, email)
|
||||
}
|
||||
|
||||
var _ domain.HumanOperation = (*humanOperation)(nil)
|
@@ -0,0 +1,18 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
)
|
||||
|
||||
type machineOperation struct {
|
||||
userOperation
|
||||
}
|
||||
|
||||
// SetDescription implements domain.MachineOperation.
|
||||
func (m *machineOperation) SetDescription(ctx context.Context, description string) error {
|
||||
return m.QueryExecutor.Exec(ctx, `UPDATE machines SET description = $1 WHERE id = $2`, description, m.clauses)
|
||||
}
|
||||
|
||||
var _ domain.MachineOperation = (*machineOperation)(nil)
|
68
backend/v3/storage/database/repository/user_operation.go
Normal file
68
backend/v3/storage/database/repository/user_operation.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type userOperation struct {
|
||||
database.QueryExecutor
|
||||
clauses []domain.UserClause
|
||||
}
|
||||
|
||||
// Delete implements [domain.UserOperation].
|
||||
func (u *userOperation) Delete(ctx context.Context) error {
|
||||
return u.QueryExecutor.Exec(ctx, `DELETE FROM users WHERE id = $1`, u.clauses)
|
||||
}
|
||||
|
||||
// SetUsername implements [domain.UserOperation].
|
||||
func (u *userOperation) SetUsername(ctx context.Context, username string) error {
|
||||
var stmt statement
|
||||
|
||||
stmt.builder.WriteString(`UPDATE users SET username = $1 WHERE `)
|
||||
stmt.appendArg(username)
|
||||
clausesToSQL(&stmt, u.clauses)
|
||||
return u.QueryExecutor.Exec(ctx, stmt.builder.String(), stmt.args...)
|
||||
}
|
||||
|
||||
var _ domain.UserOperation = (*userOperation)(nil)
|
||||
|
||||
func UserIDQuery(id string) domain.UserClause {
|
||||
return textClause[string]{
|
||||
clause: clause[domain.TextOperation]{
|
||||
field: userFields[domain.UserFieldID],
|
||||
op: domain.TextOperationEqual,
|
||||
},
|
||||
value: id,
|
||||
}
|
||||
}
|
||||
|
||||
func HumanEmailQuery(op domain.TextOperation, email string) domain.UserClause {
|
||||
return textClause[string]{
|
||||
clause: clause[domain.TextOperation]{
|
||||
field: userFields[domain.UserHumanFieldEmail],
|
||||
op: op,
|
||||
},
|
||||
value: email,
|
||||
}
|
||||
}
|
||||
|
||||
func HumanEmailVerifiedQuery(op domain.BoolOperation) domain.UserClause {
|
||||
return boolClause[domain.BoolOperation]{
|
||||
clause: clause[domain.BoolOperation]{
|
||||
field: userFields[domain.UserHumanFieldEmailVerified],
|
||||
op: op,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func clausesToSQL(stmt *statement, clauses []domain.UserClause) {
|
||||
for _, clause := range clauses {
|
||||
|
||||
stmt.builder.WriteString(userFields[clause.Field()].String())
|
||||
stmt.builder.WriteString(clause.Operation().String())
|
||||
stmt.appendArg(clause.Args()...)
|
||||
}
|
||||
}
|
36
backend/v3/storage/database/tx.go
Normal file
36
backend/v3/storage/database/tx.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package database
|
||||
|
||||
import "context"
|
||||
|
||||
type Transaction interface {
|
||||
Commit(ctx context.Context) error
|
||||
Rollback(ctx context.Context) error
|
||||
End(ctx context.Context, err error) error
|
||||
|
||||
Begin(ctx context.Context) (Transaction, error)
|
||||
|
||||
QueryExecutor
|
||||
}
|
||||
|
||||
type Beginner interface {
|
||||
Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error)
|
||||
}
|
||||
|
||||
type TransactionOptions struct {
|
||||
IsolationLevel IsolationLevel
|
||||
AccessMode AccessMode
|
||||
}
|
||||
|
||||
type IsolationLevel uint8
|
||||
|
||||
const (
|
||||
IsolationLevelSerializable IsolationLevel = iota
|
||||
IsolationLevelReadCommitted
|
||||
)
|
||||
|
||||
type AccessMode uint8
|
||||
|
||||
const (
|
||||
AccessModeReadWrite AccessMode = iota
|
||||
AccessModeReadOnly
|
||||
)
|
23
backend/v3/storage/eventstore/event.go
Normal file
23
backend/v3/storage/eventstore/event.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package eventstore
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Event struct {
|
||||
AggregateType string `json:"aggregateType"`
|
||||
AggregateID string `json:"aggregateId"`
|
||||
Type string `json:"type"`
|
||||
Payload any `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
func Publish(ctx context.Context, events []*Event, db database.Executor) error {
|
||||
for _, event := range events {
|
||||
if err := db.Exec(ctx, `INSERT INTO events (aggregate_type, aggregate_id) VALUES ($1, $2)`, event.AggregateType, event.AggregateID); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
7
backend/v3/telemetry/logging/logger.go
Normal file
7
backend/v3/telemetry/logging/logger.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package logging
|
||||
|
||||
import "log/slog"
|
||||
|
||||
type Logger struct {
|
||||
*slog.Logger
|
||||
}
|
23
backend/v3/telemetry/tracing/tracer.go
Normal file
23
backend/v3/telemetry/tracing/tracer.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
"go.opentelemetry.io/otel/trace/noop"
|
||||
)
|
||||
|
||||
type Tracer struct {
|
||||
trace.Tracer
|
||||
}
|
||||
|
||||
var noopTracer = Tracer{
|
||||
Tracer: noop.NewTracerProvider().Tracer(""),
|
||||
}
|
||||
|
||||
func (t *Tracer) Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
|
||||
if t.Tracer == nil {
|
||||
return noopTracer.Start(ctx, spanName, opts...)
|
||||
}
|
||||
return t.Tracer.Start(ctx, spanName, opts...)
|
||||
}
|
Reference in New Issue
Block a user