multiple tries

This commit is contained in:
adlerhurst
2025-04-29 06:03:47 +02:00
parent 77c4cc8185
commit 986c62b61a
131 changed files with 9805 additions and 47 deletions

View File

@@ -0,0 +1,19 @@
package v2
import (
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
var (
logger logging.Logger
tracer tracing.Tracer
)
func SetLogger(l logging.Logger) {
logger = l
}
func SetTracer(t tracing.Tracer) {
tracer = t
}

View File

@@ -0,0 +1,93 @@
package userv2
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
var (
verification domain.SetEmailOpt
returnCode *domain.ReturnCodeCommand
)
switch req.GetVerification().(type) {
case *user.SetEmailRequest_IsVerified:
verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified())
case *user.SetEmailRequest_SendCode:
verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
case *user.SetEmailRequest_ReturnCode:
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
verification = returnCode
default:
verification = domain.NewSendCodeCommand(req.GetUserId(), nil)
}
err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification))
if err != nil {
return nil, err
}
var code *string
if returnCode != nil && returnCode.Code != "" {
code = &returnCode.Code
}
return &user.SetEmailResponse{
VerificationCode: code,
}, nil
}
func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
var (
returnCode *domain.ReturnCodeCommand
cmd domain.Commander
)
switch req.GetVerification().(type) {
case *user.SendEmailCodeRequest_SendCode:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
case *user.SendEmailCodeRequest_ReturnCode:
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
cmd = returnCode
default:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
}
err = domain.Invoke(ctx, cmd)
if err != nil {
return nil, err
}
resp = new(user.SendEmailCodeResponse)
if returnCode != nil {
resp.VerificationCode = &returnCode.Code
}
return resp, nil
}
func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
var (
returnCode *domain.ReturnCodeCommand
cmd domain.Commander
)
switch req.GetVerification().(type) {
case *user.ResendEmailCodeRequest_SendCode:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
case *user.ResendEmailCodeRequest_ReturnCode:
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
cmd = returnCode
default:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
}
err = domain.Invoke(ctx, cmd)
if err != nil {
return nil, err
}
resp = new(user.SendEmailCodeResponse)
if returnCode != nil {
resp.VerificationCode = &returnCode.Code
}
return resp, nil
}

View File

@@ -0,0 +1,19 @@
package userv2
import (
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
var (
logger logging.Logger
tracer tracing.Tracer
)
func SetLogger(l logging.Logger) {
logger = l
}
func SetTracer(t tracing.Tracer) {
tracer = t
}

12
backend/v3/doc.go Normal file
View File

@@ -0,0 +1,12 @@
// the test used the manly relies on the following patterns:
// - domain:
// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package
// - command pattern which implements the changes
// - the invoker decorates the commands by checking for events and tracing
// - the database connections are manged in this package
// - the database connections are passed to the repositories
//
// - storage:
// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package
// - the repositories are used by the domain package to access the database
package v3

View File

@@ -0,0 +1,105 @@
package domain
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Commander interface {
Execute(ctx context.Context, opts *CommandOpts) (err error)
}
type Invoker interface {
Invoke(ctx context.Context, command Commander, opts *CommandOpts) error
}
type CommandOpts struct {
DB database.QueryExecutor
Invoker Invoker
}
type ensureTxOpts struct {
*database.TransactionOptions
}
type EnsureTransactionOpt func(*ensureTxOpts)
// EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction.
// The returned close function will end the transaction. If the DB is already a transaction, the close function
// will do nothing because another [Commander] is already responsible for ending the transaction.
func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) {
beginner, ok := o.DB.(database.Beginner)
if !ok {
// db is already a transaction
return func(_ context.Context, err error) error {
return err
}, nil
}
txOpts := &ensureTxOpts{
TransactionOptions: new(database.TransactionOptions),
}
for _, opt := range opts {
opt(txOpts)
}
tx, err := beginner.Begin(ctx, txOpts.TransactionOptions)
if err != nil {
return nil, err
}
o.DB = tx
return func(ctx context.Context, err error) error {
return tx.End(ctx, err)
}, nil
}
// EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool].
// The returned close function will release the client. If the o.DB is already a client or transaction, the close function
// will do nothing because another [Commander] is already responsible for releasing the client.
func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) {
pool, ok := o.DB.(database.Pool)
if !ok {
// o.DB is already a client
return func(_ context.Context) error {
return nil
}, nil
}
client, err := pool.Acquire(ctx)
if err != nil {
return nil, err
}
o.DB = client
return func(ctx context.Context) error {
return client.Release(ctx)
}, nil
}
func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error {
if o.Invoker == nil {
return command.Execute(ctx, o)
}
return o.Invoker.Invoke(ctx, command, o)
}
func DefaultOpts(invoker Invoker) *CommandOpts {
if invoker == nil {
invoker = &noopInvoker{}
}
return &CommandOpts{
DB: pool,
Invoker: invoker,
}
}
type noopInvoker struct {
next Invoker
}
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
if i.next != nil {
return i.next.Invoke(ctx, command, opts)
}
return command.Execute(ctx, opts)
}

View File

@@ -0,0 +1,76 @@
package domain
import (
"context"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
)
type CreateUserCommand struct {
user *User
email *SetEmailCommand
}
var (
_ Commander = (*CreateUserCommand)(nil)
_ eventer = (*CreateUserCommand)(nil)
)
func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand {
cmd := &CreateUserCommand{
user: &User{
User: v4.User{
Username: username,
Traits: &v4.Human{},
},
},
}
for _, opt := range opts {
opt.applyOnCreateHuman(cmd)
}
return cmd
}
// Events implements [eventer].
func (c *CreateUserCommand) Events() []*eventstore.Event {
panic("unimplemented")
}
// Execute implements [Commander].
func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := c.ensureUserID(); err != nil {
return err
}
c.email.UserID = c.user.ID
if err := opts.Invoke(ctx, c.email); err != nil {
return err
}
return nil
}
type CreateHumanOpt interface {
applyOnCreateHuman(*CreateUserCommand)
}
type createHumanIDOpt string
// applyOnCreateHuman implements [CreateHumanOpt].
func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) {
cmd.user.ID = string(c)
}
var _ CreateHumanOpt = (*createHumanIDOpt)(nil)
func CreateHumanWithID(id string) CreateHumanOpt {
return createHumanIDOpt(id)
}
func (c *CreateUserCommand) ensureUserID() (err error) {
if c.user.ID != "" {
return nil
}
c.user.ID, err = generateID()
return err
}

View File

@@ -0,0 +1,26 @@
package domain
import (
"context"
"github.com/zitadel/zitadel/internal/crypto"
)
type generateCodeCommand struct {
code string
value *crypto.CryptoValue
}
type CryptoRepository interface {
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
}
func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
if err != nil {
return err
}
generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm)
cmd.value, cmd.code, err = crypto.NewCode(generator)
return err
}

View File

@@ -0,0 +1,52 @@
package domain
import (
"math/rand/v2"
"strconv"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
"github.com/zitadel/zitadel/internal/crypto"
)
var (
pool database.Pool
userCodeAlgorithm crypto.EncryptionAlgorithm
tracer tracing.Tracer
// userRepo func(database.QueryExecutor) UserRepository
instanceRepo func(database.QueryExecutor) InstanceRepository
cryptoRepo func(database.QueryExecutor) CryptoRepository
orgRepo func(database.QueryExecutor) OrgRepository
instanceCache cache.Cache[string, string, *Instance]
generateID func() (string, error) = func() (string, error) {
return strconv.FormatUint(rand.Uint64(), 10), nil
}
)
func SetPool(p database.Pool) {
pool = p
}
func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) {
userCodeAlgorithm = algorithm
}
func SetTracer(t tracing.Tracer) {
tracer = t
}
// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
// userRepo = repo
// }
func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
instanceRepo = repo
}
func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) {
cryptoRepo = repo
}

View File

@@ -0,0 +1,45 @@
package domain_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
. "github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
func TestExample(t *testing.T) {
ctx := context.Background()
// SetPool(pool)
exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
require.NoError(t, err)
tracerProvider := sdktrace.NewTracerProvider(
sdktrace.WithSyncer(exporter),
)
otel.SetTracerProvider(tracerProvider)
SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
SetUserRepository(repository.User)
SetInstanceRepository(repository.Instance)
SetCryptoRepository(repository.Crypto)
t.Run("verified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
assert.NoError(t, err)
})
t.Run("unverified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
assert.NoError(t, err)
})
}

View File

@@ -0,0 +1,155 @@
package domain
import (
"context"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
)
type EmailVerifiedCommand struct {
UserID string `json:"userId"`
Email *Email `json:"email"`
}
func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand {
return &EmailVerifiedCommand{
UserID: userID,
Email: &Email{
IsVerified: isVerified,
},
}
}
var (
_ Commander = (*EmailVerifiedCommand)(nil)
_ SetEmailOpt = (*EmailVerifiedCommand)(nil)
)
// Execute implements [Commander]
func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error {
return userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmailVerified(ctx, cmd.Email.Address)
}
// applyOnSetEmail implements [SetEmailOpt]
func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
cmd.UserID = setEmailCmd.UserID
cmd.Email.Address = setEmailCmd.Email
setEmailCmd.verification = cmd
}
type SendCodeCommand struct {
UserID string `json:"userId"`
Email string `json:"email"`
URLTemplate *string `json:"urlTemplate"`
generator *generateCodeCommand
}
var (
_ Commander = (*SendCodeCommand)(nil)
_ SetEmailOpt = (*SendCodeCommand)(nil)
)
func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
return &SendCodeCommand{
UserID: userID,
generator: &generateCodeCommand{},
URLTemplate: urlTemplate,
}
}
// Execute implements [Commander]
func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil {
return err
}
if err := cmd.ensureURL(ctx, opts); err != nil {
return err
}
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
return err
}
// TODO: queue notification
return nil
}
func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
if cmd.Email != "" {
return nil
}
email, err := userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().GetEmail(ctx)
if err != nil || email.IsVerified {
return err
}
cmd.Email = email.Address
return nil
}
func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error {
if cmd.URLTemplate != nil && *cmd.URLTemplate != "" {
return nil
}
_, _ = ctx, opts
// TODO: load default template
return nil
}
// applyOnSetEmail implements [SetEmailOpt]
func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
cmd.UserID = setEmailCmd.UserID
cmd.Email = setEmailCmd.Email
setEmailCmd.verification = cmd
}
type ReturnCodeCommand struct {
UserID string `json:"userId"`
Email string `json:"email"`
Code string `json:"code"`
generator *generateCodeCommand
}
var (
_ Commander = (*ReturnCodeCommand)(nil)
_ SetEmailOpt = (*ReturnCodeCommand)(nil)
)
func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
return &ReturnCodeCommand{
UserID: userID,
generator: &generateCodeCommand{},
}
}
// Execute implements [Commander]
func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil {
return err
}
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
return err
}
cmd.Code = cmd.generator.code
return nil
}
func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
if cmd.Email != "" {
return nil
}
user := v4.UserRepository(opts.DB)
user.WithCondition(user.IDCondition(cmd.UserID))
email, err := user.he.GetEmail(ctx)
if err != nil || email.IsVerified {
return err
}
cmd.Email = email.Address
return nil
}
// applyOnSetEmail implements [SetEmailOpt]
func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
cmd.UserID = setEmailCmd.UserID
cmd.Email = setEmailCmd.Email
setEmailCmd.verification = cmd
}

View File

@@ -0,0 +1,7 @@
package domain
import "errors"
var (
ErrNoAdminSpecified = errors.New("at least one admin must be specified")
)

View File

@@ -0,0 +1,36 @@
package domain
import (
"context"
"time"
)
type Instance struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"-"`
UpdatedAt time.Time `json:"-"`
DeletedAt time.Time `json:"-"`
}
// Keys implements the [cache.Entry].
func (i *Instance) Keys(index string) (key []string) {
// TODO: Return the correct keys for the instance cache, e.g., i.ID, i.Domain
return []string{}
}
type InstanceRepository interface {
ByID(ctx context.Context, id string) (*Instance, error)
Create(ctx context.Context, instance *Instance) error
On(id string) InstanceOperation
}
type InstanceOperation interface {
AdminRepository
Update(ctx context.Context, instance *Instance) error
Delete(ctx context.Context) error
}
type CreateInstance struct {
Name string `json:"name"`
}

View File

@@ -0,0 +1,94 @@
package domain
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
)
var defaultInvoker = newEventStoreInvoker(newTraceInvoker(nil))
func Invoke(ctx context.Context, cmd Commander) error {
invoker := newEventStoreInvoker(newTraceInvoker(nil))
opts := &CommandOpts{
Invoker: invoker.collector,
}
return invoker.Invoke(ctx, cmd, opts)
}
type eventStoreInvoker struct {
collector *eventCollector
}
func newEventStoreInvoker(next Invoker) *eventStoreInvoker {
return &eventStoreInvoker{collector: &eventCollector{next: next}}
}
func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
err = i.collector.Invoke(ctx, command, opts)
if err != nil {
return err
}
if len(i.collector.events) > 0 {
err = eventstore.Publish(ctx, i.collector.events, opts.DB)
if err != nil {
return err
}
}
return nil
}
type eventCollector struct {
next Invoker
events []*eventstore.Event
}
type eventer interface {
Events() []*eventstore.Event
}
func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
if e, ok := command.(eventer); ok && len(e.Events()) > 0 {
// we need to ensure all commands are executed in the same transaction
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
i.events = append(i.events, e.Events()...)
}
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
}
if err != nil {
return err
}
return nil
}
type traceInvoker struct {
next Invoker
}
func newTraceInvoker(next Invoker) *traceInvoker {
return &traceInvoker{next: next}
}
func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
defer span.End()
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
}
if err != nil {
span.RecordError(err)
}
return err
}

39
backend/v3/domain/org.go Normal file
View File

@@ -0,0 +1,39 @@
package domain
import (
"context"
"time"
)
type Org struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type OrgRepository interface {
ByID(ctx context.Context, orgID string) (*Org, error)
Create(ctx context.Context, org *Org) error
On(id string) OrgOperation
}
type OrgOperation interface {
AdminRepository
DomainRepository
Update(ctx context.Context, org *Org) error
Delete(ctx context.Context) error
}
type AdminRepository interface {
AddAdmin(ctx context.Context, userID string, roles []string) error
SetAdminRoles(ctx context.Context, userID string, roles []string) error
RemoveAdmin(ctx context.Context, userID string) error
}
type DomainRepository interface {
AddDomain(ctx context.Context, domain string) error
SetDomainVerified(ctx context.Context, domain string) error
RemoveDomain(ctx context.Context, domain string) error
}

View File

@@ -0,0 +1,74 @@
package domain
import (
"context"
)
type AddOrgCommand struct {
ID string `json:"id"`
Name string `json:"name"`
Admins []AddAdminCommand `json:"admins"`
}
func NewAddOrgCommand(name string, admins ...AddAdminCommand) *AddOrgCommand {
return &AddOrgCommand{
Name: name,
Admins: admins,
}
}
// Execute implements Commander.
func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
if len(cmd.Admins) == 0 {
return ErrNoAdminSpecified
}
if err = cmd.ensureID(); err != nil {
return err
}
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
err = orgRepo(opts.DB).Create(ctx, &Org{
ID: cmd.ID,
Name: cmd.Name,
})
if err != nil {
return err
}
return nil
}
var (
_ Commander = (*AddOrgCommand)(nil)
)
func (cmd *AddOrgCommand) ensureID() (err error) {
if cmd.ID != "" {
return nil
}
cmd.ID, err = generateID()
return err
}
type AddAdminCommand struct {
UserID string `json:"userId"`
Roles []string `json:"roles"`
}
// Execute implements Commander.
func (a *AddAdminCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
return nil
}
var (
_ Commander = (*AddAdminCommand)(nil)
)

View File

@@ -0,0 +1,82 @@
package domain
import (
"time"
"golang.org/x/exp/constraints"
)
type Operation interface {
// TextOperation |
// NumberOperation |
// BoolOperation
op()
}
type clause[F ~uint8, Op Operation] struct {
field F
op Op
}
func (c *clause[F, Op]) Field() F {
return c.field
}
func (c *clause[F, Op]) Operation() Op {
return c.op
}
type Text interface {
~string | ~[]byte
}
type TextOperation uint8
const (
TextOperationEqual TextOperation = iota
TextOperationNotEqual
TextOperationStartsWith
TextOperationStartsWithIgnoreCase
)
func (TextOperation) op() {}
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time
}
type NumberOperation uint8
const (
NumberOperationEqual NumberOperation = iota
NumberOperationNotEqual
NumberOperationLessThan
NumberOperationLessThanOrEqual
NumberOperationGreaterThan
NumberOperationGreaterThanOrEqual
)
func (NumberOperation) op() {}
type Bool interface {
~bool
}
type BoolOperation uint8
const (
BoolOperationIs BoolOperation = iota
BoolOperationNot
)
func (BoolOperation) op() {}
type ListOperation uint8
const (
ListOperationContains ListOperation = iota
ListOperationNotContains
)
func (ListOperation) op() {}

View File

@@ -0,0 +1,64 @@
package domain
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
)
type SetEmailCommand struct {
UserID string `json:"userId"`
Email string `json:"email"`
verification Commander
}
var (
_ Commander = (*SetEmailCommand)(nil)
_ eventer = (*SetEmailCommand)(nil)
_ CreateHumanOpt = (*SetEmailCommand)(nil)
)
type SetEmailOpt interface {
applyOnSetEmail(*SetEmailCommand)
}
func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand {
cmd := &SetEmailCommand{
UserID: userID,
Email: email,
}
verificationType.applyOnSetEmail(cmd)
return cmd
}
func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
// userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email)
err = userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmail(ctx, cmd.Email)
if err != nil {
return err
}
return opts.Invoke(ctx, cmd.verification)
}
// Events implements [eventer].
func (cmd *SetEmailCommand) Events() []*eventstore.Event {
return []*eventstore.Event{
{
AggregateType: "user",
AggregateID: cmd.UserID,
Type: "user.email.set",
Payload: cmd,
},
}
}
// applyOnCreateHuman implements [CreateHumanOpt].
func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand[Human]) {
createUserCmd.email = cmd
}

193
backend/v3/domain/user.go Normal file
View File

@@ -0,0 +1,193 @@
package domain
import (
"context"
"time"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
)
type userColumns interface {
// TODO: move v4.columns to domain
InstanceIDColumn() column
OrgIDColumn() column
IDColumn() column
usernameColumn() column
CreatedAtColumn() column
UpdatedAtColumn() column
DeletedAtColumn() column
}
type userConditions interface {
InstanceIDCondition(instanceID string) v4.Condition
OrgIDCondition(orgID string) v4.Condition
IDCondition(userID string) v4.Condition
UsernameCondition(op v4.TextOperator, username string) v4.Condition
CreatedAtCondition(op v4.NumberOperator, createdAt time.Time) v4.Condition
UpdatedAtCondition(op v4.NumberOperator, updatedAt time.Time) v4.Condition
DeletedCondition(isDeleted bool) v4.Condition
DeletedAtCondition(op v4.NumberOperator, deletedAt time.Time) v4.Condition
}
type UserRepository interface {
userColumns
userConditions
// TODO: move condition to domain
WithCondition(condition v4.Condition) UserRepository
Get(ctx context.Context) (*User, error)
List(ctx context.Context) ([]*User, error)
Create(ctx context.Context, user *User) error
Delete(ctx context.Context) error
Human() HumanRepository
Machine() MachineRepository
}
type humanColumns interface {
FirstNameColumn() column
LastNameColumn() column
EmailAddressColumn() column
EmailVerifiedAtColumn() column
PhoneNumberColumn() column
PhoneVerifiedAtColumn() column
}
type humanConditions interface {
FirstNameCondition(op v4.TextOperator, firstName string) v4.Condition
LastNameCondition(op v4.TextOperator, lastName string) v4.Condition
EmailAddressCondition(op v4.TextOperator, email string) v4.Condition
EmailAddressVerifiedCondition(isVerified bool) v4.Condition
EmailVerifiedAtCondition(op v4.TextOperator, emailVerifiedAt string) v4.Condition
PhoneNumberCondition(op v4.TextOperator, phoneNumber string) v4.Condition
PhoneNumberVerifiedCondition(isVerified bool) v4.Condition
PhoneVerifiedAtCondition(op v4.TextOperator, phoneVerifiedAt string) v4.Condition
}
type HumanRepository interface {
humanColumns
humanConditions
GetEmail(ctx context.Context) (*Email, error)
// TODO: replace any with add email update columns
SetEmail(ctx context.Context, columns ...any) error
}
type machineColumns interface {
DescriptionColumn() column
}
type machineConditions interface {
DescriptionCondition(op v4.TextOperator, description string) v4.Condition
}
type MachineRepository interface {
machineColumns
machineConditions
}
// type UserRepository interface {
// // Get(ctx context.Context, clauses ...UserClause) (*User, error)
// // Search(ctx context.Context, clauses ...UserClause) ([]*User, error)
// UserQuery[UserOperation]
// Human() HumanQuery
// Machine() MachineQuery
// }
// type UserQuery[Op UserOperation] interface {
// ByID(id string) UserQuery[Op]
// Username(username string) UserQuery[Op]
// Exec() Op
// }
// type HumanQuery interface {
// UserQuery[HumanOperation]
// Email(op TextOperation, email string) HumanQuery
// HumanOperation
// }
// type MachineQuery interface {
// UserQuery[MachineOperation]
// MachineOperation
// }
// type UserClause interface {
// Field() UserField
// Operation() Operation
// Args() []any
// }
// type UserField uint8
// const (
// // Fields used for all users
// UserFieldInstanceID UserField = iota + 1
// UserFieldOrgID
// UserFieldID
// UserFieldUsername
// // Fields used for human users
// UserHumanFieldEmail
// UserHumanFieldEmailVerified
// // Fields used for machine users
// UserMachineFieldDescription
// )
// type userByIDClause struct {
// id string
// }
// func (c *userByIDClause) Field() UserField {
// return UserFieldID
// }
// func (c *userByIDClause) Operation() Operation {
// return TextOperationEqual
// }
// func (c *userByIDClause) Args() []any {
// return []any{c.id}
// }
// type UserOperation interface {
// Delete(ctx context.Context) error
// SetUsername(ctx context.Context, username string) error
// }
// type HumanOperation interface {
// UserOperation
// SetEmail(ctx context.Context, email string) error
// SetEmailVerified(ctx context.Context, email string) error
// GetEmail(ctx context.Context) (*Email, error)
// }
// type MachineOperation interface {
// UserOperation
// SetDescription(ctx context.Context, description string) error
// }
type User struct {
v4.User
}
// type userTraits interface {
// isUserTraits()
// }
// type Human struct {
// Email *Email `json:"email"`
// }
// func (*Human) isUserTraits() {}
// type Machine struct {
// Description string `json:"description"`
// }
// func (*Machine) isUserTraits() {}
// type Email struct {
// Address string `json:"address"`
// IsVerified bool `json:"isVerified"`
// }

112
backend/v3/storage/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,112 @@
// Package cache provides abstraction of cache implementations that can be used by zitadel.
package cache
import (
"context"
"time"
"github.com/zitadel/logging"
)
// Purpose describes which object types are stored by a cache.
type Purpose int
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
const (
PurposeUnspecified Purpose = iota
PurposeAuthzInstance
PurposeMilestones
PurposeOrganization
PurposeIdPFormCallback
)
// Cache stores objects with a value of type `V`.
// Objects may be referred to by one or more indices.
// Implementations may encode the value for storage.
// This means non-exported fields may be lost and objects
// with function values may fail to encode.
// See https://pkg.go.dev/encoding/json#Marshal for example.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
// Indices are defined when calling the constructor of an implementation of this interface.
// It is illegal to refer to an idex not defined during construction.
//
// `K` is the type used as key in each index.
// Due to the limitations in type constraints, all indices use the same key type.
//
// Implementations are free to use stricter type constraints or fixed typing.
type Cache[I, K comparable, V Entry[I, K]] interface {
// Get an object through specified index.
// An [IndexUnknownError] may be returned if the index is unknown.
// [ErrCacheMiss] is returned if the key was not found in the index,
// or the object is not valid.
Get(ctx context.Context, index I, key K) (V, bool)
// Set an object.
// Keys are created on each index based in the [Entry.Keys] method.
// If any key maps to an existing object, the object is invalidated,
// regardless if the object has other keys defined in the new entry.
// This to prevent ghost objects when an entry reduces the amount of keys
// for a given index.
Set(ctx context.Context, value V)
// Invalidate an object through specified index.
// Implementations may choose to instantly delete the object,
// defer until prune or a separate cleanup routine.
// Invalidated object are no longer returned from Get.
// It is safe to call Invalidate multiple times or on non-existing entries.
Invalidate(ctx context.Context, index I, key ...K) error
// Delete one or more keys from a specific index.
// An [IndexUnknownError] may be returned if the index is unknown.
// The referred object is not invalidated and may still be accessible though
// other indices and keys.
// It is safe to call Delete multiple times or on non-existing entries
Delete(ctx context.Context, index I, key ...K) error
// Truncate deletes all cached objects.
Truncate(ctx context.Context) error
}
// Entry contains a value of type `V` to be cached.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
//
// `K` is the type used as key in an index.
// Due to the limitations in type constraints, all indices use the same key type.
type Entry[I, K comparable] interface {
// Keys returns which keys map to the object in a specified index.
// May return nil if the index in unknown or when there are no keys.
Keys(index I) (key []K)
}
type Connector int
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
const (
// Empty line comment ensures empty string for unspecified value
ConnectorUnspecified Connector = iota //
ConnectorMemory
ConnectorPostgres
ConnectorRedis
)
type Config struct {
Connector Connector
// Age since an object was added to the cache,
// after which the object is considered invalid.
// 0 disables max age checks.
MaxAge time.Duration
// Age since last use (Get) of an object,
// after which the object is considered invalid.
// 0 disables last use age checks.
LastUseAge time.Duration
// Log allows logging of the specific cache.
// By default only errors are logged to stdout.
Log *logging.Config
}

View File

@@ -0,0 +1,49 @@
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
package connector
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap"
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop"
)
type CachesConfig struct {
Connectors struct {
Memory gomap.Config
}
Instance *cache.Config
Milestones *cache.Config
Organization *cache.Config
IdPFormCallbacks *cache.Config
}
type Connectors struct {
Config CachesConfig
Memory *gomap.Connector
}
func StartConnectors(conf *CachesConfig) (Connectors, error) {
if conf == nil {
return Connectors{}, nil
}
return Connectors{
Config: *conf,
Memory: gomap.NewConnector(conf.Connectors.Memory),
}, nil
}
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
return noop.NewCache[I, K, V](), nil
}
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@@ -0,0 +1,23 @@
package gomap
import (
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type Config struct {
Enabled bool
AutoPrune cache.AutoPruneConfig
}
type Connector struct {
Config cache.AutoPruneConfig
}
func NewConnector(config Config) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
Config: config.AutoPrune,
}
}

View File

@@ -0,0 +1,200 @@
package gomap
import (
"context"
"errors"
"log/slog"
"maps"
"os"
"sync"
"sync/atomic"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
config *cache.Config
indexMap map[I]*index[K, V]
logger *slog.Logger
}
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
// Object values are stored as-is and there is no encoding or decoding involved.
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
m := &mapCache[I, K, V]{
config: &config,
indexMap: make(map[I]*index[K, V], len(indices)),
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelError,
})),
}
if config.Log != nil {
m.logger = config.Log.Slog()
}
m.logger.InfoContext(background, "map cache logging enabled")
for _, name := range indices {
m.indexMap[name] = &index[K, V]{
config: m.config,
entries: make(map[K]*entry[V]),
}
}
return m
}
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
i, ok := c.indexMap[index]
if !ok {
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
entry, err := i.Get(key)
if err == nil {
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
return entry.value, true
}
if errors.Is(err, cache.ErrCacheMiss) {
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
return value, false
}
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
now := time.Now()
entry := &entry[V]{
value: value,
created: now,
}
entry.lastUse.Store(now.UnixMicro())
for name, i := range c.indexMap {
keys := value.Keys(name)
i.Set(keys, entry)
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
}
}
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Invalidate(keys)
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Delete(keys)
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
for name, index := range c.indexMap {
index.Prune()
c.logger.DebugContext(ctx, "map cache prune", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
for name, index := range c.indexMap {
index.Truncate()
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
}
return nil
}
type index[K comparable, V any] struct {
mutex sync.RWMutex
config *cache.Config
entries map[K]*entry[V]
}
func (i *index[K, V]) Get(key K) (*entry[V], error) {
i.mutex.RLock()
entry, ok := i.entries[key]
i.mutex.RUnlock()
if ok && entry.isValid(i.config) {
return entry, nil
}
return nil, cache.ErrCacheMiss
}
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
c.mutex.Lock()
for _, key := range keys {
c.entries[key] = entry
}
c.mutex.Unlock()
}
func (i *index[K, V]) Invalidate(keys []K) {
i.mutex.RLock()
for _, key := range keys {
if entry, ok := i.entries[key]; ok {
entry.invalid.Store(true)
}
}
i.mutex.RUnlock()
}
func (c *index[K, V]) Delete(keys []K) {
c.mutex.Lock()
for _, key := range keys {
delete(c.entries, key)
}
c.mutex.Unlock()
}
func (c *index[K, V]) Prune() {
c.mutex.Lock()
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
return !entry.isValid(c.config)
})
c.mutex.Unlock()
}
func (c *index[K, V]) Truncate() {
c.mutex.Lock()
c.entries = make(map[K]*entry[V])
c.mutex.Unlock()
}
type entry[V any] struct {
value V
created time.Time
invalid atomic.Bool
lastUse atomic.Int64 // UnixMicro time
}
func (e *entry[V]) isValid(c *cache.Config) bool {
if e.invalid.Load() {
return false
}
now := time.Now()
if c.MaxAge > 0 {
if e.created.Add(c.MaxAge).Before(now) {
e.invalid.Store(true)
return false
}
}
if c.LastUseAge > 0 {
lastUse := e.lastUse.Load()
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
e.invalid.Store(true)
return false
}
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
}
return true
}

View File

@@ -0,0 +1,329 @@
package gomap
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
id string
names []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.id}
case testIndexName:
return o.names
default:
return nil
}
}
func Test_mapCache_Get(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
args args
want *testObject
wantOk bool
}{
{
name: "ok",
args: args{
index: testIndexID,
key: "id",
},
want: obj,
wantOk: true,
},
{
name: "miss",
args: args{
index: testIndexID,
key: "spanac",
},
want: nil,
wantOk: false,
},
{
name: "unknown index",
args: args{
index: 99,
key: "id",
},
want: nil,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, ok)
})
}
}
func Test_mapCache_Invalidate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
got, ok := c.Get(context.Background(), testIndexID, "id")
assert.Nil(t, got)
assert.False(t, ok)
}
func Test_mapCache_Delete(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Delete(context.Background(), testIndexName, "bar")
require.NoError(t, err)
// Shouldn't find object by deleted name
got, ok := c.Get(context.Background(), testIndexName, "bar")
assert.Nil(t, got)
assert.False(t, ok)
// Should find object by other name
got, ok = c.Get(context.Background(), testIndexName, "foo")
assert.Equal(t, obj, got)
assert.True(t, ok)
// Should find object by id
got, ok = c.Get(context.Background(), testIndexID, "id")
assert.Equal(t, obj, got)
assert.True(t, ok)
}
func Test_mapCache_Prune(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
// invalidate one entry
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
err = c.(cache.Pruner).Prune(context.Background())
require.NoError(t, err)
// Other object should still be found
got, ok := c.Get(context.Background(), testIndexID, "id2")
assert.Equal(t, objects[1], got)
assert.True(t, ok)
}
func Test_mapCache_Truncate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
err := c.Truncate(context.Background())
require.NoError(t, err)
mc := c.(*mapCache[testIndex, string, *testObject])
for _, index := range mc.indexMap {
index.mutex.RLock()
assert.Len(t, index.entries, 0)
index.mutex.RUnlock()
}
}
func Test_entry_isValid(t *testing.T) {
type fields struct {
created time.Time
invalid bool
lastUse time.Time
}
tests := []struct {
name string
fields fields
config *cache.Config
want bool
}{
{
name: "invalid",
fields: fields{
created: time.Now(),
invalid: true,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
LastUseAge: time.Second,
},
want: true,
},
{
name: "last use age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "last use age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
},
want: true,
},
{
name: "valid",
fields: fields{
created: time.Now(),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &entry[any]{
created: tt.fields.created,
}
e.invalid.Store(tt.fields.invalid)
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
got := e.isValid(tt.config)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,21 @@
package noop
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
// NewCache returns a cache that does nothing
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
return noop[I, K, V]{}
}
func (noop[I, K, V]) Set(context.Context, V) {}
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }

View File

@@ -0,0 +1,98 @@
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _ConnectorName = "memorypostgresredis"
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
const _ConnectorLowerName = "memorypostgresredis"
func (i Connector) String() string {
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
return fmt.Sprintf("Connector(%d)", i)
}
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _ConnectorNoOp() {
var x [1]struct{}
_ = x[ConnectorUnspecified-(0)]
_ = x[ConnectorMemory-(1)]
_ = x[ConnectorPostgres-(2)]
_ = x[ConnectorRedis-(3)]
}
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
var _ConnectorNameToValueMap = map[string]Connector{
_ConnectorName[0:0]: ConnectorUnspecified,
_ConnectorLowerName[0:0]: ConnectorUnspecified,
_ConnectorName[0:6]: ConnectorMemory,
_ConnectorLowerName[0:6]: ConnectorMemory,
_ConnectorName[6:14]: ConnectorPostgres,
_ConnectorLowerName[6:14]: ConnectorPostgres,
_ConnectorName[14:19]: ConnectorRedis,
_ConnectorLowerName[14:19]: ConnectorRedis,
}
var _ConnectorNames = []string{
_ConnectorName[0:0],
_ConnectorName[0:6],
_ConnectorName[6:14],
_ConnectorName[14:19],
}
// ConnectorString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func ConnectorString(s string) (Connector, error) {
if val, ok := _ConnectorNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Connector values", s)
}
// ConnectorValues returns all values of the enum
func ConnectorValues() []Connector {
return _ConnectorValues
}
// ConnectorStrings returns a slice of all String values of the enum
func ConnectorStrings() []string {
strs := make([]string, len(_ConnectorNames))
copy(strs, _ConnectorNames)
return strs
}
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Connector) IsAConnector() bool {
for _, v := range _ConnectorValues {
if i == v {
return true
}
}
return false
}
// MarshalText implements the encoding.TextMarshaler interface for Connector
func (i Connector) MarshalText() ([]byte, error) {
return []byte(i.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
func (i *Connector) UnmarshalText(text []byte) error {
var err error
*i, err = ConnectorString(string(text))
return err
}

29
backend/v3/storage/cache/error.go vendored Normal file
View File

@@ -0,0 +1,29 @@
package cache
import (
"errors"
"fmt"
)
type IndexUnknownError[I comparable] struct {
index I
}
func NewIndexUnknownErr[I comparable](index I) error {
return IndexUnknownError[I]{index}
}
func (i IndexUnknownError[I]) Error() string {
return fmt.Sprintf("index %v unknown", i.index)
}
func (a IndexUnknownError[I]) Is(err error) bool {
if b, ok := err.(IndexUnknownError[I]); ok {
return a.index == b.index
}
return false
}
var (
ErrCacheMiss = errors.New("cache miss")
)

76
backend/v3/storage/cache/pruner.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"context"
"math/rand"
"time"
"github.com/jonboulle/clockwork"
"github.com/zitadel/logging"
)
// Pruner is an optional [Cache] interface.
type Pruner interface {
// Prune deletes all invalidated or expired objects.
Prune(ctx context.Context) error
}
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
Cache[I, K, V]
Pruner
}
type AutoPruneConfig struct {
// Interval at which the cache is automatically pruned.
// 0 or lower disables automatic pruning.
Interval time.Duration
// Timeout for an automatic prune.
// It is recommended to keep the value shorter than AutoPruneInterval
// 0 or lower disables automatic pruning.
Timeout time.Duration
}
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
}
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
if c.Interval <= 0 {
return func() {}
}
background, cancel := context.WithCancel(background)
// randomize the first interval
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
go c.pruneTimer(background, pruner, purpose, timer)
return cancel
}
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
defer func() {
if !timer.Stop() {
<-timer.Chan()
}
}()
for {
select {
case <-background.Done():
return
case <-timer.Chan():
err := c.doPrune(background, pruner)
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
timer.Reset(c.Interval)
}
}
}
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
ctx, cancel := context.WithCancel(background)
defer cancel()
if c.Timeout > 0 {
ctx, cancel = context.WithTimeout(background, c.Timeout)
defer cancel()
}
return pruner.Prune(ctx)
}

43
backend/v3/storage/cache/pruner_test.go vendored Normal file
View File

@@ -0,0 +1,43 @@
package cache
import (
"context"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
)
type testPruner struct {
called chan struct{}
}
func (p *testPruner) Prune(context.Context) error {
p.called <- struct{}{}
return nil
}
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
c := AutoPruneConfig{
Interval: time.Second,
Timeout: time.Millisecond,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pruner := testPruner{
called: make(chan struct{}),
}
clock := clockwork.NewFakeClock()
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
defer close()
clock.Advance(time.Second)
select {
case _, ok := <-pruner.called:
assert.True(t, ok)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}

View File

@@ -0,0 +1,90 @@
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
func (i Purpose) String() string {
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
return fmt.Sprintf("Purpose(%d)", i)
}
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _PurposeNoOp() {
var x [1]struct{}
_ = x[PurposeUnspecified-(0)]
_ = x[PurposeAuthzInstance-(1)]
_ = x[PurposeMilestones-(2)]
_ = x[PurposeOrganization-(3)]
_ = x[PurposeIdPFormCallback-(4)]
}
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
var _PurposeNameToValueMap = map[string]Purpose{
_PurposeName[0:11]: PurposeUnspecified,
_PurposeLowerName[0:11]: PurposeUnspecified,
_PurposeName[11:25]: PurposeAuthzInstance,
_PurposeLowerName[11:25]: PurposeAuthzInstance,
_PurposeName[25:35]: PurposeMilestones,
_PurposeLowerName[25:35]: PurposeMilestones,
_PurposeName[35:47]: PurposeOrganization,
_PurposeLowerName[35:47]: PurposeOrganization,
_PurposeName[47:65]: PurposeIdPFormCallback,
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
}
var _PurposeNames = []string{
_PurposeName[0:11],
_PurposeName[11:25],
_PurposeName[25:35],
_PurposeName[35:47],
_PurposeName[47:65],
}
// PurposeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func PurposeString(s string) (Purpose, error) {
if val, ok := _PurposeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
}
// PurposeValues returns all values of the enum
func PurposeValues() []Purpose {
return _PurposeValues
}
// PurposeStrings returns a slice of all String values of the enum
func PurposeStrings() []string {
strs := make([]string, len(_PurposeNames))
copy(strs, _PurposeNames)
return strs
}
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Purpose) IsAPurpose() bool {
for _, v := range _PurposeValues {
if i == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,9 @@
package database
import (
"context"
)
type Connector interface {
Connect(ctx context.Context) (Pool, error)
}

View File

@@ -0,0 +1,60 @@
package database
import (
"context"
)
var (
db *database
)
type database struct {
connector Connector
pool Pool
}
type Pool interface {
Beginner
QueryExecutor
Acquire(ctx context.Context) (Client, error)
Close(ctx context.Context) error
}
type Client interface {
Beginner
QueryExecutor
Release(ctx context.Context) error
}
type Querier interface {
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
QueryRow(ctx context.Context, stmt string, args ...any) Row
}
type Executor interface {
Exec(ctx context.Context, stmt string, args ...any) error
}
type QueryExecutor interface {
Querier
Executor
}
type Scanner interface {
Scan(dest ...any) error
}
type Row interface {
Scanner
}
type Rows interface {
Row
Next() bool
Close() error
Err() error
}
type Query[T any] func(querier Querier) (result T, err error)

View File

@@ -0,0 +1,92 @@
package dialect
import (
"context"
"errors"
"reflect"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
"github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
)
type Hook struct {
Match func(string) bool
Decode func(config any) (database.Connector, error)
Name string
Constructor func() database.Connector
}
var hooks = []Hook{
{
Match: postgres.NameMatcher,
Decode: postgres.DecodeConfig,
Name: postgres.Name,
Constructor: func() database.Connector { return new(postgres.Config) },
},
// {
// Match: gosql.NameMatcher,
// Decode: gosql.DecodeConfig,
// Name: gosql.Name,
// Constructor: func() database.Connector { return new(gosql.Config) },
// },
}
type Config struct {
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
connector database.Connector
}
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
if len(c.Dialects) != 1 {
return nil, errors.New("Exactly one dialect must be configured")
}
return c.connector.Connect(ctx)
}
// Hooks implements [configure.Unmarshaller].
func (c Config) Hooks() []viper.DecoderConfigOption {
return []viper.DecoderConfigOption{
viper.DecodeHook(decodeHook),
}
}
func decodeHook(from, to reflect.Value) (_ any, err error) {
if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil
}
config := new(Config)
if err = mapstructure.Decode(from.Interface(), config); err != nil {
return nil, err
}
if err = config.decodeDialect(); err != nil {
return nil, err
}
return config, nil
}
func (c *Config) decodeDialect() error {
for _, hook := range hooks {
for name, config := range c.Dialects {
if !hook.Match(name) {
continue
}
connector, err := hook.Decode(config)
if err != nil {
return err
}
c.connector = connector
return nil
}
}
return errors.New("no dialect found")
}

View File

@@ -0,0 +1,80 @@
package postgres
import (
"context"
"errors"
"slices"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var (
_ database.Connector = (*Config)(nil)
Name = "postgres"
)
type Config struct {
*pgxpool.Config
// Host string
// Port int32
// Database string
// MaxOpenConns uint32
// MaxIdleConns uint32
// MaxConnLifetime time.Duration
// MaxConnIdleTime time.Duration
// User User
// // Additional options to be appended as options=<Options>
// // The value will be taken as is. Multiple options are space separated.
// Options string
configuredFields []string
}
// Connect implements [database.Connector].
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
pool, err := pgxpool.NewWithConfig(ctx, c.Config)
if err != nil {
return nil, err
}
if err = pool.Ping(ctx); err != nil {
return nil, err
}
return &pgxPool{pool}, nil
}
func NameMatcher(name string) bool {
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
}
func DecodeConfig(input any) (database.Connector, error) {
switch c := input.(type) {
case string:
config, err := pgxpool.ParseConfig(c)
if err != nil {
return nil, err
}
return &Config{Config: config}, nil
case map[string]any:
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
if err = decoder.Decode(c); err != nil {
return nil, err
}
return &Config{
Config: &pgxpool.Config{},
}, nil
}
return nil, errors.New("invalid configuration")
}

View File

@@ -0,0 +1,48 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxConn struct{ *pgxpool.Conn }
var _ database.Client = (*pgxConn)(nil)
// Release implements [database.Client].
func (c *pgxConn) Release(_ context.Context) error {
c.Conn.Release()
return nil
}
// Begin implements [database.Client].
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Query implements sql.Client.
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Conn.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Conn.Exec(ctx, sql, args...)
return err
}

View File

@@ -0,0 +1,57 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxPool struct{ *pgxpool.Pool }
var _ database.Pool = (*pgxPool)(nil)
// Acquire implements [database.Pool].
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Pool.Acquire(ctx)
if err != nil {
return nil, err
}
return &pgxConn{conn}, nil
}
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Pool.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Pool.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Pool.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Pool].
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Close implements [database.Pool].
func (c *pgxPool) Close(_ context.Context) error {
c.Pool.Close()
return nil
}

View File

@@ -0,0 +1,18 @@
package postgres
import (
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ database.Rows = (*Rows)(nil)
type Rows struct{ pgx.Rows }
// Close implements [database.Rows].
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
func (r *Rows) Close() error {
r.Rows.Close()
return nil
}

View File

@@ -0,0 +1,95 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxTx struct{ pgx.Tx }
var _ database.Transaction = (*pgxTx)(nil)
// Commit implements [database.Transaction].
func (tx *pgxTx) Commit(ctx context.Context) error {
return tx.Tx.Commit(ctx)
}
// Rollback implements [database.Transaction].
func (tx *pgxTx) Rollback(ctx context.Context) error {
return tx.Tx.Rollback(ctx)
}
// End implements [database.Transaction].
func (tx *pgxTx) End(ctx context.Context, err error) error {
if err != nil {
tx.Rollback(ctx)
return err
}
return tx.Commit(ctx)
}
// Query implements [database.Transaction].
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := tx.Tx.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return tx.Tx.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
_, err := tx.Tx.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Transaction].
// As postgres does not support nested transactions we use savepoints to emulate them.
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
savepoint, err := tx.Tx.Begin(ctx)
if err != nil {
return nil, err
}
return &pgxTx{savepoint}, nil
}
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
if opts == nil {
return pgx.TxOptions{}
}
return pgx.TxOptions{
IsoLevel: isolationToPgx(opts.IsolationLevel),
AccessMode: accessModeToPgx(opts.AccessMode),
}
}
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
switch isolation {
case database.IsolationLevelSerializable:
return pgx.Serializable
case database.IsolationLevelReadCommitted:
return pgx.ReadCommitted
default:
return pgx.Serializable
}
}
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
switch accessMode {
case database.AccessModeReadWrite:
return pgx.ReadWrite
case database.AccessModeReadOnly:
return pgx.ReadOnly
default:
return pgx.ReadWrite
}
}

View File

@@ -0,0 +1,3 @@
package database
//go:generate mockgen -typed -package mock -destination ./mock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,160 @@
package repository
import (
"fmt"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type field interface {
fmt.Stringer
}
type fieldDescriptor struct {
schema string
table string
name string
}
func (f fieldDescriptor) String() string {
return f.schema + "." + f.table + "." + f.name
}
type ignoreCaseFieldDescriptor struct {
fieldDescriptor
fieldNameSuffix string
}
func (f ignoreCaseFieldDescriptor) String() string {
return f.fieldDescriptor.String() + f.fieldNameSuffix
}
type textFieldDescriptor struct {
field
isIgnoreCase bool
}
type clause[Op domain.Operation] struct {
field field
op Op
}
const (
schema = "zitadel"
userTable = "users"
)
var userFields = map[domain.UserField]field{
domain.UserFieldInstanceID: fieldDescriptor{
schema: schema,
table: userTable,
name: "instance_id",
},
domain.UserFieldOrgID: fieldDescriptor{
schema: schema,
table: userTable,
name: "org_id",
},
domain.UserFieldID: fieldDescriptor{
schema: schema,
table: userTable,
name: "id",
},
domain.UserFieldUsername: textFieldDescriptor{
field: ignoreCaseFieldDescriptor{
fieldDescriptor: fieldDescriptor{
schema: schema,
table: userTable,
name: "username",
},
fieldNameSuffix: "_lower",
},
},
domain.UserHumanFieldEmail: textFieldDescriptor{
field: ignoreCaseFieldDescriptor{
fieldDescriptor: fieldDescriptor{
schema: schema,
table: userTable,
name: "email",
},
fieldNameSuffix: "_lower",
},
},
domain.UserHumanFieldEmailVerified: fieldDescriptor{
schema: schema,
table: userTable,
name: "email_is_verified",
},
}
type textClause[V domain.Text] struct {
clause[domain.TextOperation]
value V
}
var textOp map[domain.TextOperation]string = map[domain.TextOperation]string{
domain.TextOperationEqual: " = ",
domain.TextOperationNotEqual: " <> ",
domain.TextOperationStartsWith: " LIKE ",
domain.TextOperationStartsWithIgnoreCase: " LIKE ",
}
func (tc textClause[V]) Write(stmt *statement) {
placeholder := stmt.appendArg(tc.value)
var (
left, right string
)
switch tc.clause.op {
case domain.TextOperationEqual:
left = tc.clause.field.String()
right = placeholder
case domain.TextOperationNotEqual:
left = tc.clause.field.String()
right = placeholder
case domain.TextOperationStartsWith:
left = tc.clause.field.String()
right = placeholder + "%"
case domain.TextOperationStartsWithIgnoreCase:
left = tc.clause.field.String()
if _, ok := tc.clause.field.(ignoreCaseFieldDescriptor); !ok {
left = "LOWER(" + left + ")"
}
right = "LOWER(" + placeholder + "%)"
}
stmt.builder.WriteString(left)
stmt.builder.WriteString(textOp[tc.clause.op])
stmt.builder.WriteString(right)
}
type boolClause[V domain.Bool] struct {
clause[domain.BoolOperation]
value V
}
func (bc boolClause[V]) Write(stmt *statement) {
if !bc.value {
stmt.builder.WriteString("NOT ")
}
stmt.builder.WriteString(bc.clause.field.String())
}
type numberClause[V domain.Number] struct {
clause[domain.NumberOperation]
value V
}
var numberOp map[domain.NumberOperation]string = map[domain.NumberOperation]string{
domain.NumberOperationEqual: " = ",
domain.NumberOperationNotEqual: " <> ",
domain.NumberOperationLessThan: " < ",
domain.NumberOperationLessThanOrEqual: " <= ",
domain.NumberOperationGreaterThan: " > ",
domain.NumberOperationGreaterThanOrEqual: " >= ",
}
func (nc numberClause[V]) Write(stmt *statement) {
stmt.builder.WriteString(nc.clause.field.String())
stmt.builder.WriteString(numberOp[nc.clause.op])
stmt.builder.WriteString(stmt.appendArg(nc.value))
}

View File

@@ -0,0 +1,45 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/internal/crypto"
)
type cryptoRepo struct {
database.QueryExecutor
}
func Crypto(db database.QueryExecutor) domain.CryptoRepository {
return &cryptoRepo{
QueryExecutor: db,
}
}
const getEncryptionConfigQuery = "SELECT" +
" length" +
", expiry" +
", should_include_lower_letters" +
", should_include_upper_letters" +
", should_include_digits" +
", should_include_symbols" +
" FROM encryption_config"
func (repo *cryptoRepo) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) {
var config crypto.GeneratorConfig
row := repo.QueryRow(ctx, getEncryptionConfigQuery)
err := row.Scan(
&config.Length,
&config.Expiry,
&config.IncludeLowerLetters,
&config.IncludeUpperLetters,
&config.IncludeDigits,
&config.IncludeSymbols,
)
if err != nil {
return nil, err
}
return &config, nil
}

View File

@@ -0,0 +1,7 @@
// Repository package provides the database repository for the application.
// It contains the implementation of the [repository pattern](https://martinfowler.com/eaaCatalog/repository.html) for the database.
// funcs which need to interact with the database should create interfaces which are implemented by the
// [query] and [exec] structs respectively their factory methods [Query] and [Execute]. The [query] struct is used for read operations, while the [exec] struct is used for write operations.
package repository

View File

@@ -0,0 +1,54 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type instance struct {
database.QueryExecutor
}
func Instance(client database.QueryExecutor) domain.InstanceRepository {
return &instance{QueryExecutor: client}
}
func (i *instance) ByID(ctx context.Context, id string) (*domain.Instance, error) {
var instance domain.Instance
err := i.QueryExecutor.QueryRow(ctx, `SELECT id, name, created_at, updated_at, deleted_at FROM instances WHERE id = $1`, id).Scan(
&instance.ID,
&instance.Name,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
return nil, err
}
return &instance, nil
}
const createInstanceStmt = `INSERT INTO instances (id, name) VALUES ($1, $2) RETURNING created_at, updated_at`
// Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
return i.QueryExecutor.QueryRow(ctx, createInstanceStmt,
instance.ID,
instance.Name,
).Scan(
&instance.CreatedAt,
&instance.UpdatedAt,
)
}
// On implements [domain.InstanceRepository].
func (i *instance) On(id string) domain.InstanceOperation {
return &instanceOperation{
QueryExecutor: i.QueryExecutor,
id: id,
}
}
var _ domain.InstanceRepository = (*instance)(nil)

View File

@@ -0,0 +1,52 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type instanceOperation struct {
database.QueryExecutor
id string
}
const addInstanceAdminStmt = `INSERT INTO instance_admins (instance_id, user_id, roles) VALUES ($1, $2, $3)`
// AddAdmin implements [domain.InstanceOperation].
func (i *instanceOperation) AddAdmin(ctx context.Context, userID string, roles []string) error {
return i.QueryExecutor.Exec(ctx, addInstanceAdminStmt, i.id, userID, roles)
}
// Delete implements [domain.InstanceOperation].
func (i *instanceOperation) Delete(ctx context.Context) error {
return i.QueryExecutor.Exec(ctx, `DELETE FROM instances WHERE id = $1`, i.id)
}
const removeInstanceAdminStmt = `DELETE FROM instance_admins WHERE instance_id = $1 AND user_id = $2`
// RemoveAdmin implements [domain.InstanceOperation].
func (i *instanceOperation) RemoveAdmin(ctx context.Context, userID string) error {
return i.QueryExecutor.Exec(ctx, removeInstanceAdminStmt, i.id, userID)
}
const setInstanceAdminRolesStmt = `UPDATE instance_admins SET roles = $1 WHERE instance_id = $2 AND user_id = $3`
// SetAdminRoles implements [domain.InstanceOperation].
func (i *instanceOperation) SetAdminRoles(ctx context.Context, userID string, roles []string) error {
return i.QueryExecutor.Exec(ctx, setInstanceAdminRolesStmt, roles, i.id, userID)
}
const updateInstanceStmt = `UPDATE instances SET name = $1, updated_at = $2 WHERE id = $3 RETURNING updated_at`
// Update implements [domain.InstanceOperation].
func (i *instanceOperation) Update(ctx context.Context, instance *domain.Instance) error {
return i.QueryExecutor.QueryRow(ctx, updateInstanceStmt,
instance.Name,
instance.UpdatedAt,
i.id,
).Scan(&instance.UpdatedAt)
}
var _ domain.InstanceOperation = (*instanceOperation)(nil)

View File

@@ -0,0 +1,17 @@
package repository
import (
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type query struct{ database.Querier }
func Query(querier database.Querier) *query {
return &query{Querier: querier}
}
type executor struct{ database.Executor }
func Execute(exec database.Executor) *executor {
return &executor{Executor: exec}
}

View File

@@ -0,0 +1,21 @@
package repository
import "strings"
type statement struct {
builder strings.Builder
args []any
}
func (s *statement) appendArg(arg any) (placeholder string) {
s.args = append(s.args, arg)
return "$" + string(len(s.args))
}
func (s *statement) appendArgs(args ...any) (placeholders []string) {
placeholders = make([]string, len(args))
for i, arg := range args {
placeholders[i] = s.appendArg(arg)
}
return placeholders
}

View File

@@ -0,0 +1,43 @@
package stmt
import "fmt"
type Column[T any] interface {
fmt.Stringer
statementApplier[T]
scanner(t *T) any
}
type columnDescriptor[T any] struct {
name string
scan func(*T) any
}
func (cd columnDescriptor[T]) scanner(t *T) any {
return cd.scan(t)
}
// Apply implements [Column].
func (f columnDescriptor[T]) Apply(stmt *statement[T]) {
stmt.builder.WriteString(stmt.columnPrefix())
stmt.builder.WriteString(f.String())
}
// String implements [Column].
func (f columnDescriptor[T]) String() string {
return f.name
}
var _ Column[any] = (*columnDescriptor[any])(nil)
type ignoreCaseColumnDescriptor[T any] struct {
columnDescriptor[T]
fieldNameSuffix string
}
func (f ignoreCaseColumnDescriptor[T]) ApplyIgnoreCase(stmt *statement[T]) {
stmt.builder.WriteString(f.String())
stmt.builder.WriteString(f.fieldNameSuffix)
}
var _ Column[any] = (*ignoreCaseColumnDescriptor[any])(nil)

View File

@@ -0,0 +1,97 @@
package stmt
import "fmt"
type statementApplier[T any] interface {
// Apply writes the statement to the builder.
Apply(stmt *statement[T])
}
type Condition[T any] interface {
statementApplier[T]
}
type op interface {
TextOperation | NumberOperation | ListOperation
fmt.Stringer
}
type operation[T any, O op] struct {
o O
}
func (o operation[T, O]) String() string {
return o.o.String()
}
func (o operation[T, O]) Apply(stmt *statement[T]) {
stmt.builder.WriteString(o.o.String())
}
type condition[V, T any, OP op] struct {
field Column[T]
op OP
value V
}
func (c *condition[V, T, OP]) Apply(stmt *statement[T]) {
// placeholder := stmt.appendArg(c.value)
stmt.builder.WriteString(stmt.columnPrefix())
stmt.builder.WriteString(c.field.String())
// stmt.builder.WriteString(c.op)
// stmt.builder.WriteString(placeholder)
}
type and[T any] struct {
conditions []Condition[T]
}
func And[T any](conditions ...Condition[T]) *and[T] {
return &and[T]{
conditions: conditions,
}
}
// Apply implements [Condition].
func (a *and[T]) Apply(stmt *statement[T]) {
if len(a.conditions) > 1 {
stmt.builder.WriteString("(")
defer stmt.builder.WriteString(")")
}
for i, condition := range a.conditions {
if i > 0 {
stmt.builder.WriteString(" AND ")
}
condition.Apply(stmt)
}
}
var _ Condition[any] = (*and[any])(nil)
type or[T any] struct {
conditions []Condition[T]
}
func Or[T any](conditions ...Condition[T]) *or[T] {
return &or[T]{
conditions: conditions,
}
}
// Apply implements [Condition].
func (o *or[T]) Apply(stmt *statement[T]) {
if len(o.conditions) > 1 {
stmt.builder.WriteString("(")
defer stmt.builder.WriteString(")")
}
for i, condition := range o.conditions {
if i > 0 {
stmt.builder.WriteString(" OR ")
}
condition.Apply(stmt)
}
}
var _ Condition[any] = (*or[any])(nil)

View File

@@ -0,0 +1,71 @@
package stmt
type ListEntry interface {
Number | Text | any
}
type ListCondition[E ListEntry, T any] struct {
condition[[]E, T, ListOperation]
}
func (lc *ListCondition[E, T]) Apply(stmt *statement[T]) {
placeholder := stmt.appendArg(lc.value)
switch lc.op {
case ListOperationEqual, ListOperationNotEqual:
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case ListOperationContainsAny, ListOperationContainsAll:
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case ListOperationNotContainsAny, ListOperationNotContainsAll:
stmt.builder.WriteString("NOT (")
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
default:
panic("unknown list operation")
}
}
type ListOperation uint8
const (
// ListOperationEqual checks if the arrays are equal including the order of the elements
ListOperationEqual ListOperation = iota + 1
// ListOperationNotEqual checks if the arrays are not equal including the order of the elements
ListOperationNotEqual
// ListOperationContains checks if the array column contains all the values of the specified array
ListOperationContainsAll
// ListOperationContainsAny checks if the arrays have at least one value in common
ListOperationContainsAny
// ListOperationContainsAll checks if the array column contains all the values of the specified array
// ListOperationNotContainsAll checks if the specified array is not contained by the column
ListOperationNotContainsAll
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
ListOperationNotContainsAny
)
var listOperations = map[ListOperation]string{
// ListOperationEqual checks if the lists are equal
ListOperationEqual: " = ",
// ListOperationNotEqual checks if the lists are not equal
ListOperationNotEqual: " <> ",
// ListOperationContainsAny checks if the arrays have at least one value in common
ListOperationContainsAny: " && ",
// ListOperationContainsAll checks if the array column contains all the values of the specified array
ListOperationContainsAll: " @> ",
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
ListOperationNotContainsAny: " && ", // Base operator for NOT (A && B)
// ListOperationNotContainsAll checks if the array column is not contained by the specified array
ListOperationNotContainsAll: " <@ ", // Base operator for NOT (A <@ B)
}
func (lo ListOperation) String() string {
return listOperations[lo]
}

View File

@@ -0,0 +1,61 @@
package stmt
import (
"time"
"golang.org/x/exp/constraints"
)
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
type between[N Number] struct {
min, max N
}
type NumberBetween[V Number, T any] struct {
condition[between[V], T, NumberOperation]
}
func (nb *NumberBetween[V, T]) Apply(stmt *statement[T]) {
nb.field.Apply(stmt)
stmt.builder.WriteString(" BETWEEN ")
stmt.builder.WriteString(stmt.appendArg(nb.value.min))
stmt.builder.WriteString(" AND ")
stmt.builder.WriteString(stmt.appendArg(nb.value.max))
}
type NumberCondition[V Number, T any] struct {
condition[V, T, NumberOperation]
}
func (nc *NumberCondition[V, T]) Apply(stmt *statement[T]) {
nc.field.Apply(stmt)
operation[T, NumberOperation]{nc.op}.Apply(stmt)
stmt.builder.WriteString(stmt.appendArg(nc.value))
}
type NumberOperation uint8
const (
NumberOperationEqual NumberOperation = iota + 1
NumberOperationNotEqual
NumberOperationLessThan
NumberOperationLessThanOrEqual
NumberOperationGreaterThan
NumberOperationGreaterThanOrEqual
)
var numberOperations = map[NumberOperation]string{
NumberOperationEqual: " = ",
NumberOperationNotEqual: " <> ",
NumberOperationLessThan: " < ",
NumberOperationLessThanOrEqual: " <= ",
NumberOperationGreaterThan: " > ",
NumberOperationGreaterThanOrEqual: " >= ",
}
func (no NumberOperation) String() string {
return numberOperations[no]
}

View File

@@ -0,0 +1,104 @@
package stmt
import (
"fmt"
"strings"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type statement[T any] struct {
builder strings.Builder
client database.QueryExecutor
columns []Column[T]
schema string
table string
alias string
condition Condition[T]
limit uint32
offset uint32
// order by fieldname and sort direction false for asc true for desc
// orderBy SortingColumns[C]
args []any
existingArgs map[any]string
}
func (s *statement[T]) scanners(t *T) []any {
scanners := make([]any, len(s.columns))
for i, column := range s.columns {
scanners[i] = column.scanner(t)
}
return scanners
}
func (s *statement[T]) query() string {
s.builder.WriteString(`SELECT `)
for i, column := range s.columns {
if i > 0 {
s.builder.WriteString(", ")
}
column.Apply(s)
}
s.builder.WriteString(` FROM `)
s.builder.WriteString(s.schema)
s.builder.WriteRune('.')
s.builder.WriteString(s.table)
if s.alias != "" {
s.builder.WriteString(" AS ")
s.builder.WriteString(s.alias)
}
s.builder.WriteString(` WHERE `)
s.condition.Apply(s)
if s.limit > 0 {
s.builder.WriteString(` LIMIT `)
s.builder.WriteString(s.appendArg(s.limit))
}
if s.offset > 0 {
s.builder.WriteString(` OFFSET `)
s.builder.WriteString(s.appendArg(s.offset))
}
return s.builder.String()
}
// func (s *statement[T]) Where(condition Condition[T]) *statement[T] {
// s.condition = condition
// return s
// }
// func (s *statement[T]) Limit(limit uint32) *statement[T] {
// s.limit = limit
// return s
// }
// func (s *statement[T]) Offset(offset uint32) *statement[T] {
// s.offset = offset
// return s
// }
func (s *statement[T]) columnPrefix() string {
if s.alias != "" {
return s.alias + "."
}
return s.schema + "." + s.table + "."
}
func (s *statement[T]) appendArg(arg any) string {
if s.existingArgs == nil {
s.existingArgs = make(map[any]string)
}
if existing, ok := s.existingArgs[arg]; ok {
return existing
}
s.args = append(s.args, arg)
placeholder := fmt.Sprintf("$%d", len(s.args))
s.existingArgs[arg] = placeholder
return placeholder
}

View File

@@ -0,0 +1,18 @@
package stmt_test
import (
"context"
"testing"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt"
)
func Test_Bla(t *testing.T) {
stmt.User(nil).Where(
stmt.Or(
stmt.UserIDCondition("123"),
stmt.UserIDCondition("123"),
stmt.UserUsernameCondition(stmt.TextOperationEqualIgnoreCase, "test"),
),
).Limit(1).Offset(1).Get(context.Background())
}

View File

@@ -0,0 +1,72 @@
package stmt
type Text interface {
~string | ~[]byte
}
type TextCondition[V Text, T any] struct {
condition[V, T, TextOperation]
}
func (tc *TextCondition[V, T]) Apply(stmt *statement[T]) {
placeholder := stmt.appendArg(tc.value)
switch tc.op {
case TextOperationEqual, TextOperationNotEqual:
tc.field.Apply(stmt)
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case TextOperationEqualIgnoreCase:
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
desc.ApplyIgnoreCase(stmt)
} else {
stmt.builder.WriteString("LOWER(")
tc.field.Apply(stmt)
stmt.builder.WriteString(")")
}
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString("LOWER(")
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
case TextOperationStartsWith:
tc.field.Apply(stmt)
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString("|| '%'")
case TextOperationStartsWithIgnoreCase:
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
desc.ApplyIgnoreCase(stmt)
} else {
stmt.builder.WriteString("LOWER(")
tc.field.Apply(stmt)
stmt.builder.WriteString(")")
}
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString("LOWER(")
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
stmt.builder.WriteString("|| '%'")
}
}
type TextOperation uint8
const (
TextOperationEqual TextOperation = iota + 1
TextOperationEqualIgnoreCase
TextOperationNotEqual
TextOperationStartsWith
TextOperationStartsWithIgnoreCase
)
var textOperations = map[TextOperation]string{
TextOperationEqual: " = ",
TextOperationEqualIgnoreCase: " = ",
TextOperationNotEqual: " <> ",
TextOperationStartsWith: " LIKE ",
TextOperationStartsWithIgnoreCase: " LIKE ",
}
func (to TextOperation) String() string {
return textOperations[to]
}

View File

@@ -0,0 +1,193 @@
package stmt
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type userStatement struct {
statement[domain.User]
}
func User(client database.QueryExecutor) *userStatement {
return &userStatement{
statement: statement[domain.User]{
schema: "zitadel",
table: "users",
alias: "u",
client: client,
columns: []Column[domain.User]{
userColumns[UserInstanceID],
userColumns[UserOrgID],
userColumns[UserColumnID],
userColumns[UserColumnUsername],
userColumns[UserCreatedAt],
userColumns[UserUpdatedAt],
userColumns[UserDeletedAt],
},
},
}
}
func (s *userStatement) Where(condition Condition[domain.User]) *userStatement {
s.condition = condition
return s
}
func (s *userStatement) Limit(limit uint32) *userStatement {
s.limit = limit
return s
}
func (s *userStatement) Offset(offset uint32) *userStatement {
s.offset = offset
return s
}
func (s *userStatement) Get(ctx context.Context) (*domain.User, error) {
var user domain.User
err := s.client.QueryRow(ctx, s.query(), s.statement.args...).Scan(s.scanners(&user)...)
if err != nil {
return nil, err
}
return &user, nil
}
func (s *userStatement) List(ctx context.Context) ([]*domain.User, error) {
var users []*domain.User
rows, err := s.client.Query(ctx, s.query(), s.statement.args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var user domain.User
err = rows.Scan(s.scanners(&user)...)
if err != nil {
return nil, err
}
users = append(users, &user)
}
return users, nil
}
func (s *userStatement) SetUsername(ctx context.Context, username string) error {
return nil
}
type UserColumn uint8
var (
userColumns map[UserColumn]Column[domain.User] = map[UserColumn]Column[domain.User]{
UserInstanceID: columnDescriptor[domain.User]{
name: "instance_id",
scan: func(u *domain.User) any {
return &u.InstanceID
},
},
UserOrgID: columnDescriptor[domain.User]{
name: "org_id",
scan: func(u *domain.User) any {
return &u.OrgID
},
},
UserColumnID: columnDescriptor[domain.User]{
name: "id",
scan: func(u *domain.User) any {
return &u.ID
},
},
UserColumnUsername: ignoreCaseColumnDescriptor[domain.User]{
columnDescriptor: columnDescriptor[domain.User]{
name: "username",
scan: func(u *domain.User) any {
return &u.Username
},
},
fieldNameSuffix: "_lower",
},
UserCreatedAt: columnDescriptor[domain.User]{
name: "created_at",
scan: func(u *domain.User) any {
return &u.CreatedAt
},
},
UserUpdatedAt: columnDescriptor[domain.User]{
name: "updated_at",
scan: func(u *domain.User) any {
return &u.UpdatedAt
},
},
UserDeletedAt: columnDescriptor[domain.User]{
name: "deleted_at",
scan: func(u *domain.User) any {
return &u.DeletedAt
},
},
}
humanColumns = map[UserColumn]Column[domain.User]{
UserHumanColumnEmail: ignoreCaseColumnDescriptor[domain.User]{
columnDescriptor: columnDescriptor[domain.User]{
name: "email",
scan: func(u *domain.User) any {
human, ok := u.Traits.(*domain.Human)
if !ok {
return nil
}
if human.Email == nil {
human.Email = new(domain.Email)
}
return &human.Email.Address
},
},
fieldNameSuffix: "_lower",
},
UserHumanColumnEmailVerified: columnDescriptor[domain.User]{
name: "email_is_verified",
scan: func(u *domain.User) any {
human, ok := u.Traits.(*domain.Human)
if !ok {
return nil
}
if human.Email == nil {
human.Email = new(domain.Email)
}
return &human.Email.IsVerified
},
},
}
machineColumns = map[UserColumn]Column[domain.User]{
UserMachineDescription: columnDescriptor[domain.User]{
name: "description",
scan: func(u *domain.User) any {
machine, ok := u.Traits.(*domain.Machine)
if !ok {
return nil
}
if machine == nil {
machine = new(domain.Machine)
}
return &machine.Description
},
},
}
)
const (
UserInstanceID UserColumn = iota + 1
UserOrgID
UserColumnID
UserColumnUsername
UserHumanColumnEmail
UserHumanColumnEmailVerified
UserMachineDescription
UserCreatedAt
UserUpdatedAt
UserDeletedAt
)

View File

@@ -0,0 +1,23 @@
package stmt
import "github.com/zitadel/zitadel/backend/v3/domain"
func UserIDCondition(id string) *TextCondition[string, domain.User] {
return &TextCondition[string, domain.User]{
condition: condition[string, domain.User, TextOperation]{
field: userColumns[UserColumnID],
op: TextOperationEqual,
value: id,
},
}
}
func UserUsernameCondition(op TextOperation, username string) *TextCondition[string, domain.User] {
return &TextCondition[string, domain.User]{
condition: condition[string, domain.User, TextOperation]{
field: userColumns[UserColumnUsername],
op: op,
value: username,
},
}
}

View File

@@ -0,0 +1,135 @@
package stmt
// type table struct {
// schema string
// name string
// possibleJoins []*join
// columns []*col
// }
// type col struct {
// *table
// name string
// }
// type join struct {
// *table
// on []*joinColumns
// }
// type joinColumns struct {
// left, right *col
// }
// var (
// userTable = &table{
// schema: "zitadel",
// name: "users",
// }
// userColumns = []*col{
// userInstanceIDColumn,
// userOrgIDColumn,
// userIDColumn,
// userUsernameColumn,
// }
// userInstanceIDColumn = &col{
// table: userTable,
// name: "instance_id",
// }
// userOrgIDColumn = &col{
// table: userTable,
// name: "org_id",
// }
// userIDColumn = &col{
// table: userTable,
// name: "id",
// }
// userUsernameColumn = &col{
// table: userTable,
// name: "username",
// }
// userJoins = []*join{
// {
// table: instanceTable,
// on: []*joinColumns{
// {
// left: instanceIDColumn,
// right: userInstanceIDColumn,
// },
// },
// },
// {
// table: orgTable,
// on: []*joinColumns{
// {
// left: orgIDColumn,
// right: userOrgIDColumn,
// },
// },
// },
// }
// )
// var (
// instanceTable = &table{
// schema: "zitadel",
// name: "instances",
// }
// instanceColumns = []*col{
// instanceIDColumn,
// instanceNameColumn,
// }
// instanceIDColumn = &col{
// table: instanceTable,
// name: "id",
// }
// instanceNameColumn = &col{
// table: instanceTable,
// name: "name",
// }
// )
// var (
// orgTable = &table{
// schema: "zitadel",
// name: "orgs",
// }
// orgColumns = []*col{
// orgInstanceIDColumn,
// orgIDColumn,
// orgNameColumn,
// }
// orgInstanceIDColumn = &col{
// table: orgTable,
// name: "instance_id",
// }
// orgIDColumn = &col{
// table: orgTable,
// name: "id",
// }
// orgNameColumn = &col{
// table: orgTable,
// name: "name",
// }
// )
// func init() {
// instanceTable.columns = instanceColumns
// userTable.columns = userColumns
// userTable.possibleJoins = []join{
// {
// table: userTable,
// on: []joinColumns{
// {
// left: userIDColumn,
// right: userIDColumn,
// },
// },
// },
// }
// }

View File

@@ -0,0 +1,55 @@
package v3
type Column interface {
Name() string
Write(builder statementBuilder)
}
type ignoreCaseColumn interface {
Column
WriteIgnoreCase(builder statementBuilder)
}
var (
columnNameID = "id"
columnNameName = "name"
columnNameCreatedAt = "created_at"
columnNameUpdatedAt = "updated_at"
columnNameDeletedAt = "deleted_at"
columnNameInstanceID = "instance_id"
columnNameOrgID = "org_id"
)
type column struct {
table Table
name string
}
// Write implements Column.
func (c *column) Write(builder statementBuilder) {
c.table.writeOn(builder)
builder.writeRune('.')
builder.writeString(c.name)
}
// Name implements [Column].
func (c *column) Name() string {
return c.name
}
var _ Column = (*column)(nil)
type columnIgnoreCase struct {
column
suffix string
}
// WriteIgnoreCase implements ignoreCaseColumn.
func (c *columnIgnoreCase) WriteIgnoreCase(builder statementBuilder) {
c.Write(builder)
builder.writeString(c.suffix)
}
var _ ignoreCaseColumn = (*columnIgnoreCase)(nil)

View File

@@ -0,0 +1,182 @@
package v3
type statementBuilder interface {
write([]byte)
writeString(string)
writeRune(rune)
appendArg(any) (placeholder string)
table() Table
}
type Condition interface {
writeOn(builder statementBuilder)
}
type and struct {
conditions []Condition
}
func And(conditions ...Condition) *and {
return &and{conditions: conditions}
}
// writeOn implements [Condition].
func (a *and) writeOn(builder statementBuilder) {
if len(a.conditions) > 1 {
builder.writeString("(")
defer builder.writeString(")")
}
for i, condition := range a.conditions {
if i > 0 {
builder.writeString(" AND ")
}
condition.writeOn(builder)
}
}
var _ Condition = (*and)(nil)
type or struct {
conditions []Condition
}
func Or(conditions ...Condition) *or {
return &or{conditions: conditions}
}
// writeOn implements [Condition].
func (o *or) writeOn(builder statementBuilder) {
if len(o.conditions) > 1 {
builder.writeString("(")
defer builder.writeString(")")
}
for i, condition := range o.conditions {
if i > 0 {
builder.writeString(" OR ")
}
condition.writeOn(builder)
}
}
var _ Condition = (*or)(nil)
type isNull struct {
column Column
}
func IsNull(column Column) *isNull {
return &isNull{column: column}
}
// writeOn implements [Condition].
func (cond *isNull) writeOn(builder statementBuilder) {
cond.column.Write(builder)
builder.writeString(" IS NULL")
}
var _ Condition = (*isNull)(nil)
type isNotNull struct {
column Column
}
func IsNotNull(column Column) *isNotNull {
return &isNotNull{column: column}
}
// writeOn implements [Condition].
func (cond *isNotNull) writeOn(builder statementBuilder) {
cond.column.Write(builder)
builder.writeString(" IS NOT NULL")
}
var _ Condition = (*isNotNull)(nil)
type condition[Op Operator, V Value] struct {
column Column
operator Op
value V
}
// writeOn implements [Condition].
func (cond condition[Op, V]) writeOn(builder statementBuilder) {
cond.column.Write(builder)
builder.writeString(cond.operator.String())
builder.writeString(builder.appendArg(cond.value))
}
var _ Condition = (*condition[TextOperator, string])(nil)
type textCondition[V Text] struct {
condition[TextOperator, V]
}
func NewTextCondition[V Text](column Column, operator TextOperator, value V) *textCondition[V] {
return &textCondition[V]{
condition: condition[TextOperator, V]{
column: column,
operator: operator,
value: value,
},
}
}
// writeOn implements [Condition].
func (cond *textCondition[V]) writeOn(builder statementBuilder) {
switch cond.operator {
case TextOperatorEqual, TextOperatorNotEqual:
cond.column.Write(builder)
builder.writeString(cond.operator.String())
builder.writeString(builder.appendArg(cond.value))
case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase:
if col, ok := cond.column.(ignoreCaseColumn); ok {
col.WriteIgnoreCase(builder)
} else {
builder.writeString("LOWER(")
cond.column.Write(builder)
builder.writeString(")")
}
builder.writeString(cond.operator.String())
builder.writeString("LOWER(")
builder.writeString(builder.appendArg(cond.value))
builder.writeString(")")
case TextOperatorStartsWith:
cond.column.Write(builder)
builder.writeString(cond.operator.String())
builder.writeString(builder.appendArg(cond.value))
builder.writeString(" || '%'")
case TextOperatorStartsWithIgnoreCase:
if col, ok := cond.column.(ignoreCaseColumn); ok {
col.WriteIgnoreCase(builder)
} else {
builder.writeString("LOWER(")
cond.column.Write(builder)
builder.writeString(")")
}
builder.writeString(cond.operator.String())
builder.writeString("LOWER(")
builder.writeString(builder.appendArg(cond.value))
builder.writeString(") || '%'")
}
}
var _ Condition = (*textCondition[string])(nil)
type numberCondition[V Number] struct {
condition[NumberOperator, V]
}
func NewNumberCondition[V Number](column Column, operator NumberOperator, value V) *numberCondition[V] {
return &numberCondition[V]{
condition: condition[NumberOperator, V]{
column: column,
operator: operator,
value: value,
},
}
}
var _ Condition = (*numberCondition[int])(nil)

View File

@@ -0,0 +1,104 @@
package v3
import (
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Instance struct {
id string
name string
createdAt time.Time
updatedAt time.Time
deletedAt time.Time
}
// Columns implements [object].
func (Instance) Columns(table Table) []Column {
return []Column{
&column{
table: table,
name: columnNameID,
},
&column{
table: table,
name: columnNameName,
},
&column{
table: table,
name: columnNameCreatedAt,
},
&column{
table: table,
name: columnNameUpdatedAt,
},
&column{
table: table,
name: columnNameDeletedAt,
},
}
}
// Scan implements [object].
func (i Instance) Scan(row database.Scanner) error {
return row.Scan(
&i.id,
&i.name,
&i.createdAt,
&i.updatedAt,
&i.deletedAt,
)
}
type instanceTable struct {
*table
}
func InstanceTable() *instanceTable {
table := &instanceTable{
table: newTable[Instance]("zitadel", "instances"),
}
table.possibleJoins = func(t Table) map[string]Column {
switch on := t.(type) {
case *instanceTable:
return map[string]Column{
columnNameID: on.IDColumn(),
}
case *orgTable:
return map[string]Column{
columnNameID: on.InstanceIDColumn(),
}
case *userTable:
return map[string]Column{
columnNameID: on.InstanceIDColumn(),
}
default:
return nil
}
}
return table
}
func (i *instanceTable) IDColumn() Column {
return i.columns[columnNameID]
}
func (i *instanceTable) NameColumn() Column {
return i.columns[columnNameName]
}
func (i *instanceTable) CreatedAtColumn() Column {
return i.columns[columnNameCreatedAt]
}
func (i *instanceTable) UpdatedAtColumn() Column {
return i.columns[columnNameUpdatedAt]
}
func (i *instanceTable) DeletedAtColumn() Column {
return i.columns[columnNameDeletedAt]
}

View File

@@ -0,0 +1,11 @@
package v3
type join struct {
table Table
conditions []joinCondition
}
type joinCondition struct {
left Column
right Column
}

View File

@@ -0,0 +1,82 @@
package v3
import (
"fmt"
"time"
"golang.org/x/exp/constraints"
)
type Value interface {
Bool | Number | Text
}
type Text interface {
~string | ~[]byte
}
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
type Bool interface {
~bool
}
type Operator interface {
fmt.Stringer
}
type TextOperator uint8
// String implements [Operator].
func (t TextOperator) String() string {
return textOperators[t]
}
const (
TextOperatorEqual TextOperator = iota + 1
TextOperatorEqualIgnoreCase
TextOperatorNotEqual
TextOperatorNotEqualIgnoreCase
TextOperatorStartsWith
TextOperatorStartsWithIgnoreCase
)
var textOperators = map[TextOperator]string{
TextOperatorEqual: " = ",
TextOperatorEqualIgnoreCase: " LIKE ",
TextOperatorNotEqual: " <> ",
TextOperatorNotEqualIgnoreCase: " NOT LIKE ",
TextOperatorStartsWith: " LIKE ",
TextOperatorStartsWithIgnoreCase: " LIKE ",
}
var _ Operator = TextOperator(0)
type NumberOperator uint8
// String implements Operator.
func (n NumberOperator) String() string {
return numberOperators[n]
}
const (
NumberOperatorEqual NumberOperator = iota + 1
NumberOperatorNotEqual
NumberOperatorLessThan
NumberOperatorLessThanOrEqual
NumberOperatorGreaterThan
NumberOperatorGreaterThanOrEqual
)
var numberOperators = map[NumberOperator]string{
NumberOperatorEqual: " = ",
NumberOperatorNotEqual: " <> ",
NumberOperatorLessThan: " < ",
NumberOperatorLessThanOrEqual: " <= ",
NumberOperatorGreaterThan: " > ",
NumberOperatorGreaterThanOrEqual: " >= ",
}
var _ Operator = NumberOperator(0)

View File

@@ -0,0 +1,117 @@
package v3
import (
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Org struct {
instanceID string
id string
name string
createdAt time.Time
updatedAt time.Time
deletedAt time.Time
}
// Columns implements [object].
func (Org) Columns(table Table) []Column {
return []Column{
&column{
table: table,
name: columnNameInstanceID,
},
&column{
table: table,
name: columnNameID,
},
&column{
table: table,
name: columnNameName,
},
&column{
table: table,
name: columnNameCreatedAt,
},
&column{
table: table,
name: columnNameUpdatedAt,
},
&column{
table: table,
name: columnNameDeletedAt,
},
}
}
// Scan implements [object].
func (o Org) Scan(row database.Scanner) error {
return row.Scan(
&o.instanceID,
&o.id,
&o.name,
&o.createdAt,
&o.updatedAt,
&o.deletedAt,
)
}
type orgTable struct {
*table
}
func OrgTable() *orgTable {
table := &orgTable{
table: newTable[Org]("zitadel", "orgs"),
}
table.possibleJoins = func(table Table) map[string]Column {
switch on := table.(type) {
case *instanceTable:
return map[string]Column{
columnNameInstanceID: on.IDColumn(),
}
case *orgTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameID: on.IDColumn(),
}
case *userTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameID: on.IDColumn(),
}
default:
return nil
}
}
return table
}
func (o *orgTable) InstanceIDColumn() Column {
return o.columns[columnNameInstanceID]
}
func (o *orgTable) IDColumn() Column {
return o.columns[columnNameID]
}
func (o *orgTable) NameColumn() Column {
return o.columns[columnNameName]
}
func (o *orgTable) CreatedAtColumn() Column {
return o.columns[columnNameCreatedAt]
}
func (o *orgTable) UpdatedAtColumn() Column {
return o.columns[columnNameUpdatedAt]
}
func (o *orgTable) DeletedAtColumn() Column {
return o.columns[columnNameDeletedAt]
}

View File

@@ -0,0 +1,188 @@
package v3
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Query[O object] interface {
Where(condition Condition)
Join(tables ...Table)
Limit(limit uint32)
Offset(offset uint32)
OrderBy(columns ...Column)
Result(ctx context.Context, client database.Querier) (*O, error)
Results(ctx context.Context, client database.Querier) ([]O, error)
fmt.Stringer
statementBuilder
}
type query[O object] struct {
*statement[O]
joins []join
limit uint32
offset uint32
orderBy []Column
}
func NewQuery[O object](table Table) Query[O] {
return &query[O]{
statement: newStatement[O](table),
}
}
// Result implements [Query].
func (q *query[O]) Result(ctx context.Context, client database.Querier) (*O, error) {
var object O
row := client.QueryRow(ctx, q.String(), q.args...)
if err := object.Scan(row); err != nil {
return nil, err
}
return &object, nil
}
// Results implements [Query].
func (q *query[O]) Results(ctx context.Context, client database.Querier) ([]O, error) {
var objects []O
rows, err := client.Query(ctx, q.String(), q.args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var object O
if err := object.Scan(rows); err != nil {
return nil, err
}
objects = append(objects, object)
}
return objects, rows.Err()
}
// Join implements [Query].
func (q *query[O]) Join(tables ...Table) {
for _, tbl := range tables {
cols := q.tbl.(*table).possibleJoins(tbl)
if len(cols) == 0 {
panic(fmt.Sprintf("table %q does not have any possible joins with table %q", q.tbl.Name(), tbl.Name()))
}
q.joins = append(q.joins, join{
table: tbl,
conditions: make([]joinCondition, 0, len(cols)),
})
for colName, col := range cols {
q.joins[len(q.joins)-1].conditions = append(q.joins[len(q.joins)-1].conditions, joinCondition{
left: q.tbl.(*table).columns[colName],
right: col,
})
}
}
}
func (q *query[O]) Limit(limit uint32) {
q.limit = limit
}
func (q *query[O]) Offset(offset uint32) {
q.offset = offset
}
func (q *query[O]) OrderBy(columns ...Column) {
for _, allowedColumn := range q.columns {
for _, column := range columns {
if allowedColumn.Name() == column.Name() {
q.orderBy = append(q.orderBy, column)
}
}
}
}
// String implements [fmt.Stringer] and [Query].
func (q *query[O]) String() string {
q.writeSelectColumns()
q.writeFrom()
q.writeJoins()
q.writeCondition()
q.writeOrderBy()
q.writeLimit()
q.writeOffset()
q.writeGroupBy()
return q.builder.String()
}
func (q *query[O]) writeSelectColumns() {
q.builder.WriteString("SELECT ")
for i, column := range q.columns {
if i > 0 {
q.builder.WriteString(", ")
}
q.builder.WriteString(q.tbl.Alias())
q.builder.WriteRune('.')
q.builder.WriteString(column.Name())
}
}
func (q *query[O]) writeJoins() {
for _, join := range q.joins {
q.builder.WriteString(" JOIN ")
q.builder.WriteString(join.table.Schema())
q.builder.WriteRune('.')
q.builder.WriteString(join.table.Name())
if join.table.Alias() != "" {
q.builder.WriteString(" AS ")
q.builder.WriteString(join.table.Alias())
}
q.builder.WriteString(" ON ")
for i, condition := range join.conditions {
if i > 0 {
q.builder.WriteString(" AND ")
}
q.builder.WriteString(condition.left.Name())
q.builder.WriteString(" = ")
q.builder.WriteString(condition.right.Name())
}
}
}
func (q *query[O]) writeOrderBy() {
if len(q.orderBy) == 0 {
return
}
q.builder.WriteString(" ORDER BY ")
for i, order := range q.orderBy {
if i > 0 {
q.builder.WriteString(", ")
}
order.Write(q)
}
}
func (q *query[O]) writeLimit() {
if q.limit == 0 {
return
}
q.builder.WriteString(" LIMIT ")
q.builder.WriteString(q.appendArg(q.limit))
}
func (q *query[O]) writeOffset() {
if q.offset == 0 {
return
}
q.builder.WriteString(" OFFSET ")
q.builder.WriteString(q.appendArg(q.offset))
}
func (q *query[O]) writeGroupBy() {
q.builder.WriteString(" GROUP BY ")
}

View File

@@ -0,0 +1,85 @@
package v3
import (
"fmt"
"strings"
)
type statement[T object] struct {
tbl Table
columns []Column
condition Condition
builder strings.Builder
args []any
existingArgs map[any]string
}
func newStatement[O object](t Table) *statement[O] {
var o O
return &statement[O]{
tbl: t,
columns: o.Columns(t),
}
}
// Where implements [Query].
func (stmt *statement[T]) Where(condition Condition) {
stmt.condition = condition
}
func (stmt *statement[T]) writeFrom() {
stmt.builder.WriteString(" FROM ")
stmt.builder.WriteString(stmt.tbl.Schema())
stmt.builder.WriteRune('.')
stmt.builder.WriteString(stmt.tbl.Name())
if stmt.tbl.Alias() != "" {
stmt.builder.WriteString(" AS ")
stmt.builder.WriteString(stmt.tbl.Alias())
}
}
func (stmt *statement[T]) writeCondition() {
if stmt.condition == nil {
return
}
stmt.builder.WriteString(" WHERE ")
stmt.condition.writeOn(stmt)
}
// appendArg implements [statementBuilder].
func (stmt *statement[T]) appendArg(arg any) (placeholder string) {
if stmt.existingArgs == nil {
stmt.existingArgs = make(map[any]string)
}
if placeholder, ok := stmt.existingArgs[arg]; ok {
return placeholder
}
stmt.args = append(stmt.args, arg)
placeholder = fmt.Sprintf("$%d", len(stmt.args))
stmt.existingArgs[arg] = placeholder
return placeholder
}
// table implements [statementBuilder].
func (stmt *statement[T]) table() Table {
return stmt.tbl
}
// write implements [statementBuilder].
func (stmt *statement[T]) write(data []byte) {
stmt.builder.Write(data)
}
// writeRune implements [statementBuilder].
func (stmt *statement[T]) writeRune(r rune) {
stmt.builder.WriteRune(r)
}
// writeString implements [statementBuilder].
func (stmt *statement[T]) writeString(s string) {
stmt.builder.WriteString(s)
}
var _ statementBuilder = (*statement[Instance])(nil)

View File

@@ -0,0 +1,84 @@
package v3
import "github.com/zitadel/zitadel/backend/v3/storage/database"
type object interface {
User | Org | Instance
Columns(t Table) []Column
Scan(s database.Scanner) error
}
type Table interface {
Schema() string
Name() string
Alias() string
Columns() []Column
writeOn(builder statementBuilder)
}
type table struct {
schema string
name string
alias string
possibleJoins func(table Table) map[string]Column
columns map[string]Column
colList []Column
}
func newTable[O object](schema, name string) *table {
t := &table{
schema: schema,
name: name,
}
var o O
t.colList = o.Columns(t)
t.columns = make(map[string]Column, len(t.colList))
for _, col := range t.colList {
t.columns[col.Name()] = col
}
return t
}
// Columns implements [Table].
func (t *table) Columns() []Column {
if len(t.colList) > 0 {
return t.colList
}
t.colList = make([]Column, 0, len(t.columns))
for _, column := range t.columns {
t.colList = append(t.colList, column)
}
return t.colList
}
// Name implements [Table].
func (t *table) Name() string {
return t.name
}
// Schema implements [Table].
func (t *table) Schema() string {
return t.schema
}
// Alias implements [Table].
func (t *table) Alias() string {
if t.alias != "" {
return t.alias
}
return t.schema + "." + t.name
}
// writeOn implements [Table].
func (t *table) writeOn(builder statementBuilder) {
builder.writeString(t.Alias())
}
var _ Table = (*table)(nil)

View File

@@ -0,0 +1,170 @@
package v3
import (
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type User struct {
instanceID string
orgID string
id string
username string
createdAt time.Time
updatedAt time.Time
deletedAt time.Time
}
// Columns implements [object].
func (u User) Columns(table Table) []Column {
return []Column{
&column{
table: table,
name: columnNameInstanceID,
},
&column{
table: table,
name: columnNameOrgID,
},
&column{
table: table,
name: columnNameID,
},
&columnIgnoreCase{
column: column{
table: table,
name: userTableUsernameColumn,
},
suffix: "_lower",
},
&column{
table: table,
name: columnNameCreatedAt,
},
&column{
table: table,
name: columnNameUpdatedAt,
},
&column{
table: table,
name: columnNameDeletedAt,
},
}
}
// Scan implements [object].
func (u User) Scan(row database.Scanner) error {
return row.Scan(
&u.instanceID,
&u.orgID,
&u.id,
&u.username,
&u.createdAt,
&u.updatedAt,
&u.deletedAt,
)
}
type userTable struct {
*table
}
const (
userTableUsernameColumn = "username"
)
func UserTable() *userTable {
table := &userTable{
table: newTable[User]("zitadel", "users"),
}
table.possibleJoins = func(table Table) map[string]Column {
switch on := table.(type) {
case *userTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameOrgID: on.OrgIDColumn(),
columnNameID: on.IDColumn(),
}
case *orgTable:
return map[string]Column{
columnNameInstanceID: on.InstanceIDColumn(),
columnNameOrgID: on.IDColumn(),
}
case *instanceTable:
return map[string]Column{
columnNameInstanceID: on.IDColumn(),
}
default:
return nil
}
}
return table
}
func (t *userTable) InstanceIDColumn() Column {
return t.columns[columnNameInstanceID]
}
func (t *userTable) OrgIDColumn() Column {
return t.columns[columnNameOrgID]
}
func (t *userTable) IDColumn() Column {
return t.columns[columnNameID]
}
func (t *userTable) UsernameColumn() Column {
return t.columns[userTableUsernameColumn]
}
func (t *userTable) CreatedAtColumn() Column {
return t.columns[columnNameCreatedAt]
}
func (t *userTable) UpdatedAtColumn() Column {
return t.columns[columnNameUpdatedAt]
}
func (t *userTable) DeletedAtColumn() Column {
return t.columns[columnNameDeletedAt]
}
func NewUserQuery() Query[User] {
q := NewQuery[User](UserTable())
return q
}
type userByIDCondition[T Text] struct {
id T
}
func UserByID[T Text](id T) Condition {
return &userByIDCondition[T]{id: id}
}
// writeOn implements Condition.
func (u *userByIDCondition[T]) writeOn(builder statementBuilder) {
NewTextCondition(builder.table().(*userTable).IDColumn(), TextOperatorEqual, u.id).writeOn(builder)
}
var _ Condition = (*userByIDCondition[string])(nil)
type userByUsernameCondition[T Text] struct {
username T
operator TextOperator
}
func UserByUsername[T Text](username T, operator TextOperator) Condition {
return &userByUsernameCondition[T]{username: username, operator: operator}
}
// writeOn implements Condition.
func (u *userByUsernameCondition[T]) writeOn(builder statementBuilder) {
NewTextCondition(builder.table().(*userTable).UsernameColumn(), u.operator, u.username).writeOn(builder)
}
var _ Condition = (*userByUsernameCondition[string])(nil)

View File

@@ -0,0 +1,25 @@
package v3_test
import (
"context"
"testing"
v3 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v3"
)
type user struct{}
func TestUser(t *testing.T) {
query := v3.NewUserQuery()
query.Where(
v3.Or(
v3.UserByID("123"),
v3.UserByUsername("test", v3.TextOperatorStartsWithIgnoreCase),
),
)
query.Limit(10)
query.Offset(5)
// query.OrderBy(
query.Result(context.TODO(), nil)
}

View File

@@ -0,0 +1,78 @@
package v4
type Change interface {
Column
}
type change[V Value] struct {
column Column
value V
}
func newChange[V Value](col Column, value V) Change {
return &change[V]{
column: col,
value: value,
}
}
func newUpdatePtrColumn[V Value](col Column, value *V) Change {
if value == nil {
return newChange(col, nullDBInstruction)
}
return newChange(col, *value)
}
// writeTo implements [Change].
func (c change[V]) writeTo(builder *statementBuilder) {
c.column.writeTo(builder)
builder.WriteString(" = ")
builder.writeArg(c.value)
}
type Changes []Change
func newChanges(cols ...Change) Change {
return Changes(cols)
}
// writeTo implements [Change].
func (m Changes) writeTo(builder *statementBuilder) {
for i, col := range m {
if i > 0 {
builder.WriteString(", ")
}
col.writeTo(builder)
}
}
var _ Change = Changes(nil)
var _ Change = (*change[string])(nil)
type Column interface {
writeTo(builder *statementBuilder)
}
type column struct {
name string
}
func (c column) writeTo(builder *statementBuilder) {
builder.WriteString(c.name)
}
type ignoreCaseColumn interface {
Column
writeIgnoreCaseTo(builder *statementBuilder)
}
type ignoreCaseCol struct {
column
suffix string
}
func (c ignoreCaseCol) writeIgnoreCaseTo(builder *statementBuilder) {
c.column.writeTo(builder)
builder.WriteString(c.suffix)
}

View File

@@ -0,0 +1,112 @@
package v4
type Condition interface {
writeTo(builder *statementBuilder)
}
type and struct {
conditions []Condition
}
// writeTo implements [Condition].
func (a *and) writeTo(builder *statementBuilder) {
if len(a.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
}
for i, condition := range a.conditions {
if i > 0 {
builder.WriteString(" AND ")
}
condition.writeTo(builder)
}
}
func And(conditions ...Condition) *and {
return &and{conditions: conditions}
}
var _ Condition = (*and)(nil)
type or struct {
conditions []Condition
}
// writeTo implements [Condition].
func (o *or) writeTo(builder *statementBuilder) {
if len(o.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
}
for i, condition := range o.conditions {
if i > 0 {
builder.WriteString(" OR ")
}
condition.writeTo(builder)
}
}
func Or(conditions ...Condition) *or {
return &or{conditions: conditions}
}
var _ Condition = (*or)(nil)
type isNull struct {
column Column
}
// writeTo implements [Condition].
func (i *isNull) writeTo(builder *statementBuilder) {
i.column.writeTo(builder)
builder.WriteString(" IS NULL")
}
func IsNull(column Column) *isNull {
return &isNull{column: column}
}
var _ Condition = (*isNull)(nil)
type isNotNull struct {
column Column
}
// writeTo implements [Condition].
func (i *isNotNull) writeTo(builder *statementBuilder) {
i.column.writeTo(builder)
builder.WriteString(" IS NOT NULL")
}
func IsNotNull(column Column) *isNotNull {
return &isNotNull{column: column}
}
var _ Condition = (*isNotNull)(nil)
type valueCondition func(builder *statementBuilder)
func newTextCondition[V Text](col Column, op TextOperator, value V) Condition {
return valueCondition(func(builder *statementBuilder) {
writeTextOperation(builder, col, op, value)
})
}
func newNumberCondition[V Number](col Column, op NumberOperator, value V) Condition {
return valueCondition(func(builder *statementBuilder) {
writeNumberOperation(builder, col, op, value)
})
}
func newBooleanCondition[V Boolean](col Column, value V) Condition {
return valueCondition(func(builder *statementBuilder) {
writeBooleanOperation(builder, col, value)
})
}
// writeTo implements [Condition].
func (c valueCondition) writeTo(builder *statementBuilder) {
c(builder)
}
var _ Condition = (*valueCondition)(nil)

View File

@@ -0,0 +1,2 @@
// this test focuses on queries rather than on tables
package v4

View File

@@ -0,0 +1,149 @@
CREATE TABLE objects (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at TIMESTAMP
);
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TABLE instances(
name VARCHAR(50) NOT NULL
, PRIMARY KEY (id)
) INHERITS (objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON instances
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE instance_objects(
instance_id INT NOT NULL
, PRIMARY KEY (instance_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
--, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (objects);
CREATE TABLE orgs(
name VARCHAR(50) NOT NULL
, PRIMARY KEY (instance_id, id)
, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (instance_objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON orgs
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE org_objects(
org_id INT NOT NULL
, PRIMARY KEY (instance_id, org_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
-- CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id),
-- CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (instance_objects);
CREATE TABLE users (
username VARCHAR(50) NOT NULL
, PRIMARY KEY (instance_id, org_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
-- , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
-- , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (org_objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE human_users(
first_name VARCHAR(50)
, last_name VARCHAR(50)
, PRIMARY KEY (instance_id, org_id, id)
-- CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id),
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (users);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON human_users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE machine_users(
description VARCHAR(50)
, PRIMARY KEY (instance_id, org_id, id)
-- , CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id)
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (users);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON machine_users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
select u.*, hu.first_name, hu.last_name, mu.description from users u
left join human_users hu on u.instance_id = hu.instance_id and u.org_id = hu.org_id and u.id = hu.id
left join machine_users mu on u.instance_id = mu.instance_id and u.org_id = mu.org_id and u.id = mu.id
-- where
-- u.instance_id = 1
-- and u.org_id = 3
-- and u.id = 7
;
create view users_view as (
SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, first_name
, last_name
, description
FROM (
(SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, first_name
, last_name
, NULL AS description
FROM
human_users)
UNION
(SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, NULL AS first_name
, NULL AS last_name
, description
FROM
machine_users)
));

View File

@@ -0,0 +1,139 @@
package v4
import (
"time"
"golang.org/x/exp/constraints"
)
type Value interface {
Boolean | Number | Text | databaseInstruction
}
type Operator interface {
BooleanOperator | NumberOperator | TextOperator
}
type Text interface {
~string | ~[]byte
}
type TextOperator uint8
const (
// TextOperatorEqual compares two strings for equality.
TextOperatorEqual TextOperator = iota + 1
// TextOperatorEqualIgnoreCase compares two strings for equality, ignoring case.
TextOperatorEqualIgnoreCase
// TextOperatorNotEqual compares two strings for inequality.
TextOperatorNotEqual
// TextOperatorNotEqualIgnoreCase compares two strings for inequality, ignoring case.
TextOperatorNotEqualIgnoreCase
// TextOperatorStartsWith checks if the first string starts with the second.
TextOperatorStartsWith
// TextOperatorStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
TextOperatorStartsWithIgnoreCase
)
var textOperators = map[TextOperator]string{
TextOperatorEqual: " = ",
TextOperatorEqualIgnoreCase: " LIKE ",
TextOperatorNotEqual: " <> ",
TextOperatorNotEqualIgnoreCase: " NOT LIKE ",
TextOperatorStartsWith: " LIKE ",
TextOperatorStartsWithIgnoreCase: " LIKE ",
}
func writeTextOperation[T Text](builder *statementBuilder, col Column, op TextOperator, value T) {
switch op {
case TextOperatorEqual, TextOperatorNotEqual:
col.writeTo(builder)
builder.WriteString(textOperators[op])
builder.WriteString(builder.appendArg(value))
case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase:
if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok {
ignoreCaseCol.writeIgnoreCaseTo(builder)
} else {
builder.WriteString("LOWER(")
col.writeTo(builder)
builder.WriteString(")")
}
builder.WriteString(textOperators[op])
builder.WriteString("LOWER(")
builder.WriteString(builder.appendArg(value))
builder.WriteString(")")
case TextOperatorStartsWith:
col.writeTo(builder)
builder.WriteString(textOperators[op])
builder.WriteString(builder.appendArg(value))
builder.WriteString(" || '%'")
case TextOperatorStartsWithIgnoreCase:
if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok {
ignoreCaseCol.writeIgnoreCaseTo(builder)
} else {
builder.WriteString("LOWER(")
col.writeTo(builder)
builder.WriteString(")")
}
builder.WriteString(textOperators[op])
builder.WriteString("LOWER(")
builder.WriteString(builder.appendArg(value))
builder.WriteString(")")
builder.WriteString(" || '%'")
default:
panic("unsupported text operation")
}
}
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
type NumberOperator uint8
const (
// NumberOperatorEqual compares two numbers for equality.
NumberOperatorEqual NumberOperator = iota + 1
// NumberOperatorNotEqual compares two numbers for inequality.
NumberOperatorNotEqual
// NumberOperatorLessThan compares two numbers to check if the first is less than the second.
NumberOperatorLessThan
// NumberOperatorLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
NumberOperatorAtLeast
// NumberOperatorGreaterThan compares two numbers to check if the first is greater than the second.
NumberOperatorGreaterThan
// NumberOperatorGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
NumberOperatorAtMost
)
var numberOperators = map[NumberOperator]string{
NumberOperatorEqual: " = ",
NumberOperatorNotEqual: " <> ",
NumberOperatorLessThan: " < ",
NumberOperatorAtLeast: " <= ",
NumberOperatorGreaterThan: " > ",
NumberOperatorAtMost: " >= ",
}
func writeNumberOperation[T Number](builder *statementBuilder, col Column, op NumberOperator, value T) {
col.writeTo(builder)
builder.WriteString(numberOperators[op])
builder.WriteString(builder.appendArg(value))
}
type Boolean interface {
~bool
}
type BooleanOperator uint8
const (
BooleanOperatorIsTrue BooleanOperator = iota + 1
BooleanOperatorIsFalse
)
func writeBooleanOperation[T Boolean](builder *statementBuilder, col Column, value T) {
col.writeTo(builder)
builder.WriteString(" IS ")
builder.WriteString(builder.appendArg(value))
}

View File

@@ -0,0 +1,18 @@
package v4
type Org struct {
InstanceID string
ID string
Name string
Dates
}
type GetOrg struct{}
type ListOrgs struct{}
type CreateOrg struct{}
type UpdateOrg struct{}
type DeleteOrg struct{}

View File

@@ -0,0 +1,46 @@
package v4
import (
"strconv"
"strings"
)
type databaseInstruction string
const (
nowDBInstruction databaseInstruction = "NOW()"
nullDBInstruction databaseInstruction = "NULL"
)
type statementBuilder struct {
strings.Builder
args []any
existingArgs map[any]string
}
func (b *statementBuilder) writeArg(arg any) {
b.WriteString(b.appendArg(arg))
}
func (b *statementBuilder) appendArg(arg any) (placeholder string) {
if b.existingArgs == nil {
b.existingArgs = make(map[any]string)
}
if placeholder, ok := b.existingArgs[arg]; ok {
return placeholder
}
if instruction, ok := arg.(databaseInstruction); ok {
return string(instruction)
}
b.args = append(b.args, arg)
placeholder = "$" + strconv.Itoa(len(b.args))
b.existingArgs[arg] = placeholder
return placeholder
}
func (b *statementBuilder) appendArgs(args ...any) {
for _, arg := range args {
b.appendArg(arg)
}
}

View File

@@ -0,0 +1,239 @@
package v4
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Dates struct {
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
}
type User struct {
InstanceID string
OrgID string
ID string
Username string
Traits userTrait
Dates
}
type UserType string
type userTrait interface {
userTrait()
Type() UserType
}
const userQuery = `SELECT u.instance_id, u.org_id, u.id, u.username, u.type, u.created_at, u.updated_at, u.deleted_at,` +
` h.first_name, h.last_name, h.email_address, h.email_verified_at, h.phone_number, h.phone_verified_at, m.description` +
` FROM users u` +
` LEFT JOIN user_humans h ON u.instance_id = h.instance_id AND u.org_id = h.org_id AND u.id = h.id` +
` LEFT JOIN user_machines m ON u.instance_id = m.instance_id AND u.org_id = m.org_id AND u.id = m.id`
type user struct {
builder statementBuilder
client database.QueryExecutor
condition Condition
}
func UserRepository(client database.QueryExecutor) *user {
return &user{
client: client,
}
}
func (u *user) WithCondition(condition Condition) *user {
u.condition = condition
return u
}
func (u *user) Get(ctx context.Context) (*User, error) {
u.builder.WriteString(userQuery)
u.writeCondition()
return scanUser(u.client.QueryRow(ctx, u.builder.String(), u.builder.args...))
}
func (u *user) List(ctx context.Context) (users []*User, err error) {
u.builder.WriteString(userQuery)
u.writeCondition()
rows, err := u.client.Query(ctx, u.builder.String(), u.builder.args...)
if err != nil {
return nil, err
}
defer func() {
closeErr := rows.Close()
if err != nil {
return
}
err = closeErr
}()
for rows.Next() {
user, err := scanUser(rows)
if err != nil {
return nil, err
}
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
const (
createUserCte = `WITH user AS (` +
`INSERT INTO users (instance_id, org_id, id, username, type) VALUES ($1, $2, $3, $4, $5)` +
` RETURNING *)`
createHumanStmt = createUserCte + ` INSERT INTO user_humans h (instance_id, org_id, user_id, first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at)` +
` SELECT u.instance_id, u.org_id, u.id, $6, $7, $8, $9, $10, $11` +
` FROM user u` +
` RETURNING u.created_at, u.updated_at, u.deleted_at`
createMachineStmt = createUserCte + ` INSERT INTO user_machines (instance_id, org_id, user_id, description)` +
` SELECT u.instance_id, u.org_id, u.id, $6` +
` FROM user u` +
` RETURNING u.created_at, u.updated_at`
)
func (u *user) Create(ctx context.Context, user *User) error {
u.builder.appendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) {
case *Human:
u.builder.WriteString(createHumanStmt)
u.builder.appendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
case *Machine:
u.builder.WriteString(createMachineStmt)
u.builder.appendArgs(trait.Description)
}
return u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan(user.CreatedAt, user.UpdatedAt)
}
func (u *user) InstanceIDColumn() Column {
return column{name: "u.instance_id"}
}
func (u *user) InstanceIDCondition(instanceID string) Condition {
return newTextCondition(u.InstanceIDColumn(), TextOperatorEqual, instanceID)
}
func (u *user) OrgIDColumn() Column {
return column{name: "u.org_id"}
}
func (u *user) OrgIDCondition(orgID string) Condition {
return newTextCondition(u.OrgIDColumn(), TextOperatorEqual, orgID)
}
func (u *user) IDColumn() Column {
return column{name: "u.id"}
}
func (u *user) IDCondition(userID string) Condition {
return newTextCondition(u.IDColumn(), TextOperatorEqual, userID)
}
func (u *user) UsernameColumn() Column {
return ignoreCaseCol{
column: column{name: "u.username"},
suffix: "_lower",
}
}
func (u user) SetUsername(username string) Change {
return newChange(u.UsernameColumn(), username)
}
func (u *user) UsernameCondition(op TextOperator, username string) Condition {
return newTextCondition(u.UsernameColumn(), op, username)
}
func (u *user) CreatedAtColumn() Column {
return column{name: "u.created_at"}
}
func (u *user) CreatedAtCondition(op NumberOperator, createdAt time.Time) Condition {
return newNumberCondition(u.CreatedAtColumn(), op, createdAt)
}
func (u *user) UpdatedAtColumn() Column {
return column{name: "u.updated_at"}
}
func (u *user) UpdatedAtCondition(op NumberOperator, updatedAt time.Time) Condition {
return newNumberCondition(u.UpdatedAtColumn(), op, updatedAt)
}
func (u *user) DeletedAtColumn() Column {
return column{name: "u.deleted_at"}
}
func (u *user) DeletedCondition(isDeleted bool) Condition {
if isDeleted {
return IsNotNull(u.DeletedAtColumn())
}
return IsNull(u.DeletedAtColumn())
}
func (u *user) DeletedAtCondition(op NumberOperator, deletedAt time.Time) Condition {
return newNumberCondition(u.DeletedAtColumn(), op, deletedAt)
}
func (u *user) writeCondition() {
if u.condition == nil {
return
}
u.builder.WriteString(" WHERE ")
u.condition.writeTo(&u.builder)
}
func scanUser(scanner database.Scanner) (*User, error) {
var (
user User
human Human
email Email
phone Phone
machine Machine
typ UserType
)
err := scanner.Scan(
&user.InstanceID,
&user.OrgID,
&user.ID,
&user.Username,
&typ,
&user.Dates.CreatedAt,
&user.Dates.UpdatedAt,
&user.Dates.DeletedAt,
&human.FirstName,
&human.LastName,
&email.Address,
&email.VerifiedAt,
&phone.Number,
&phone.VerifiedAt,
&machine.Description,
)
if err != nil {
return nil, err
}
switch typ {
case UserTypeHuman:
if email.Address != "" {
human.Email = &email
}
if phone.Number != "" {
human.Phone = &phone
}
user.Traits = &human
case UserTypeMachine:
user.Traits = &machine
}
return &user, nil
}

View File

@@ -0,0 +1,187 @@
package v4
import (
"context"
"time"
)
type Human struct {
FirstName string
LastName string
Email *Email
Phone *Phone
}
const UserTypeHuman UserType = "human"
func (Human) userTrait() {}
func (h Human) Type() UserType {
return UserTypeHuman
}
var _ userTrait = (*Human)(nil)
type Email struct {
Address string
Verification
}
type Phone struct {
Number string
Verification
}
type Verification struct {
VerifiedAt time.Time
}
type userHuman struct {
*user
}
func (u *user) Human() *userHuman {
return &userHuman{user: u}
}
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
func (u *userHuman) GetEmail(ctx context.Context) (*Email, error) {
var email Email
u.builder.WriteString(userEmailQuery)
u.writeCondition()
err := u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan(
&email.Address,
&email.Verification.VerifiedAt,
)
if err != nil {
return nil, err
}
return &email, nil
}
func (h userHuman) Update(ctx context.Context, changes ...Change) error {
h.builder.WriteString(`UPDATE human_users h SET `)
Changes(changes).writeTo(&h.builder)
h.writeCondition()
stmt := h.builder.String()
return h.client.Exec(ctx, stmt, h.builder.args...)
}
func (h userHuman) SetFirstName(firstName string) Change {
return newChange(h.FirstNameColumn(), firstName)
}
func (h userHuman) FirstNameColumn() Column {
return column{"h.first_name"}
}
func (h userHuman) FirstNameCondition(op TextOperator, firstName string) Condition {
return newTextCondition(h.FirstNameColumn(), op, firstName)
}
func (h userHuman) SetLastName(lastName string) Change {
return newChange(h.LastNameColumn(), lastName)
}
func (h userHuman) LastNameColumn() Column {
return column{"h.last_name"}
}
func (h userHuman) LastNameCondition(op TextOperator, lastName string) Condition {
return newTextCondition(h.LastNameColumn(), op, lastName)
}
func (h userHuman) EmailAddressColumn() Column {
return ignoreCaseCol{
column: column{"h.email_address"},
suffix: "_lower",
}
}
func (h userHuman) EmailAddressCondition(op TextOperator, email string) Condition {
return newTextCondition(h.EmailAddressColumn(), op, email)
}
func (h userHuman) EmailVerifiedAtColumn() Column {
return column{"h.email_verified_at"}
}
func (h *userHuman) EmailAddressVerifiedCondition(isVerified bool) Condition {
if isVerified {
return IsNotNull(h.EmailVerifiedAtColumn())
}
return IsNull(h.EmailVerifiedAtColumn())
}
func (h userHuman) EmailVerifiedAtCondition(op TextOperator, emailVerifiedAt string) Condition {
return newTextCondition(h.EmailVerifiedAtColumn(), op, emailVerifiedAt)
}
func (h userHuman) SetEmailAddress(address string) Change {
return newChange(h.EmailAddressColumn(), address)
}
// SetEmailVerified sets the verified column of the email
// if at is zero the statement uses the database timestamp
func (h userHuman) SetEmailVerified(at time.Time) Change {
if at.IsZero() {
return newChange(h.EmailVerifiedAtColumn(), nowDBInstruction)
}
return newChange(h.EmailVerifiedAtColumn(), at)
}
func (h userHuman) SetEmail(address string, verified *time.Time) Change {
return newChanges(
h.SetEmailAddress(address),
newUpdatePtrColumn(h.EmailVerifiedAtColumn(), verified),
)
}
func (h userHuman) PhoneNumberColumn() Column {
return column{"h.phone_number"}
}
func (h userHuman) SetPhoneNumber(number string) Change {
return newChange(h.PhoneNumberColumn(), number)
}
func (h userHuman) PhoneNumberCondition(op TextOperator, phoneNumber string) Condition {
return newTextCondition(h.PhoneNumberColumn(), op, phoneNumber)
}
func (h userHuman) PhoneVerifiedAtColumn() Column {
return column{"h.phone_verified_at"}
}
func (h userHuman) PhoneNumberVerifiedCondition(isVerified bool) Condition {
if isVerified {
return IsNotNull(h.PhoneVerifiedAtColumn())
}
return IsNull(h.PhoneVerifiedAtColumn())
}
// SetPhoneVerified sets the verified column of the phone
// if at is zero the statement uses the database timestamp
func (h userHuman) SetPhoneVerified(at time.Time) Change {
if at.IsZero() {
return newChange(h.PhoneVerifiedAtColumn(), nowDBInstruction)
}
return newChange(h.PhoneVerifiedAtColumn(), at)
}
func (h userHuman) PhoneVerifiedAtCondition(op TextOperator, phoneVerifiedAt string) Condition {
return newTextCondition(h.PhoneVerifiedAtColumn(), op, phoneVerifiedAt)
}
func (h userHuman) SetPhone(number string, verifiedAt *time.Time) Change {
return newChanges(
h.SetPhoneNumber(number),
newUpdatePtrColumn(h.PhoneVerifiedAtColumn(), verifiedAt),
)
}

View File

@@ -0,0 +1,41 @@
package v4
import "context"
type Machine struct {
Description string
}
func (Machine) userTrait() {}
func (m Machine) Type() UserType {
return UserTypeMachine
}
const UserTypeMachine UserType = "machine"
var _ userTrait = (*Machine)(nil)
type userMachine struct {
*user
}
func (u *user) Machine() *userMachine {
return &userMachine{user: u}
}
func (m userMachine) Update(ctx context.Context, cols ...Change) (*Machine, error) {
return nil, nil
}
func (userMachine) DescriptionColumn() Column {
return column{"m.description"}
}
func (m userMachine) SetDescription(description string) Change {
return newChange(m.DescriptionColumn(), description)
}
func (m userMachine) DescriptionCondition(op TextOperator, description string) Condition {
return newTextCondition(m.DescriptionColumn(), op, description)
}

View File

@@ -0,0 +1,65 @@
package v4_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
)
func TestQueryUser(t *testing.T) {
t.Run("User filters", func(t *testing.T) {
user := v4.UserRepository(nil)
user.WithCondition(
v4.And(
v4.Or(
user.IDCondition("test"),
user.IDCondition("2"),
),
user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
),
).Get(context.Background())
})
t.Run("machine and human filters", func(t *testing.T) {
user := v4.UserRepository(nil)
machine := user.Machine()
human := user.Human()
user.WithCondition(
v4.And(
user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
v4.Or(
machine.DescriptionCondition(v4.TextOperatorStartsWithIgnoreCase, "test"),
human.EmailAddressVerifiedCondition(true),
v4.IsNotNull(machine.DescriptionColumn()),
),
),
)
human.GetEmail(context.Background())
})
}
type dbInstruction string
func TestArg(t *testing.T) {
var bla any = "asdf"
instr, ok := bla.(dbInstruction)
assert.False(t, ok)
assert.Empty(t, instr)
bla = dbInstruction("asdf")
instr, ok = bla.(dbInstruction)
assert.True(t, ok)
assert.Equal(t, instr, dbInstruction("asdf"))
}
func TestWriteUser(t *testing.T) {
t.Run("update user", func(t *testing.T) {
user := v4.UserRepository(nil)
user.WithCondition(user.IDCondition("test")).Human().Update(
context.Background(),
user.SetUsername("test"),
)
})
}

View File

@@ -0,0 +1,39 @@
package repository
import (
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type user struct {
database.QueryExecutor
}
func User(client database.QueryExecutor) domain.UserRepository {
// return &user{QueryExecutor: client}
return nil
}
// On implements [domain.UserRepository].
func (exec *user) On(clauses ...domain.UserClause) domain.UserOperation {
return &userOperation{
QueryExecutor: exec.QueryExecutor,
clauses: clauses,
}
}
// OnHuman implements [domain.UserRepository].
func (exec *user) OnHuman(clauses ...domain.UserClause) domain.HumanOperation {
return &humanOperation{
userOperation: *exec.On(clauses...).(*userOperation),
}
}
// OnMachine implements [domain.UserRepository].
func (exec *user) OnMachine(clauses ...domain.UserClause) domain.MachineOperation {
return &machineOperation{
userOperation: *exec.On(clauses...).(*userOperation),
}
}
// var _ domain.UserRepository = (*user)(nil)

View File

@@ -0,0 +1,36 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type humanOperation struct {
userOperation
}
// GetEmail implements domain.HumanOperation.
func (h *humanOperation) GetEmail(ctx context.Context) (*domain.Email, error) {
var email domain.Email
err := h.QueryExecutor.QueryRow(ctx, `SELECT email, is_email_verified FROM human_users WHERE id = $1`, h.clauses).Scan(
&email.Address,
&email.IsVerified,
)
if err != nil {
return nil, err
}
return &email, nil
}
// SetEmail implements domain.HumanOperation.
func (h *humanOperation) SetEmail(ctx context.Context, email string) error {
return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET email = $1 WHERE id = $2`, email, h.clauses)
}
// SetEmailVerified implements domain.HumanOperation.
func (h *humanOperation) SetEmailVerified(ctx context.Context, email string) error {
return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET is_email_verified = $1 WHERE id = $2 AND email = $3`, true, h.clauses, email)
}
var _ domain.HumanOperation = (*humanOperation)(nil)

View File

@@ -0,0 +1,18 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type machineOperation struct {
userOperation
}
// SetDescription implements domain.MachineOperation.
func (m *machineOperation) SetDescription(ctx context.Context, description string) error {
return m.QueryExecutor.Exec(ctx, `UPDATE machines SET description = $1 WHERE id = $2`, description, m.clauses)
}
var _ domain.MachineOperation = (*machineOperation)(nil)

View File

@@ -0,0 +1,68 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type userOperation struct {
database.QueryExecutor
clauses []domain.UserClause
}
// Delete implements [domain.UserOperation].
func (u *userOperation) Delete(ctx context.Context) error {
return u.QueryExecutor.Exec(ctx, `DELETE FROM users WHERE id = $1`, u.clauses)
}
// SetUsername implements [domain.UserOperation].
func (u *userOperation) SetUsername(ctx context.Context, username string) error {
var stmt statement
stmt.builder.WriteString(`UPDATE users SET username = $1 WHERE `)
stmt.appendArg(username)
clausesToSQL(&stmt, u.clauses)
return u.QueryExecutor.Exec(ctx, stmt.builder.String(), stmt.args...)
}
var _ domain.UserOperation = (*userOperation)(nil)
func UserIDQuery(id string) domain.UserClause {
return textClause[string]{
clause: clause[domain.TextOperation]{
field: userFields[domain.UserFieldID],
op: domain.TextOperationEqual,
},
value: id,
}
}
func HumanEmailQuery(op domain.TextOperation, email string) domain.UserClause {
return textClause[string]{
clause: clause[domain.TextOperation]{
field: userFields[domain.UserHumanFieldEmail],
op: op,
},
value: email,
}
}
func HumanEmailVerifiedQuery(op domain.BoolOperation) domain.UserClause {
return boolClause[domain.BoolOperation]{
clause: clause[domain.BoolOperation]{
field: userFields[domain.UserHumanFieldEmailVerified],
op: op,
},
}
}
func clausesToSQL(stmt *statement, clauses []domain.UserClause) {
for _, clause := range clauses {
stmt.builder.WriteString(userFields[clause.Field()].String())
stmt.builder.WriteString(clause.Operation().String())
stmt.appendArg(clause.Args()...)
}
}

View File

@@ -0,0 +1,36 @@
package database
import "context"
type Transaction interface {
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
End(ctx context.Context, err error) error
Begin(ctx context.Context) (Transaction, error)
QueryExecutor
}
type Beginner interface {
Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error)
}
type TransactionOptions struct {
IsolationLevel IsolationLevel
AccessMode AccessMode
}
type IsolationLevel uint8
const (
IsolationLevelSerializable IsolationLevel = iota
IsolationLevelReadCommitted
)
type AccessMode uint8
const (
AccessModeReadWrite AccessMode = iota
AccessModeReadOnly
)

View File

@@ -0,0 +1,23 @@
package eventstore
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Event struct {
AggregateType string `json:"aggregateType"`
AggregateID string `json:"aggregateId"`
Type string `json:"type"`
Payload any `json:"payload,omitempty"`
}
func Publish(ctx context.Context, events []*Event, db database.Executor) error {
for _, event := range events {
if err := db.Exec(ctx, `INSERT INTO events (aggregate_type, aggregate_id) VALUES ($1, $2)`, event.AggregateType, event.AggregateID); err != nil {
return err
}
}
return nil
}

View File

@@ -0,0 +1,7 @@
package logging
import "log/slog"
type Logger struct {
*slog.Logger
}

View File

@@ -0,0 +1,23 @@
package tracing
import (
"context"
"go.opentelemetry.io/otel/trace"
"go.opentelemetry.io/otel/trace/noop"
)
type Tracer struct {
trace.Tracer
}
var noopTracer = Tracer{
Tracer: noop.NewTracerProvider().Tracer(""),
}
func (t *Tracer) Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) {
if t.Tracer == nil {
return noopTracer.Start(ctx, spanName, opts...)
}
return t.Tracer.Start(ctx, spanName, opts...)
}