mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 18:07:31 +00:00
multiple tries
This commit is contained in:
56
backend/command/client/grpc/api.go
Normal file
56
backend/command/client/grpc/api.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package grpc
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/command"
|
||||
"github.com/zitadel/zitadel/backend/command/query"
|
||||
"github.com/zitadel/zitadel/backend/command/receiver"
|
||||
"github.com/zitadel/zitadel/backend/command/receiver/cache"
|
||||
"github.com/zitadel/zitadel/backend/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/telemetry/logging"
|
||||
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
||||
)
|
||||
|
||||
type api struct {
|
||||
db database.Pool
|
||||
|
||||
manipulator receiver.InstanceManipulator
|
||||
reader receiver.InstanceReader
|
||||
tracer *tracing.Tracer
|
||||
logger *logging.Logger
|
||||
cache cache.Cache[receiver.InstanceIndex, string, *receiver.Instance]
|
||||
}
|
||||
|
||||
func (a *api) CreateInstance(ctx context.Context) error {
|
||||
instance := &receiver.Instance{
|
||||
ID: "123",
|
||||
Name: "test",
|
||||
}
|
||||
return command.Trace(
|
||||
a.tracer,
|
||||
command.SetCache(a.cache,
|
||||
command.Activity(a.logger, command.CreateInstance(a.manipulator, instance)),
|
||||
instance,
|
||||
),
|
||||
).Execute(ctx)
|
||||
}
|
||||
|
||||
func (a *api) DeleteInstance(ctx context.Context) error {
|
||||
return command.Trace(
|
||||
a.tracer,
|
||||
command.DeleteCache(a.cache,
|
||||
command.Activity(
|
||||
a.logger,
|
||||
command.DeleteInstance(a.manipulator, &receiver.Instance{
|
||||
ID: "123",
|
||||
})),
|
||||
receiver.InstanceByID,
|
||||
"123",
|
||||
)).Execute(ctx)
|
||||
}
|
||||
|
||||
func (a *api) InstanceByID(ctx context.Context) (*receiver.Instance, error) {
|
||||
q := query.InstanceByID(a.reader, "123")
|
||||
return q.Execute(ctx)
|
||||
}
|
102
backend/command/command/caching.go
Normal file
102
backend/command/command/caching.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/receiver/cache"
|
||||
)
|
||||
|
||||
type setCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
cache cache.Cache[I, K, V]
|
||||
command Command
|
||||
entry V
|
||||
}
|
||||
|
||||
// SetCache decorates the command, if the command is executed without error it will set the cache entry.
|
||||
func SetCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, entry V) Command {
|
||||
return &setCache[I, K, V]{
|
||||
cache: cache,
|
||||
command: command,
|
||||
entry: entry,
|
||||
}
|
||||
}
|
||||
|
||||
var _ Command = (*setCache[any, any, cache.Entry[any, any]])(nil)
|
||||
|
||||
// Execute implements [Command].
|
||||
func (s *setCache[I, K, V]) Execute(ctx context.Context) error {
|
||||
if err := s.command.Execute(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
s.cache.Set(ctx, s.entry)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name implements [Command].
|
||||
func (s *setCache[I, K, V]) Name() string {
|
||||
return s.command.Name()
|
||||
}
|
||||
|
||||
type deleteCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
cache cache.Cache[I, K, V]
|
||||
command Command
|
||||
index I
|
||||
keys []K
|
||||
}
|
||||
|
||||
// DeleteCache decorates the command, if the command is executed without error it will delete the cache entry.
|
||||
func DeleteCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, index I, keys ...K) Command {
|
||||
return &deleteCache[I, K, V]{
|
||||
cache: cache,
|
||||
command: command,
|
||||
index: index,
|
||||
keys: keys,
|
||||
}
|
||||
}
|
||||
|
||||
var _ Command = (*deleteCache[any, any, cache.Entry[any, any]])(nil)
|
||||
|
||||
// Execute implements [Command].
|
||||
func (s *deleteCache[I, K, V]) Execute(ctx context.Context) error {
|
||||
if err := s.command.Execute(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.cache.Delete(ctx, s.index, s.keys...)
|
||||
}
|
||||
|
||||
// Name implements [Command].
|
||||
func (s *deleteCache[I, K, V]) Name() string {
|
||||
return s.command.Name()
|
||||
}
|
||||
|
||||
type invalidateCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
cache cache.Cache[I, K, V]
|
||||
command Command
|
||||
index I
|
||||
keys []K
|
||||
}
|
||||
|
||||
// InvalidateCache decorates the command, if the command is executed without error it will invalidate the cache entry.
|
||||
func InvalidateCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, index I, keys ...K) Command {
|
||||
return &invalidateCache[I, K, V]{
|
||||
cache: cache,
|
||||
command: command,
|
||||
index: index,
|
||||
keys: keys,
|
||||
}
|
||||
}
|
||||
|
||||
var _ Command = (*invalidateCache[any, any, cache.Entry[any, any]])(nil)
|
||||
|
||||
// Execute implements [Command].
|
||||
func (s *invalidateCache[I, K, V]) Execute(ctx context.Context) error {
|
||||
if err := s.command.Execute(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
return s.cache.Invalidate(ctx, s.index, s.keys...)
|
||||
}
|
||||
|
||||
// Name implements [Command].
|
||||
func (s *invalidateCache[I, K, V]) Name() string {
|
||||
return s.command.Name()
|
||||
}
|
@@ -1,6 +1,10 @@
|
||||
package command
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/receiver/cache"
|
||||
)
|
||||
|
||||
type Command interface {
|
||||
Execute(context.Context) error
|
||||
@@ -20,3 +24,8 @@ func (b *Batch) Execute(ctx context.Context) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type CacheableCommand[I, K comparable, V cache.Entry[I, K]] interface {
|
||||
Command
|
||||
Entry() V
|
||||
}
|
||||
|
@@ -65,7 +65,7 @@ func UpdateInstance(receiver receiver.InstanceManipulator, instance *receiver.In
|
||||
|
||||
func (u *updateInstance) Execute(ctx context.Context) error {
|
||||
u.Instance.Name = u.name
|
||||
// return u.receiver.(ctx, u.Instance)
|
||||
// return u.receiver.Update(ctx, u.Instance)
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@@ -8,21 +8,28 @@ import (
|
||||
"github.com/zitadel/zitadel/backend/telemetry/logging"
|
||||
)
|
||||
|
||||
type Logger struct {
|
||||
type logger struct {
|
||||
level slog.Level
|
||||
*logging.Logger
|
||||
cmd Command
|
||||
}
|
||||
|
||||
func Activity(l *logging.Logger, command Command) *Logger {
|
||||
return &Logger{
|
||||
// Activity decorates the commands execute method with logging.
|
||||
// It logs the command name, duration, and success or failure of the command.
|
||||
func Activity(l *logging.Logger, command Command) Command {
|
||||
return &logger{
|
||||
Logger: l.With(slog.String("type", "activity")),
|
||||
level: slog.LevelInfo,
|
||||
cmd: command,
|
||||
}
|
||||
}
|
||||
|
||||
func (l *Logger) Execute(ctx context.Context) error {
|
||||
// Name implements [Command].
|
||||
func (l *logger) Name() string {
|
||||
return l.cmd.Name()
|
||||
}
|
||||
|
||||
func (l *logger) Execute(ctx context.Context) error {
|
||||
start := time.Now()
|
||||
log := l.Logger.With(slog.String("command", l.cmd.Name()))
|
||||
log.InfoContext(ctx, "execute")
|
||||
|
@@ -6,12 +6,27 @@ import (
|
||||
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
||||
)
|
||||
|
||||
type Trace struct {
|
||||
type trace struct {
|
||||
command Command
|
||||
tracer *tracing.Tracer
|
||||
}
|
||||
|
||||
func (t *Trace) Execute(ctx context.Context) error {
|
||||
// Trace decorates the commands execute method with tracing.
|
||||
// It creates a span with the command name and records any errors that occur during execution.
|
||||
// The span is ended after the command is executed.
|
||||
func Trace(tracer *tracing.Tracer, command Command) Command {
|
||||
return &trace{
|
||||
command: command,
|
||||
tracer: tracer,
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements [Command].
|
||||
func (l *trace) Name() string {
|
||||
return l.command.Name()
|
||||
}
|
||||
|
||||
func (t *trace) Execute(ctx context.Context) error {
|
||||
ctx, span := t.tracer.Start(ctx, t.command.Name())
|
||||
defer span.End()
|
||||
err := t.command.Execute(ctx)
|
||||
|
@@ -1,38 +0,0 @@
|
||||
package invoker
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/command"
|
||||
"github.com/zitadel/zitadel/backend/command/query"
|
||||
"github.com/zitadel/zitadel/backend/command/receiver"
|
||||
"github.com/zitadel/zitadel/backend/command/receiver/db"
|
||||
"github.com/zitadel/zitadel/backend/storage/database"
|
||||
)
|
||||
|
||||
type api struct {
|
||||
db database.Pool
|
||||
|
||||
manipulator receiver.InstanceManipulator
|
||||
reader receiver.InstanceReader
|
||||
}
|
||||
|
||||
func (a *api) CreateInstance(ctx context.Context) error {
|
||||
cmd := command.CreateInstance(db.NewInstance(a.db), &receiver.Instance{
|
||||
ID: "123",
|
||||
Name: "test",
|
||||
})
|
||||
return cmd.Execute(ctx)
|
||||
}
|
||||
|
||||
func (a *api) DeleteInstance(ctx context.Context) error {
|
||||
cmd := command.DeleteInstance(db.NewInstance(a.db), &receiver.Instance{
|
||||
ID: "123",
|
||||
})
|
||||
return cmd.Execute(ctx)
|
||||
}
|
||||
|
||||
func (a *api) InstanceByID(ctx context.Context) (*receiver.Instance, error) {
|
||||
q := query.InstanceByID(a.reader, "123")
|
||||
return q.Execute(ctx)
|
||||
}
|
112
backend/command/receiver/cache/cache.go
vendored
Normal file
112
backend/command/receiver/cache/cache.go
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package cache provides abstraction of cache implementations that can be used by zitadel.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Purpose describes which object types are stored by a cache.
|
||||
type Purpose int
|
||||
|
||||
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
|
||||
const (
|
||||
PurposeUnspecified Purpose = iota
|
||||
PurposeAuthzInstance
|
||||
PurposeMilestones
|
||||
PurposeOrganization
|
||||
PurposeIdPFormCallback
|
||||
)
|
||||
|
||||
// Cache stores objects with a value of type `V`.
|
||||
// Objects may be referred to by one or more indices.
|
||||
// Implementations may encode the value for storage.
|
||||
// This means non-exported fields may be lost and objects
|
||||
// with function values may fail to encode.
|
||||
// See https://pkg.go.dev/encoding/json#Marshal for example.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
// Indices are defined when calling the constructor of an implementation of this interface.
|
||||
// It is illegal to refer to an idex not defined during construction.
|
||||
//
|
||||
// `K` is the type used as key in each index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
//
|
||||
// Implementations are free to use stricter type constraints or fixed typing.
|
||||
type Cache[I, K comparable, V Entry[I, K]] interface {
|
||||
// Get an object through specified index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// [ErrCacheMiss] is returned if the key was not found in the index,
|
||||
// or the object is not valid.
|
||||
Get(ctx context.Context, index I, key K) (V, bool)
|
||||
|
||||
// Set an object.
|
||||
// Keys are created on each index based in the [Entry.Keys] method.
|
||||
// If any key maps to an existing object, the object is invalidated,
|
||||
// regardless if the object has other keys defined in the new entry.
|
||||
// This to prevent ghost objects when an entry reduces the amount of keys
|
||||
// for a given index.
|
||||
Set(ctx context.Context, value V)
|
||||
|
||||
// Invalidate an object through specified index.
|
||||
// Implementations may choose to instantly delete the object,
|
||||
// defer until prune or a separate cleanup routine.
|
||||
// Invalidated object are no longer returned from Get.
|
||||
// It is safe to call Invalidate multiple times or on non-existing entries.
|
||||
Invalidate(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Delete one or more keys from a specific index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// The referred object is not invalidated and may still be accessible though
|
||||
// other indices and keys.
|
||||
// It is safe to call Delete multiple times or on non-existing entries
|
||||
Delete(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Truncate deletes all cached objects.
|
||||
Truncate(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Entry contains a value of type `V` to be cached.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
//
|
||||
// `K` is the type used as key in an index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
type Entry[I, K comparable] interface {
|
||||
// Keys returns which keys map to the object in a specified index.
|
||||
// May return nil if the index in unknown or when there are no keys.
|
||||
Keys(index I) (key []K)
|
||||
}
|
||||
|
||||
type Connector int
|
||||
|
||||
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
|
||||
const (
|
||||
// Empty line comment ensures empty string for unspecified value
|
||||
ConnectorUnspecified Connector = iota //
|
||||
ConnectorMemory
|
||||
ConnectorPostgres
|
||||
ConnectorRedis
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Connector Connector
|
||||
|
||||
// Age since an object was added to the cache,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables max age checks.
|
||||
MaxAge time.Duration
|
||||
|
||||
// Age since last use (Get) of an object,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables last use age checks.
|
||||
LastUseAge time.Duration
|
||||
|
||||
// Log allows logging of the specific cache.
|
||||
// By default only errors are logged to stdout.
|
||||
Log *logging.Config
|
||||
}
|
49
backend/command/receiver/cache/connector/connector.go
vendored
Normal file
49
backend/command/receiver/cache/connector/connector.go
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||
"github.com/zitadel/zitadel/backend/storage/cache/connector/gomap"
|
||||
"github.com/zitadel/zitadel/backend/storage/cache/connector/noop"
|
||||
)
|
||||
|
||||
type CachesConfig struct {
|
||||
Connectors struct {
|
||||
Memory gomap.Config
|
||||
}
|
||||
Instance *cache.Config
|
||||
Milestones *cache.Config
|
||||
Organization *cache.Config
|
||||
IdPFormCallbacks *cache.Config
|
||||
}
|
||||
|
||||
type Connectors struct {
|
||||
Config CachesConfig
|
||||
Memory *gomap.Connector
|
||||
}
|
||||
|
||||
func StartConnectors(conf *CachesConfig) (Connectors, error) {
|
||||
if conf == nil {
|
||||
return Connectors{}, nil
|
||||
}
|
||||
return Connectors{
|
||||
Config: *conf,
|
||||
Memory: gomap.NewConnector(conf.Connectors.Memory),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
|
||||
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
|
||||
return noop.NewCache[I, K, V](), nil
|
||||
}
|
||||
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
|
||||
c := gomap.NewCache[I, K, V](background, indices, *conf)
|
||||
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
|
||||
}
|
23
backend/command/receiver/cache/connector/gomap/connector.go
vendored
Normal file
23
backend/command/receiver/cache/connector/gomap/connector.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
AutoPrune cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
Config cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
func NewConnector(config Config) *Connector {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
return &Connector{
|
||||
Config: config.AutoPrune,
|
||||
}
|
||||
}
|
200
backend/command/receiver/cache/connector/gomap/gomap.go
vendored
Normal file
200
backend/command/receiver/cache/connector/gomap/gomap.go
vendored
Normal file
@@ -0,0 +1,200 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||
)
|
||||
|
||||
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
config *cache.Config
|
||||
indexMap map[I]*index[K, V]
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
|
||||
// Object values are stored as-is and there is no encoding or decoding involved.
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
|
||||
m := &mapCache[I, K, V]{
|
||||
config: &config,
|
||||
indexMap: make(map[I]*index[K, V], len(indices)),
|
||||
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelError,
|
||||
})),
|
||||
}
|
||||
if config.Log != nil {
|
||||
m.logger = config.Log.Slog()
|
||||
}
|
||||
m.logger.InfoContext(background, "map cache logging enabled")
|
||||
|
||||
for _, name := range indices {
|
||||
m.indexMap[name] = &index[K, V]{
|
||||
config: m.config,
|
||||
entries: make(map[K]*entry[V]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
entry, err := i.Get(key)
|
||||
if err == nil {
|
||||
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
|
||||
return entry.value, true
|
||||
}
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
|
||||
now := time.Now()
|
||||
entry := &entry[V]{
|
||||
value: value,
|
||||
created: now,
|
||||
}
|
||||
entry.lastUse.Store(now.UnixMicro())
|
||||
|
||||
for name, i := range c.indexMap {
|
||||
keys := value.Keys(name)
|
||||
i.Set(keys, entry)
|
||||
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Invalidate(keys)
|
||||
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Delete(keys)
|
||||
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Prune()
|
||||
c.logger.DebugContext(ctx, "map cache prune", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Truncate()
|
||||
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type index[K comparable, V any] struct {
|
||||
mutex sync.RWMutex
|
||||
config *cache.Config
|
||||
entries map[K]*entry[V]
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Get(key K) (*entry[V], error) {
|
||||
i.mutex.RLock()
|
||||
entry, ok := i.entries[key]
|
||||
i.mutex.RUnlock()
|
||||
if ok && entry.isValid(i.config) {
|
||||
return entry, nil
|
||||
}
|
||||
return nil, cache.ErrCacheMiss
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
c.entries[key] = entry
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Invalidate(keys []K) {
|
||||
i.mutex.RLock()
|
||||
for _, key := range keys {
|
||||
if entry, ok := i.entries[key]; ok {
|
||||
entry.invalid.Store(true)
|
||||
}
|
||||
}
|
||||
i.mutex.RUnlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Delete(keys []K) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
delete(c.entries, key)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Prune() {
|
||||
c.mutex.Lock()
|
||||
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
|
||||
return !entry.isValid(c.config)
|
||||
})
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Truncate() {
|
||||
c.mutex.Lock()
|
||||
c.entries = make(map[K]*entry[V])
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
type entry[V any] struct {
|
||||
value V
|
||||
created time.Time
|
||||
invalid atomic.Bool
|
||||
lastUse atomic.Int64 // UnixMicro time
|
||||
}
|
||||
|
||||
func (e *entry[V]) isValid(c *cache.Config) bool {
|
||||
if e.invalid.Load() {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if c.MaxAge > 0 {
|
||||
if e.created.Add(c.MaxAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.LastUseAge > 0 {
|
||||
lastUse := e.lastUse.Load()
|
||||
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
|
||||
}
|
||||
return true
|
||||
}
|
329
backend/command/receiver/cache/connector/gomap/gomap_test.go
vendored
Normal file
329
backend/command/receiver/cache/connector/gomap/gomap_test.go
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||
)
|
||||
|
||||
type testIndex int
|
||||
|
||||
const (
|
||||
testIndexID testIndex = iota
|
||||
testIndexName
|
||||
)
|
||||
|
||||
var testIndices = []testIndex{
|
||||
testIndexID,
|
||||
testIndexName,
|
||||
}
|
||||
|
||||
type testObject struct {
|
||||
id string
|
||||
names []string
|
||||
}
|
||||
|
||||
func (o *testObject) Keys(index testIndex) []string {
|
||||
switch index {
|
||||
case testIndexID:
|
||||
return []string{o.id}
|
||||
case testIndexName:
|
||||
return o.names
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Get(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
|
||||
type args struct {
|
||||
index testIndex
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *testObject
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id",
|
||||
},
|
||||
want: obj,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "miss",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "spanac",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "unknown index",
|
||||
args: args{
|
||||
index: 99,
|
||||
key: "id",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantOk, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Invalidate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Delete(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Delete(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Shouldn't find object by deleted name
|
||||
got, ok := c.Get(context.Background(), testIndexName, "bar")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Should find object by other name
|
||||
got, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Should find object by id
|
||||
got, ok = c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Prune(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
// invalidate one entry
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.(cache.Pruner).Prune(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Other object should still be found
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id2")
|
||||
assert.Equal(t, objects[1], got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Truncate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
|
||||
err := c.Truncate(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
mc := c.(*mapCache[testIndex, string, *testObject])
|
||||
for _, index := range mc.indexMap {
|
||||
index.mutex.RLock()
|
||||
assert.Len(t, index.entries, 0)
|
||||
index.mutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_entry_isValid(t *testing.T) {
|
||||
type fields struct {
|
||||
created time.Time
|
||||
invalid bool
|
||||
lastUse time.Time
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
config *cache.Config
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: true,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "last use age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "last use age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &entry[any]{
|
||||
created: tt.fields.created,
|
||||
}
|
||||
e.invalid.Store(tt.fields.invalid)
|
||||
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
|
||||
got := e.isValid(tt.config)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
21
backend/command/receiver/cache/connector/noop/noop.go
vendored
Normal file
21
backend/command/receiver/cache/connector/noop/noop.go
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package noop
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||
)
|
||||
|
||||
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
|
||||
|
||||
// NewCache returns a cache that does nothing
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
|
||||
return noop[I, K, V]{}
|
||||
}
|
||||
|
||||
func (noop[I, K, V]) Set(context.Context, V) {}
|
||||
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
|
||||
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
|
||||
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }
|
98
backend/command/receiver/cache/connector_enumer.go
vendored
Normal file
98
backend/command/receiver/cache/connector_enumer.go
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _ConnectorName = "memorypostgresredis"
|
||||
|
||||
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
|
||||
|
||||
const _ConnectorLowerName = "memorypostgresredis"
|
||||
|
||||
func (i Connector) String() string {
|
||||
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
|
||||
return fmt.Sprintf("Connector(%d)", i)
|
||||
}
|
||||
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _ConnectorNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[ConnectorUnspecified-(0)]
|
||||
_ = x[ConnectorMemory-(1)]
|
||||
_ = x[ConnectorPostgres-(2)]
|
||||
_ = x[ConnectorRedis-(3)]
|
||||
}
|
||||
|
||||
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
|
||||
|
||||
var _ConnectorNameToValueMap = map[string]Connector{
|
||||
_ConnectorName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorLowerName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorName[0:6]: ConnectorMemory,
|
||||
_ConnectorLowerName[0:6]: ConnectorMemory,
|
||||
_ConnectorName[6:14]: ConnectorPostgres,
|
||||
_ConnectorLowerName[6:14]: ConnectorPostgres,
|
||||
_ConnectorName[14:19]: ConnectorRedis,
|
||||
_ConnectorLowerName[14:19]: ConnectorRedis,
|
||||
}
|
||||
|
||||
var _ConnectorNames = []string{
|
||||
_ConnectorName[0:0],
|
||||
_ConnectorName[0:6],
|
||||
_ConnectorName[6:14],
|
||||
_ConnectorName[14:19],
|
||||
}
|
||||
|
||||
// ConnectorString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func ConnectorString(s string) (Connector, error) {
|
||||
if val, ok := _ConnectorNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Connector values", s)
|
||||
}
|
||||
|
||||
// ConnectorValues returns all values of the enum
|
||||
func ConnectorValues() []Connector {
|
||||
return _ConnectorValues
|
||||
}
|
||||
|
||||
// ConnectorStrings returns a slice of all String values of the enum
|
||||
func ConnectorStrings() []string {
|
||||
strs := make([]string, len(_ConnectorNames))
|
||||
copy(strs, _ConnectorNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Connector) IsAConnector() bool {
|
||||
for _, v := range _ConnectorValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for Connector
|
||||
func (i Connector) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
|
||||
func (i *Connector) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = ConnectorString(string(text))
|
||||
return err
|
||||
}
|
29
backend/command/receiver/cache/error.go
vendored
Normal file
29
backend/command/receiver/cache/error.go
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type IndexUnknownError[I comparable] struct {
|
||||
index I
|
||||
}
|
||||
|
||||
func NewIndexUnknownErr[I comparable](index I) error {
|
||||
return IndexUnknownError[I]{index}
|
||||
}
|
||||
|
||||
func (i IndexUnknownError[I]) Error() string {
|
||||
return fmt.Sprintf("index %v unknown", i.index)
|
||||
}
|
||||
|
||||
func (a IndexUnknownError[I]) Is(err error) bool {
|
||||
if b, ok := err.(IndexUnknownError[I]); ok {
|
||||
return a.index == b.index
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
)
|
76
backend/command/receiver/cache/pruner.go
vendored
Normal file
76
backend/command/receiver/cache/pruner.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Pruner is an optional [Cache] interface.
|
||||
type Pruner interface {
|
||||
// Prune deletes all invalidated or expired objects.
|
||||
Prune(ctx context.Context) error
|
||||
}
|
||||
|
||||
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
|
||||
Cache[I, K, V]
|
||||
Pruner
|
||||
}
|
||||
|
||||
type AutoPruneConfig struct {
|
||||
// Interval at which the cache is automatically pruned.
|
||||
// 0 or lower disables automatic pruning.
|
||||
Interval time.Duration
|
||||
|
||||
// Timeout for an automatic prune.
|
||||
// It is recommended to keep the value shorter than AutoPruneInterval
|
||||
// 0 or lower disables automatic pruning.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
|
||||
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
|
||||
if c.Interval <= 0 {
|
||||
return func() {}
|
||||
}
|
||||
background, cancel := context.WithCancel(background)
|
||||
// randomize the first interval
|
||||
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
|
||||
go c.pruneTimer(background, pruner, purpose, timer)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
<-timer.Chan()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-background.Done():
|
||||
return
|
||||
case <-timer.Chan():
|
||||
err := c.doPrune(background, pruner)
|
||||
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
|
||||
timer.Reset(c.Interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
|
||||
ctx, cancel := context.WithCancel(background)
|
||||
defer cancel()
|
||||
if c.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(background, c.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
return pruner.Prune(ctx)
|
||||
}
|
43
backend/command/receiver/cache/pruner_test.go
vendored
Normal file
43
backend/command/receiver/cache/pruner_test.go
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testPruner struct {
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func (p *testPruner) Prune(context.Context) error {
|
||||
p.called <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
|
||||
c := AutoPruneConfig{
|
||||
Interval: time.Second,
|
||||
Timeout: time.Millisecond,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
pruner := testPruner{
|
||||
called: make(chan struct{}),
|
||||
}
|
||||
clock := clockwork.NewFakeClock()
|
||||
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
|
||||
defer close()
|
||||
clock.Advance(time.Second)
|
||||
|
||||
select {
|
||||
case _, ok := <-pruner.called:
|
||||
assert.True(t, ok)
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
90
backend/command/receiver/cache/purpose_enumer.go
vendored
Normal file
90
backend/command/receiver/cache/purpose_enumer.go
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
|
||||
|
||||
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
func (i Purpose) String() string {
|
||||
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
|
||||
return fmt.Sprintf("Purpose(%d)", i)
|
||||
}
|
||||
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _PurposeNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[PurposeUnspecified-(0)]
|
||||
_ = x[PurposeAuthzInstance-(1)]
|
||||
_ = x[PurposeMilestones-(2)]
|
||||
_ = x[PurposeOrganization-(3)]
|
||||
_ = x[PurposeIdPFormCallback-(4)]
|
||||
}
|
||||
|
||||
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
|
||||
|
||||
var _PurposeNameToValueMap = map[string]Purpose{
|
||||
_PurposeName[0:11]: PurposeUnspecified,
|
||||
_PurposeLowerName[0:11]: PurposeUnspecified,
|
||||
_PurposeName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeLowerName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeName[25:35]: PurposeMilestones,
|
||||
_PurposeLowerName[25:35]: PurposeMilestones,
|
||||
_PurposeName[35:47]: PurposeOrganization,
|
||||
_PurposeLowerName[35:47]: PurposeOrganization,
|
||||
_PurposeName[47:65]: PurposeIdPFormCallback,
|
||||
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
|
||||
}
|
||||
|
||||
var _PurposeNames = []string{
|
||||
_PurposeName[0:11],
|
||||
_PurposeName[11:25],
|
||||
_PurposeName[25:35],
|
||||
_PurposeName[35:47],
|
||||
_PurposeName[47:65],
|
||||
}
|
||||
|
||||
// PurposeString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func PurposeString(s string) (Purpose, error) {
|
||||
if val, ok := _PurposeNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
|
||||
}
|
||||
|
||||
// PurposeValues returns all values of the enum
|
||||
func PurposeValues() []Purpose {
|
||||
return _PurposeValues
|
||||
}
|
||||
|
||||
// PurposeStrings returns a slice of all String values of the enum
|
||||
func PurposeStrings() []string {
|
||||
strs := make([]string, len(_PurposeNames))
|
||||
copy(strs, _PurposeNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Purpose) IsAPurpose() bool {
|
||||
for _, v := range _PurposeValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
@@ -1,6 +1,10 @@
|
||||
package receiver
|
||||
|
||||
import "context"
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/receiver/cache"
|
||||
)
|
||||
|
||||
type InstanceState uint8
|
||||
|
||||
@@ -16,6 +20,31 @@ type Instance struct {
|
||||
Domains []*Domain
|
||||
}
|
||||
|
||||
type InstanceIndex uint8
|
||||
|
||||
var InstanceIndices = []InstanceIndex{
|
||||
InstanceByID,
|
||||
InstanceByDomain,
|
||||
}
|
||||
|
||||
const (
|
||||
InstanceByID InstanceIndex = iota
|
||||
InstanceByDomain
|
||||
)
|
||||
|
||||
var _ cache.Entry[InstanceIndex, string] = (*Instance)(nil)
|
||||
|
||||
// Keys implements [cache.Entry].
|
||||
func (i *Instance) Keys(index InstanceIndex) (key []string) {
|
||||
switch index {
|
||||
case InstanceByID:
|
||||
return []string{i.ID}
|
||||
case InstanceByDomain:
|
||||
return []string{i.Name}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type InstanceManipulator interface {
|
||||
Create(ctx context.Context, instance *Instance) error
|
||||
Delete(ctx context.Context, instance *Instance) error
|
||||
|
6
backend/command/v2/api/doc.go
Normal file
6
backend/command/v2/api/doc.go
Normal file
@@ -0,0 +1,6 @@
|
||||
// The API package implements the protobuf stubs
|
||||
|
||||
// It uses the Chain of responsibility pattern to handle requests in a modular way
|
||||
// It implements the client or invoker of the command pattern.
|
||||
// The client is responsible for creating the concrete command and setting its receiver.
|
||||
package api
|
35
backend/command/v2/api/user/v2/email.go
Normal file
35
backend/command/v2/api/user/v2/email.go
Normal file
@@ -0,0 +1,35 @@
|
||||
package userv2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/zitadel/zitadel/backend/command/v2/domain"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
|
||||
func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
|
||||
request := &domain.SetUserEmail{
|
||||
UserID: req.GetUserId(),
|
||||
Email: req.GetEmail(),
|
||||
}
|
||||
switch req.GetVerification().(type) {
|
||||
case *user.SetEmailRequest_IsVerified:
|
||||
request.IsVerified = gu.Ptr(req.GetIsVerified())
|
||||
case *user.SetEmailRequest_SendCode:
|
||||
request.SendCode = &domain.SendCode{
|
||||
URLTemplate: req.GetSendCode().UrlTemplate,
|
||||
}
|
||||
case *user.SetEmailRequest_ReturnCode:
|
||||
request.ReturnCode = new(domain.ReturnCode)
|
||||
}
|
||||
if err := s.domain.SetUserEmail(ctx, request); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response := new(user.SetEmailResponse)
|
||||
if request.ReturnCode != nil {
|
||||
response.VerificationCode = &request.ReturnCode.Code
|
||||
}
|
||||
return response, nil
|
||||
}
|
12
backend/command/v2/api/user/v2/server.go
Normal file
12
backend/command/v2/api/user/v2/server.go
Normal file
@@ -0,0 +1,12 @@
|
||||
package userv2
|
||||
|
||||
import (
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/domain"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
tracer trace.Tracer
|
||||
domain *domain.Domain
|
||||
}
|
41
backend/command/v2/domain/command/generate_code.go
Normal file
41
backend/command/v2/domain/command/generate_code.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
type generateCode struct {
|
||||
set func(code string)
|
||||
generator pattern.Query[crypto.Generator]
|
||||
}
|
||||
|
||||
func GenerateCode(set func(code string), generator pattern.Query[crypto.Generator]) *generateCode {
|
||||
return &generateCode{
|
||||
set: set,
|
||||
generator: generator,
|
||||
}
|
||||
}
|
||||
|
||||
var _ pattern.Command = (*generateCode)(nil)
|
||||
|
||||
// Execute implements [pattern.Command].
|
||||
func (cmd *generateCode) Execute(ctx context.Context) error {
|
||||
if err := cmd.generator.Execute(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
value, code, err := crypto.NewCode(cmd.generator.Result())
|
||||
_ = value
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.set(code)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Name implements [pattern.Command].
|
||||
func (*generateCode) Name() string {
|
||||
return "command.generate_code"
|
||||
}
|
41
backend/command/v2/domain/command/send_email_code.go
Normal file
41
backend/command/v2/domain/command/send_email_code.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
)
|
||||
|
||||
var _ pattern.Command = (*sendEmailCode)(nil)
|
||||
|
||||
type sendEmailCode struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
URLTemplate *string `json:"urlTemplate"`
|
||||
code string `json:"-"`
|
||||
}
|
||||
|
||||
func SendEmailCode(userID, email string, urlTemplate *string) pattern.Command {
|
||||
cmd := &sendEmailCode{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
URLTemplate: urlTemplate,
|
||||
}
|
||||
|
||||
return pattern.Batch(GenerateCode(cmd.SetCode, generateCode))
|
||||
}
|
||||
|
||||
// Name implements [pattern.Command].
|
||||
func (c *sendEmailCode) Name() string {
|
||||
return "user.v2.email.send_code"
|
||||
}
|
||||
|
||||
// Execute implements [pattern.Command].
|
||||
func (c *sendEmailCode) Execute(ctx context.Context) error {
|
||||
// Implementation of the command execution
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *sendEmailCode) SetCode(code string) {
|
||||
c.code = code
|
||||
}
|
39
backend/command/v2/domain/command/set_email.go
Normal file
39
backend/command/v2/domain/command/set_email.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/eventstore"
|
||||
)
|
||||
|
||||
var (
|
||||
_ eventstore.EventCommander = (*setEmail)(nil)
|
||||
)
|
||||
|
||||
type setEmail struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func SetEmail(userID, email string) *setEmail {
|
||||
return &setEmail{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
// Event implements [eventstore.EventCommander].
|
||||
func (c *setEmail) Event() *eventstore.Event {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Name implements [pattern.Command].
|
||||
func (c *setEmail) Name() string {
|
||||
return "user.v2.set_email"
|
||||
}
|
||||
|
||||
// Execute implements [pattern.Command].
|
||||
func (c *setEmail) Execute(ctx context.Context) error {
|
||||
// Implementation of the command execution
|
||||
return nil
|
||||
}
|
32
backend/command/v2/domain/command/verify_email.go
Normal file
32
backend/command/v2/domain/command/verify_email.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package command
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
)
|
||||
|
||||
var _ pattern.Command = (*verifyEmail)(nil)
|
||||
|
||||
type verifyEmail struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
func VerifyEmail(userID, email string) *verifyEmail {
|
||||
return &verifyEmail{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements [pattern.Command].
|
||||
func (c *verifyEmail) Name() string {
|
||||
return "user.v2.verify_email"
|
||||
}
|
||||
|
||||
// Execute implements [pattern.Command].
|
||||
func (c *verifyEmail) Execute(ctx context.Context) error {
|
||||
// Implementation of the command execution
|
||||
return nil
|
||||
}
|
13
backend/command/v2/domain/domain.go
Normal file
13
backend/command/v2/domain/domain.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
)
|
||||
|
||||
type Domain struct {
|
||||
pool database.Pool
|
||||
tracer trace.Tracer
|
||||
userCodeAlg crypto.EncryptionAlgorithm
|
||||
}
|
6
backend/command/v2/domain/email.go
Normal file
6
backend/command/v2/domain/email.go
Normal file
@@ -0,0 +1,6 @@
|
||||
package domain
|
||||
|
||||
type Email struct {
|
||||
Address string
|
||||
Verified bool
|
||||
}
|
42
backend/command/v2/domain/query/encryption_generator.go
Normal file
42
backend/command/v2/domain/query/encryption_generator.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
type encryptionConfigReceiver interface {
|
||||
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
|
||||
}
|
||||
|
||||
type encryptionGenerator struct {
|
||||
receiver encryptionConfigReceiver
|
||||
algorithm crypto.EncryptionAlgorithm
|
||||
|
||||
res crypto.Generator
|
||||
}
|
||||
|
||||
func QueryEncryptionGenerator(receiver encryptionConfigReceiver, algorithm crypto.EncryptionAlgorithm) *encryptionGenerator {
|
||||
return &encryptionGenerator{
|
||||
receiver: receiver,
|
||||
algorithm: algorithm,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *encryptionGenerator) Execute(ctx context.Context) error {
|
||||
config, err := q.receiver.GetEncryptionConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q.res = crypto.NewEncryptionGenerator(*config, q.algorithm)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (q *encryptionGenerator) Name() string {
|
||||
return "query.encryption_generator"
|
||||
}
|
||||
|
||||
func (q *encryptionGenerator) Result() crypto.Generator {
|
||||
return q.res
|
||||
}
|
38
backend/command/v2/domain/query/return_email_code.go
Normal file
38
backend/command/v2/domain/query/return_email_code.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
)
|
||||
|
||||
var _ pattern.Query[string] = (*returnEmailCode)(nil)
|
||||
|
||||
type returnEmailCode struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
code string `json:"-"`
|
||||
}
|
||||
|
||||
func ReturnEmailCode(userID, email string) *returnEmailCode {
|
||||
return &returnEmailCode{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}
|
||||
}
|
||||
|
||||
// Name implements [pattern.Command].
|
||||
func (c *returnEmailCode) Name() string {
|
||||
return "user.v2.email.return_code"
|
||||
}
|
||||
|
||||
// Execute implements [pattern.Command].
|
||||
func (c *returnEmailCode) Execute(ctx context.Context) error {
|
||||
// Implementation of the command execution
|
||||
return nil
|
||||
}
|
||||
|
||||
// Result implements [pattern.Query].
|
||||
func (c *returnEmailCode) Result() string {
|
||||
return c.code
|
||||
}
|
38
backend/command/v2/domain/query/user_by_id.go
Normal file
38
backend/command/v2/domain/query/user_by_id.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package query
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/domain"
|
||||
"github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
)
|
||||
|
||||
type UserByIDQuery struct {
|
||||
querier database.Querier
|
||||
UserID string `json:"userId"`
|
||||
res *domain.User
|
||||
}
|
||||
|
||||
var _ pattern.Query[*domain.User] = (*UserByIDQuery)(nil)
|
||||
|
||||
// Name implements [pattern.Command].
|
||||
func (q *UserByIDQuery) Name() string {
|
||||
return "user.v2.by_id"
|
||||
}
|
||||
|
||||
// Execute implements [pattern.Command].
|
||||
func (q *UserByIDQuery) Execute(ctx context.Context) error {
|
||||
var res *domain.User
|
||||
err := q.querier.QueryRow(ctx, "SELECT id, username, email FROM users WHERE id = $1", q.UserID).Scan(&res.ID, &res.Username, &res.Email.Address)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
q.res = res
|
||||
return nil
|
||||
}
|
||||
|
||||
// Result implements [pattern.Query].
|
||||
func (q *UserByIDQuery) Result() *domain.User {
|
||||
return q.res
|
||||
}
|
77
backend/command/v2/domain/user.go
Normal file
77
backend/command/v2/domain/user.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/domain/command"
|
||||
"github.com/zitadel/zitadel/backend/command/v2/domain/query"
|
||||
"github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/command/v2/telemetry/tracing"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
ID string
|
||||
Username string
|
||||
Email Email
|
||||
}
|
||||
|
||||
type SetUserEmail struct {
|
||||
UserID string
|
||||
Email string
|
||||
|
||||
IsVerified *bool
|
||||
ReturnCode *ReturnCode
|
||||
SendCode *SendCode
|
||||
|
||||
code string
|
||||
client database.QueryExecutor
|
||||
}
|
||||
|
||||
func (e *SetUserEmail) SetCode(code string) {
|
||||
e.code = code
|
||||
}
|
||||
|
||||
type ReturnCode struct {
|
||||
// Code is the code to be sent to the user
|
||||
Code string
|
||||
}
|
||||
|
||||
type SendCode struct {
|
||||
// URLTemplate is the template for the URL that is rendered into the message
|
||||
URLTemplate *string
|
||||
}
|
||||
|
||||
func (d *Domain) SetUserEmail(ctx context.Context, req *SetUserEmail) error {
|
||||
batch := pattern.Batch(
|
||||
tracing.Trace(d.tracer, command.SetEmail(req.UserID, req.Email)),
|
||||
)
|
||||
|
||||
if req.IsVerified == nil {
|
||||
batch.Append(command.GenerateCode(
|
||||
req.SetCode,
|
||||
query.QueryEncryptionGenerator(
|
||||
database.Query(d.pool),
|
||||
d.userCodeAlg,
|
||||
),
|
||||
))
|
||||
} else {
|
||||
batch.Append(command.VerifyEmail(req.UserID, req.Email))
|
||||
}
|
||||
|
||||
// if !req.GetVerification().GetIsVerified() {
|
||||
// batch.
|
||||
|
||||
// switch req.GetVerification().(type) {
|
||||
// case *user.SetEmailRequest_IsVerified:
|
||||
// batch.Append(tracing.Trace(s.tracer, command.VerifyEmail(req.GetUserId(), req.GetEmail())))
|
||||
// case *user.SetEmailRequest_SendCode:
|
||||
// batch.Append(tracing.Trace(s.tracer, command.SendEmailCode(req.GetUserId(), req.GetEmail(), req.GetSendCode().UrlTemplate)))
|
||||
// case *user.SetEmailRequest_ReturnCode:
|
||||
// batch.Append(tracing.Trace(s.tracer, query.ReturnEmailCode(req.GetUserId(), req.GetEmail())))
|
||||
// }
|
||||
|
||||
// if err := batch.Execute(ctx); err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
}
|
100
backend/command/v2/pattern/command.go
Normal file
100
backend/command/v2/pattern/command.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package pattern
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
)
|
||||
|
||||
// Command implements the command pattern.
|
||||
// It is used to encapsulate a request as an object, thereby allowing for parameterization of clients with queues, requests, and operations.
|
||||
// The command pattern allows for the decoupling of the sender and receiver of a request.
|
||||
// It is often used in conjunction with the invoker pattern, which is responsible for executing the command.
|
||||
// The command pattern is a behavioral design pattern that turns a request into a stand-alone object.
|
||||
// This object contains all the information about the request.
|
||||
// The command pattern is useful for implementing undo/redo functionality, logging, and queuing requests.
|
||||
// It is also useful for implementing the macro command pattern, which allows for the execution of a series of commands as a single command.
|
||||
// The command pattern is also used in event-driven architectures, where events are encapsulated as commands.
|
||||
type Command interface {
|
||||
Execute(ctx context.Context) error
|
||||
Name() string
|
||||
}
|
||||
|
||||
type Query[T any] interface {
|
||||
Command
|
||||
Result() T
|
||||
}
|
||||
|
||||
type Invoker struct{}
|
||||
|
||||
// func bla() {
|
||||
// sync.Pool{
|
||||
// New: func() any {
|
||||
// return new(Invoker)
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
|
||||
type Transaction struct {
|
||||
beginner database.Beginner
|
||||
cmd Command
|
||||
opts *database.TransactionOptions
|
||||
}
|
||||
|
||||
func (t *Transaction) Execute(ctx context.Context) error {
|
||||
tx, err := t.beginner.Begin(ctx, t.opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = tx.End(ctx, err) }()
|
||||
return t.cmd.Execute(ctx)
|
||||
}
|
||||
|
||||
func (t *Transaction) Name() string {
|
||||
return t.cmd.Name()
|
||||
}
|
||||
|
||||
type batch struct {
|
||||
Commands []Command
|
||||
}
|
||||
|
||||
func Batch(cmds ...Command) *batch {
|
||||
return &batch{
|
||||
Commands: cmds,
|
||||
}
|
||||
}
|
||||
|
||||
func (b *batch) Execute(ctx context.Context) error {
|
||||
for _, cmd := range b.Commands {
|
||||
if err := cmd.Execute(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *batch) Name() string {
|
||||
return "batch"
|
||||
}
|
||||
|
||||
func (b *batch) Append(cmds ...Command) {
|
||||
b.Commands = append(b.Commands, cmds...)
|
||||
}
|
||||
|
||||
type NoopCommand struct{}
|
||||
|
||||
func (c *NoopCommand) Execute(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
func (c *NoopCommand) Name() string {
|
||||
return "noop"
|
||||
}
|
||||
|
||||
type NoopQuery[T any] struct {
|
||||
NoopCommand
|
||||
}
|
||||
|
||||
func (q *NoopQuery[T]) Result() T {
|
||||
var zero T
|
||||
return zero
|
||||
}
|
9
backend/command/v2/storage/database/config.go
Normal file
9
backend/command/v2/storage/database/config.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Connector interface {
|
||||
Connect(ctx context.Context) (Pool, error)
|
||||
}
|
54
backend/command/v2/storage/database/database.go
Normal file
54
backend/command/v2/storage/database/database.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
var (
|
||||
db *database
|
||||
)
|
||||
|
||||
type database struct {
|
||||
connector Connector
|
||||
pool Pool
|
||||
}
|
||||
|
||||
type Pool interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
|
||||
Acquire(ctx context.Context) (Client, error)
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
|
||||
Release(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Querier interface {
|
||||
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, stmt string, args ...any) Row
|
||||
}
|
||||
|
||||
type Executor interface {
|
||||
Exec(ctx context.Context, stmt string, args ...any) error
|
||||
}
|
||||
|
||||
type Row interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
type Rows interface {
|
||||
Row
|
||||
Next() bool
|
||||
Close() error
|
||||
Err() error
|
||||
}
|
||||
|
||||
type QueryExecutor interface {
|
||||
Querier
|
||||
Executor
|
||||
}
|
92
backend/command/v2/storage/database/dialect/config.go
Normal file
92
backend/command/v2/storage/database/dialect/config.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
|
||||
)
|
||||
|
||||
type Hook struct {
|
||||
Match func(string) bool
|
||||
Decode func(config any) (database.Connector, error)
|
||||
Name string
|
||||
Constructor func() database.Connector
|
||||
}
|
||||
|
||||
var hooks = []Hook{
|
||||
{
|
||||
Match: postgres.NameMatcher,
|
||||
Decode: postgres.DecodeConfig,
|
||||
Name: postgres.Name,
|
||||
Constructor: func() database.Connector { return new(postgres.Config) },
|
||||
},
|
||||
// {
|
||||
// Match: gosql.NameMatcher,
|
||||
// Decode: gosql.DecodeConfig,
|
||||
// Name: gosql.Name,
|
||||
// Constructor: func() database.Connector { return new(gosql.Config) },
|
||||
// },
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
|
||||
|
||||
connector database.Connector
|
||||
}
|
||||
|
||||
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
if len(c.Dialects) != 1 {
|
||||
return nil, errors.New("Exactly one dialect must be configured")
|
||||
}
|
||||
|
||||
return c.connector.Connect(ctx)
|
||||
}
|
||||
|
||||
// Hooks implements [configure.Unmarshaller].
|
||||
func (c Config) Hooks() []viper.DecoderConfigOption {
|
||||
return []viper.DecoderConfigOption{
|
||||
viper.DecodeHook(decodeHook),
|
||||
}
|
||||
}
|
||||
|
||||
func decodeHook(from, to reflect.Value) (_ any, err error) {
|
||||
if to.Type() != reflect.TypeOf(Config{}) {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
config := new(Config)
|
||||
if err = mapstructure.Decode(from.Interface(), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = config.decodeDialect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (c *Config) decodeDialect() error {
|
||||
for _, hook := range hooks {
|
||||
for name, config := range c.Dialects {
|
||||
if !hook.Match(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
connector, err := hook.Decode(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.connector = connector
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("no dialect found")
|
||||
}
|
@@ -0,0 +1,80 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
)
|
||||
|
||||
var (
|
||||
_ database.Connector = (*Config)(nil)
|
||||
Name = "postgres"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
config *pgxpool.Config
|
||||
|
||||
// Host string
|
||||
// Port int32
|
||||
// Database string
|
||||
// MaxOpenConns uint32
|
||||
// MaxIdleConns uint32
|
||||
// MaxConnLifetime time.Duration
|
||||
// MaxConnIdleTime time.Duration
|
||||
// User User
|
||||
// // Additional options to be appended as options=<Options>
|
||||
// // The value will be taken as is. Multiple options are space separated.
|
||||
// Options string
|
||||
|
||||
configuredFields []string
|
||||
}
|
||||
|
||||
// Connect implements [database.Connector].
|
||||
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
pool, err := pgxpool.NewWithConfig(ctx, c.config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = pool.Ping(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxPool{pool}, nil
|
||||
}
|
||||
|
||||
func NameMatcher(name string) bool {
|
||||
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
|
||||
}
|
||||
|
||||
func DecodeConfig(input any) (database.Connector, error) {
|
||||
switch c := input.(type) {
|
||||
case string:
|
||||
config, err := pgxpool.ParseConfig(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{config: config}, nil
|
||||
case map[string]any:
|
||||
connector := new(Config)
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: connector,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = decoder.Decode(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{
|
||||
config: &pgxpool.Config{},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid configuration")
|
||||
}
|
48
backend/command/v2/storage/database/dialect/postgres/conn.go
Normal file
48
backend/command/v2/storage/database/dialect/postgres/conn.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
)
|
||||
|
||||
type pgxConn struct{ *pgxpool.Conn }
|
||||
|
||||
var _ database.Client = (*pgxConn)(nil)
|
||||
|
||||
// Release implements [database.Client].
|
||||
func (c *pgxConn) Release(_ context.Context) error {
|
||||
c.Conn.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Begin implements [database.Client].
|
||||
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Query implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
|
||||
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Conn.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
|
||||
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return c.Conn.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := c.Conn.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
57
backend/command/v2/storage/database/dialect/postgres/pool.go
Normal file
57
backend/command/v2/storage/database/dialect/postgres/pool.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
)
|
||||
|
||||
type pgxPool struct{ *pgxpool.Pool }
|
||||
|
||||
var _ database.Pool = (*pgxPool)(nil)
|
||||
|
||||
// Acquire implements [database.Pool].
|
||||
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
conn, err := c.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxConn{conn}, nil
|
||||
}
|
||||
|
||||
// Query implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Pool.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return c.Pool.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := c.Pool.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Begin implements [database.Pool].
|
||||
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Close implements [database.Pool].
|
||||
func (c *pgxPool) Close(_ context.Context) error {
|
||||
c.Pool.Close()
|
||||
return nil
|
||||
}
|
18
backend/command/v2/storage/database/dialect/postgres/rows.go
Normal file
18
backend/command/v2/storage/database/dialect/postgres/rows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
)
|
||||
|
||||
var _ database.Rows = (*Rows)(nil)
|
||||
|
||||
type Rows struct{ pgx.Rows }
|
||||
|
||||
// Close implements [database.Rows].
|
||||
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
|
||||
func (r *Rows) Close() error {
|
||||
r.Rows.Close()
|
||||
return nil
|
||||
}
|
95
backend/command/v2/storage/database/dialect/postgres/tx.go
Normal file
95
backend/command/v2/storage/database/dialect/postgres/tx.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
|
||||
)
|
||||
|
||||
type pgxTx struct{ pgx.Tx }
|
||||
|
||||
var _ database.Transaction = (*pgxTx)(nil)
|
||||
|
||||
// Commit implements [database.Transaction].
|
||||
func (tx *pgxTx) Commit(ctx context.Context) error {
|
||||
return tx.Tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Rollback implements [database.Transaction].
|
||||
func (tx *pgxTx) Rollback(ctx context.Context) error {
|
||||
return tx.Tx.Rollback(ctx)
|
||||
}
|
||||
|
||||
// End implements [database.Transaction].
|
||||
func (tx *pgxTx) End(ctx context.Context, err error) error {
|
||||
if err != nil {
|
||||
tx.Rollback(ctx)
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Query implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
|
||||
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := tx.Tx.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
|
||||
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return tx.Tx.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := tx.Tx.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Begin implements [database.Transaction].
|
||||
// As postgres does not support nested transactions we use savepoints to emulate them.
|
||||
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
|
||||
savepoint, err := tx.Tx.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{savepoint}, nil
|
||||
}
|
||||
|
||||
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
|
||||
if opts == nil {
|
||||
return pgx.TxOptions{}
|
||||
}
|
||||
|
||||
return pgx.TxOptions{
|
||||
IsoLevel: isolationToPgx(opts.IsolationLevel),
|
||||
AccessMode: accessModeToPgx(opts.AccessMode),
|
||||
}
|
||||
}
|
||||
|
||||
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
|
||||
switch isolation {
|
||||
case database.IsolationLevelSerializable:
|
||||
return pgx.Serializable
|
||||
case database.IsolationLevelReadCommitted:
|
||||
return pgx.ReadCommitted
|
||||
default:
|
||||
return pgx.Serializable
|
||||
}
|
||||
}
|
||||
|
||||
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
|
||||
switch accessMode {
|
||||
case database.AccessModeReadWrite:
|
||||
return pgx.ReadWrite
|
||||
case database.AccessModeReadOnly:
|
||||
return pgx.ReadOnly
|
||||
default:
|
||||
return pgx.ReadWrite
|
||||
}
|
||||
}
|
39
backend/command/v2/storage/database/secret_generator.go
Normal file
39
backend/command/v2/storage/database/secret_generator.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
type query struct{ Querier }
|
||||
|
||||
func Query(querier Querier) *query {
|
||||
return &query{Querier: querier}
|
||||
}
|
||||
|
||||
const getEncryptionConfigQuery = "SELECT" +
|
||||
" length" +
|
||||
", expiry" +
|
||||
", should_include_lower_letters" +
|
||||
", should_include_upper_letters" +
|
||||
", should_include_digits" +
|
||||
", should_include_symbols" +
|
||||
" FROM encryption_config"
|
||||
|
||||
func (q query) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) {
|
||||
var config crypto.GeneratorConfig
|
||||
row := q.QueryRow(ctx, getEncryptionConfigQuery)
|
||||
err := row.Scan(
|
||||
&config.Length,
|
||||
&config.Expiry,
|
||||
&config.IncludeLowerLetters,
|
||||
&config.IncludeUpperLetters,
|
||||
&config.IncludeDigits,
|
||||
&config.IncludeSymbols,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
36
backend/command/v2/storage/database/tx.go
Normal file
36
backend/command/v2/storage/database/tx.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package database
|
||||
|
||||
import "context"
|
||||
|
||||
type Transaction interface {
|
||||
Commit(ctx context.Context) error
|
||||
Rollback(ctx context.Context) error
|
||||
End(ctx context.Context, err error) error
|
||||
|
||||
Begin(ctx context.Context) (Transaction, error)
|
||||
|
||||
QueryExecutor
|
||||
}
|
||||
|
||||
type Beginner interface {
|
||||
Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error)
|
||||
}
|
||||
|
||||
type TransactionOptions struct {
|
||||
IsolationLevel IsolationLevel
|
||||
AccessMode AccessMode
|
||||
}
|
||||
|
||||
type IsolationLevel uint8
|
||||
|
||||
const (
|
||||
IsolationLevelSerializable IsolationLevel = iota
|
||||
IsolationLevelReadCommitted
|
||||
)
|
||||
|
||||
type AccessMode uint8
|
||||
|
||||
const (
|
||||
AccessModeReadWrite AccessMode = iota
|
||||
AccessModeReadOnly
|
||||
)
|
13
backend/command/v2/storage/eventstore/event.go
Normal file
13
backend/command/v2/storage/eventstore/event.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package eventstore
|
||||
|
||||
import "github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
|
||||
type Event struct {
|
||||
AggregateType string `json:"aggregateType"`
|
||||
AggregateID string `json:"aggregateId"`
|
||||
}
|
||||
|
||||
type EventCommander interface {
|
||||
pattern.Command
|
||||
Event() *Event
|
||||
}
|
55
backend/command/v2/telemetry/tracing/command.go
Normal file
55
backend/command/v2/telemetry/tracing/command.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package tracing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"go.opentelemetry.io/otel/trace"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/command/v2/pattern"
|
||||
)
|
||||
|
||||
type command struct {
|
||||
trace.Tracer
|
||||
cmd pattern.Command
|
||||
}
|
||||
|
||||
func Trace(tracer trace.Tracer, cmd pattern.Command) pattern.Command {
|
||||
return &command{
|
||||
Tracer: tracer,
|
||||
cmd: cmd,
|
||||
}
|
||||
}
|
||||
|
||||
func (cmd *command) Name() string {
|
||||
return cmd.cmd.Name()
|
||||
}
|
||||
|
||||
func (cmd *command) Execute(ctx context.Context) error {
|
||||
ctx, span := cmd.Tracer.Start(ctx, cmd.Name())
|
||||
defer span.End()
|
||||
|
||||
err := cmd.cmd.Execute(ctx)
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
type query[T any] struct {
|
||||
command
|
||||
query pattern.Query[T]
|
||||
}
|
||||
|
||||
func Query[T any](tracer trace.Tracer, q pattern.Query[T]) pattern.Query[T] {
|
||||
return &query[T]{
|
||||
command: command{
|
||||
Tracer: tracer,
|
||||
cmd: q,
|
||||
},
|
||||
query: q,
|
||||
}
|
||||
}
|
||||
|
||||
func (q *query[T]) Result() T {
|
||||
return q.query.Result()
|
||||
}
|
19
backend/v3/api/instance/v2/server.go
Normal file
19
backend/v3/api/instance/v2/server.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package v2
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
logger logging.Logger
|
||||
tracer tracing.Tracer
|
||||
)
|
||||
|
||||
func SetLogger(l logging.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
func SetTracer(t tracing.Tracer) {
|
||||
tracer = t
|
||||
}
|
93
backend/v3/api/user/v2/email.go
Normal file
93
backend/v3/api/user/v2/email.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package userv2
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
)
|
||||
|
||||
func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
|
||||
var (
|
||||
verification domain.SetEmailOpt
|
||||
returnCode *domain.ReturnCodeCommand
|
||||
)
|
||||
|
||||
switch req.GetVerification().(type) {
|
||||
case *user.SetEmailRequest_IsVerified:
|
||||
verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified())
|
||||
case *user.SetEmailRequest_SendCode:
|
||||
verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
case *user.SetEmailRequest_ReturnCode:
|
||||
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
verification = returnCode
|
||||
default:
|
||||
verification = domain.NewSendCodeCommand(req.GetUserId(), nil)
|
||||
}
|
||||
|
||||
err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var code *string
|
||||
if returnCode != nil && returnCode.Code != "" {
|
||||
code = &returnCode.Code
|
||||
}
|
||||
|
||||
return &user.SetEmailResponse{
|
||||
VerificationCode: code,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
|
||||
var (
|
||||
returnCode *domain.ReturnCodeCommand
|
||||
cmd domain.Commander
|
||||
)
|
||||
|
||||
switch req.GetVerification().(type) {
|
||||
case *user.SendEmailCodeRequest_SendCode:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
case *user.SendEmailCodeRequest_ReturnCode:
|
||||
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
cmd = returnCode
|
||||
default:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
}
|
||||
err = domain.Invoke(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = new(user.SendEmailCodeResponse)
|
||||
if returnCode != nil {
|
||||
resp.VerificationCode = &returnCode.Code
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
|
||||
var (
|
||||
returnCode *domain.ReturnCodeCommand
|
||||
cmd domain.Commander
|
||||
)
|
||||
|
||||
switch req.GetVerification().(type) {
|
||||
case *user.ResendEmailCodeRequest_SendCode:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
case *user.ResendEmailCodeRequest_ReturnCode:
|
||||
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
cmd = returnCode
|
||||
default:
|
||||
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
}
|
||||
err = domain.Invoke(ctx, cmd)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resp = new(user.SendEmailCodeResponse)
|
||||
if returnCode != nil {
|
||||
resp.VerificationCode = &returnCode.Code
|
||||
}
|
||||
return resp, nil
|
||||
}
|
19
backend/v3/api/user/v2/server.go
Normal file
19
backend/v3/api/user/v2/server.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package userv2
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
)
|
||||
|
||||
var (
|
||||
logger logging.Logger
|
||||
tracer tracing.Tracer
|
||||
)
|
||||
|
||||
func SetLogger(l logging.Logger) {
|
||||
logger = l
|
||||
}
|
||||
|
||||
func SetTracer(t tracing.Tracer) {
|
||||
tracer = t
|
||||
}
|
12
backend/v3/doc.go
Normal file
12
backend/v3/doc.go
Normal file
@@ -0,0 +1,12 @@
|
||||
// the test used the manly relies on the following patterns:
|
||||
// - domain:
|
||||
// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package
|
||||
// - command pattern which implements the changes
|
||||
// - the invoker decorates the commands by checking for events and tracing
|
||||
// - the database connections are manged in this package
|
||||
// - the database connections are passed to the repositories
|
||||
//
|
||||
// - storage:
|
||||
// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package
|
||||
// - the repositories are used by the domain package to access the database
|
||||
package v3
|
105
backend/v3/domain/command.go
Normal file
105
backend/v3/domain/command.go
Normal file
@@ -0,0 +1,105 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Commander interface {
|
||||
Execute(ctx context.Context, opts *CommandOpts) (err error)
|
||||
}
|
||||
|
||||
type Invoker interface {
|
||||
Invoke(ctx context.Context, command Commander, opts *CommandOpts) error
|
||||
}
|
||||
|
||||
type CommandOpts struct {
|
||||
DB database.QueryExecutor
|
||||
Invoker Invoker
|
||||
}
|
||||
|
||||
type ensureTxOpts struct {
|
||||
*database.TransactionOptions
|
||||
}
|
||||
|
||||
type EnsureTransactionOpt func(*ensureTxOpts)
|
||||
|
||||
// EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction.
|
||||
// The returned close function will end the transaction. If the DB is already a transaction, the close function
|
||||
// will do nothing because another [Commander] is already responsible for ending the transaction.
|
||||
func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) {
|
||||
beginner, ok := o.DB.(database.Beginner)
|
||||
if !ok {
|
||||
// db is already a transaction
|
||||
return func(_ context.Context, err error) error {
|
||||
return err
|
||||
}, nil
|
||||
}
|
||||
|
||||
txOpts := &ensureTxOpts{
|
||||
TransactionOptions: new(database.TransactionOptions),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(txOpts)
|
||||
}
|
||||
|
||||
tx, err := beginner.Begin(ctx, txOpts.TransactionOptions)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.DB = tx
|
||||
|
||||
return func(ctx context.Context, err error) error {
|
||||
return tx.End(ctx, err)
|
||||
}, nil
|
||||
}
|
||||
|
||||
// EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool].
|
||||
// The returned close function will release the client. If the o.DB is already a client or transaction, the close function
|
||||
// will do nothing because another [Commander] is already responsible for releasing the client.
|
||||
func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) {
|
||||
pool, ok := o.DB.(database.Pool)
|
||||
if !ok {
|
||||
// o.DB is already a client
|
||||
return func(_ context.Context) error {
|
||||
return nil
|
||||
}, nil
|
||||
}
|
||||
client, err := pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.DB = client
|
||||
return func(ctx context.Context) error {
|
||||
return client.Release(ctx)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error {
|
||||
if o.Invoker == nil {
|
||||
return command.Execute(ctx, o)
|
||||
}
|
||||
return o.Invoker.Invoke(ctx, command, o)
|
||||
}
|
||||
|
||||
func DefaultOpts(invoker Invoker) *CommandOpts {
|
||||
if invoker == nil {
|
||||
invoker = &noopInvoker{}
|
||||
}
|
||||
return &CommandOpts{
|
||||
DB: pool,
|
||||
Invoker: invoker,
|
||||
}
|
||||
}
|
||||
|
||||
type noopInvoker struct {
|
||||
next Invoker
|
||||
}
|
||||
|
||||
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
|
||||
if i.next != nil {
|
||||
return i.next.Invoke(ctx, command, opts)
|
||||
}
|
||||
return command.Execute(ctx, opts)
|
||||
}
|
76
backend/v3/domain/create_user.go
Normal file
76
backend/v3/domain/create_user.go
Normal file
@@ -0,0 +1,76 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
)
|
||||
|
||||
type CreateUserCommand struct {
|
||||
user *User
|
||||
email *SetEmailCommand
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*CreateUserCommand)(nil)
|
||||
_ eventer = (*CreateUserCommand)(nil)
|
||||
)
|
||||
|
||||
func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand {
|
||||
cmd := &CreateUserCommand{
|
||||
user: &User{
|
||||
User: v4.User{
|
||||
Username: username,
|
||||
Traits: &v4.Human{},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, opt := range opts {
|
||||
opt.applyOnCreateHuman(cmd)
|
||||
}
|
||||
return cmd
|
||||
}
|
||||
|
||||
// Events implements [eventer].
|
||||
func (c *CreateUserCommand) Events() []*eventstore.Event {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// Execute implements [Commander].
|
||||
func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
if err := c.ensureUserID(); err != nil {
|
||||
return err
|
||||
}
|
||||
c.email.UserID = c.user.ID
|
||||
if err := opts.Invoke(ctx, c.email); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type CreateHumanOpt interface {
|
||||
applyOnCreateHuman(*CreateUserCommand)
|
||||
}
|
||||
|
||||
type createHumanIDOpt string
|
||||
|
||||
// applyOnCreateHuman implements [CreateHumanOpt].
|
||||
func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) {
|
||||
cmd.user.ID = string(c)
|
||||
}
|
||||
|
||||
var _ CreateHumanOpt = (*createHumanIDOpt)(nil)
|
||||
|
||||
func CreateHumanWithID(id string) CreateHumanOpt {
|
||||
return createHumanIDOpt(id)
|
||||
}
|
||||
|
||||
func (c *CreateUserCommand) ensureUserID() (err error) {
|
||||
if c.user.ID != "" {
|
||||
return nil
|
||||
}
|
||||
c.user.ID, err = generateID()
|
||||
return err
|
||||
}
|
26
backend/v3/domain/crypto.go
Normal file
26
backend/v3/domain/crypto.go
Normal file
@@ -0,0 +1,26 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
type generateCodeCommand struct {
|
||||
code string
|
||||
value *crypto.CryptoValue
|
||||
}
|
||||
|
||||
type CryptoRepository interface {
|
||||
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
|
||||
}
|
||||
|
||||
func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm)
|
||||
cmd.value, cmd.code, err = crypto.NewCode(generator)
|
||||
return err
|
||||
}
|
52
backend/v3/domain/domain.go
Normal file
52
backend/v3/domain/domain.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"math/rand/v2"
|
||||
"strconv"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
var (
|
||||
pool database.Pool
|
||||
userCodeAlgorithm crypto.EncryptionAlgorithm
|
||||
tracer tracing.Tracer
|
||||
|
||||
// userRepo func(database.QueryExecutor) UserRepository
|
||||
instanceRepo func(database.QueryExecutor) InstanceRepository
|
||||
cryptoRepo func(database.QueryExecutor) CryptoRepository
|
||||
orgRepo func(database.QueryExecutor) OrgRepository
|
||||
|
||||
instanceCache cache.Cache[string, string, *Instance]
|
||||
|
||||
generateID func() (string, error) = func() (string, error) {
|
||||
return strconv.FormatUint(rand.Uint64(), 10), nil
|
||||
}
|
||||
)
|
||||
|
||||
func SetPool(p database.Pool) {
|
||||
pool = p
|
||||
}
|
||||
|
||||
func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) {
|
||||
userCodeAlgorithm = algorithm
|
||||
}
|
||||
|
||||
func SetTracer(t tracing.Tracer) {
|
||||
tracer = t
|
||||
}
|
||||
|
||||
// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
|
||||
// userRepo = repo
|
||||
// }
|
||||
|
||||
func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
|
||||
instanceRepo = repo
|
||||
}
|
||||
|
||||
func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) {
|
||||
cryptoRepo = repo
|
||||
}
|
45
backend/v3/domain/domain_test.go
Normal file
45
backend/v3/domain/domain_test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package domain_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"go.opentelemetry.io/otel"
|
||||
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||
sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
|
||||
. "github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
)
|
||||
|
||||
func TestExample(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
// SetPool(pool)
|
||||
|
||||
exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
|
||||
require.NoError(t, err)
|
||||
tracerProvider := sdktrace.NewTracerProvider(
|
||||
sdktrace.WithSyncer(exporter),
|
||||
)
|
||||
otel.SetTracerProvider(tracerProvider)
|
||||
SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
|
||||
defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
|
||||
|
||||
SetUserRepository(repository.User)
|
||||
SetInstanceRepository(repository.Instance)
|
||||
SetCryptoRepository(repository.Crypto)
|
||||
|
||||
t.Run("verified email", func(t *testing.T) {
|
||||
err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("unverified email", func(t *testing.T) {
|
||||
err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
}
|
155
backend/v3/domain/email_verification.go
Normal file
155
backend/v3/domain/email_verification.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
|
||||
)
|
||||
|
||||
type EmailVerifiedCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email *Email `json:"email"`
|
||||
}
|
||||
|
||||
func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand {
|
||||
return &EmailVerifiedCommand{
|
||||
UserID: userID,
|
||||
Email: &Email{
|
||||
IsVerified: isVerified,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*EmailVerifiedCommand)(nil)
|
||||
_ SetEmailOpt = (*EmailVerifiedCommand)(nil)
|
||||
)
|
||||
|
||||
// Execute implements [Commander]
|
||||
func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
return userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmailVerified(ctx, cmd.Email.Address)
|
||||
}
|
||||
|
||||
// applyOnSetEmail implements [SetEmailOpt]
|
||||
func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
cmd.UserID = setEmailCmd.UserID
|
||||
cmd.Email.Address = setEmailCmd.Email
|
||||
setEmailCmd.verification = cmd
|
||||
}
|
||||
|
||||
type SendCodeCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
URLTemplate *string `json:"urlTemplate"`
|
||||
generator *generateCodeCommand
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*SendCodeCommand)(nil)
|
||||
_ SetEmailOpt = (*SendCodeCommand)(nil)
|
||||
)
|
||||
|
||||
func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
|
||||
return &SendCodeCommand{
|
||||
UserID: userID,
|
||||
generator: &generateCodeCommand{},
|
||||
URLTemplate: urlTemplate,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute implements [Commander]
|
||||
func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
if err := cmd.ensureEmail(ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := cmd.ensureURL(ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: queue notification
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
|
||||
if cmd.Email != "" {
|
||||
return nil
|
||||
}
|
||||
email, err := userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().GetEmail(ctx)
|
||||
if err != nil || email.IsVerified {
|
||||
return err
|
||||
}
|
||||
cmd.Email = email.Address
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error {
|
||||
if cmd.URLTemplate != nil && *cmd.URLTemplate != "" {
|
||||
return nil
|
||||
}
|
||||
_, _ = ctx, opts
|
||||
// TODO: load default template
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOnSetEmail implements [SetEmailOpt]
|
||||
func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
cmd.UserID = setEmailCmd.UserID
|
||||
cmd.Email = setEmailCmd.Email
|
||||
setEmailCmd.verification = cmd
|
||||
}
|
||||
|
||||
type ReturnCodeCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
Code string `json:"code"`
|
||||
generator *generateCodeCommand
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*ReturnCodeCommand)(nil)
|
||||
_ SetEmailOpt = (*ReturnCodeCommand)(nil)
|
||||
)
|
||||
|
||||
func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
|
||||
return &ReturnCodeCommand{
|
||||
UserID: userID,
|
||||
generator: &generateCodeCommand{},
|
||||
}
|
||||
}
|
||||
|
||||
// Execute implements [Commander]
|
||||
func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
if err := cmd.ensureEmail(ctx, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
|
||||
return err
|
||||
}
|
||||
cmd.Code = cmd.generator.code
|
||||
return nil
|
||||
}
|
||||
|
||||
func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
|
||||
if cmd.Email != "" {
|
||||
return nil
|
||||
}
|
||||
user := v4.UserRepository(opts.DB)
|
||||
user.WithCondition(user.IDCondition(cmd.UserID))
|
||||
email, err := user.he.GetEmail(ctx)
|
||||
if err != nil || email.IsVerified {
|
||||
return err
|
||||
}
|
||||
cmd.Email = email.Address
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyOnSetEmail implements [SetEmailOpt]
|
||||
func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
cmd.UserID = setEmailCmd.UserID
|
||||
cmd.Email = setEmailCmd.Email
|
||||
setEmailCmd.verification = cmd
|
||||
}
|
7
backend/v3/domain/errors.go
Normal file
7
backend/v3/domain/errors.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package domain
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoAdminSpecified = errors.New("at least one admin must be specified")
|
||||
)
|
36
backend/v3/domain/instance.go
Normal file
36
backend/v3/domain/instance.go
Normal file
@@ -0,0 +1,36 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
CreatedAt time.Time `json:"-"`
|
||||
UpdatedAt time.Time `json:"-"`
|
||||
DeletedAt time.Time `json:"-"`
|
||||
}
|
||||
|
||||
// Keys implements the [cache.Entry].
|
||||
func (i *Instance) Keys(index string) (key []string) {
|
||||
// TODO: Return the correct keys for the instance cache, e.g., i.ID, i.Domain
|
||||
return []string{}
|
||||
}
|
||||
|
||||
type InstanceRepository interface {
|
||||
ByID(ctx context.Context, id string) (*Instance, error)
|
||||
Create(ctx context.Context, instance *Instance) error
|
||||
On(id string) InstanceOperation
|
||||
}
|
||||
|
||||
type InstanceOperation interface {
|
||||
AdminRepository
|
||||
Update(ctx context.Context, instance *Instance) error
|
||||
Delete(ctx context.Context) error
|
||||
}
|
||||
|
||||
type CreateInstance struct {
|
||||
Name string `json:"name"`
|
||||
}
|
94
backend/v3/domain/invoke.go
Normal file
94
backend/v3/domain/invoke.go
Normal file
@@ -0,0 +1,94 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
)
|
||||
|
||||
var defaultInvoker = newEventStoreInvoker(newTraceInvoker(nil))
|
||||
|
||||
func Invoke(ctx context.Context, cmd Commander) error {
|
||||
invoker := newEventStoreInvoker(newTraceInvoker(nil))
|
||||
opts := &CommandOpts{
|
||||
Invoker: invoker.collector,
|
||||
}
|
||||
return invoker.Invoke(ctx, cmd, opts)
|
||||
}
|
||||
|
||||
type eventStoreInvoker struct {
|
||||
collector *eventCollector
|
||||
}
|
||||
|
||||
func newEventStoreInvoker(next Invoker) *eventStoreInvoker {
|
||||
return &eventStoreInvoker{collector: &eventCollector{next: next}}
|
||||
}
|
||||
|
||||
func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
err = i.collector.Invoke(ctx, command, opts)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if len(i.collector.events) > 0 {
|
||||
err = eventstore.Publish(ctx, i.collector.events, opts.DB)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type eventCollector struct {
|
||||
next Invoker
|
||||
events []*eventstore.Event
|
||||
}
|
||||
|
||||
type eventer interface {
|
||||
Events() []*eventstore.Event
|
||||
}
|
||||
|
||||
func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
if e, ok := command.(eventer); ok && len(e.Events()) > 0 {
|
||||
// we need to ensure all commands are executed in the same transaction
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
|
||||
i.events = append(i.events, e.Events()...)
|
||||
}
|
||||
if i.next != nil {
|
||||
err = i.next.Invoke(ctx, command, opts)
|
||||
} else {
|
||||
err = command.Execute(ctx, opts)
|
||||
}
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type traceInvoker struct {
|
||||
next Invoker
|
||||
}
|
||||
|
||||
func newTraceInvoker(next Invoker) *traceInvoker {
|
||||
return &traceInvoker{next: next}
|
||||
}
|
||||
|
||||
func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
|
||||
defer span.End()
|
||||
|
||||
if i.next != nil {
|
||||
err = i.next.Invoke(ctx, command, opts)
|
||||
} else {
|
||||
err = command.Execute(ctx, opts)
|
||||
}
|
||||
if err != nil {
|
||||
span.RecordError(err)
|
||||
}
|
||||
return err
|
||||
}
|
39
backend/v3/domain/org.go
Normal file
39
backend/v3/domain/org.go
Normal file
@@ -0,0 +1,39 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Org struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt"`
|
||||
UpdatedAt time.Time `json:"updatedAt"`
|
||||
}
|
||||
|
||||
type OrgRepository interface {
|
||||
ByID(ctx context.Context, orgID string) (*Org, error)
|
||||
Create(ctx context.Context, org *Org) error
|
||||
On(id string) OrgOperation
|
||||
}
|
||||
|
||||
type OrgOperation interface {
|
||||
AdminRepository
|
||||
DomainRepository
|
||||
Update(ctx context.Context, org *Org) error
|
||||
Delete(ctx context.Context) error
|
||||
}
|
||||
|
||||
type AdminRepository interface {
|
||||
AddAdmin(ctx context.Context, userID string, roles []string) error
|
||||
SetAdminRoles(ctx context.Context, userID string, roles []string) error
|
||||
RemoveAdmin(ctx context.Context, userID string) error
|
||||
}
|
||||
|
||||
type DomainRepository interface {
|
||||
AddDomain(ctx context.Context, domain string) error
|
||||
SetDomainVerified(ctx context.Context, domain string) error
|
||||
RemoveDomain(ctx context.Context, domain string) error
|
||||
}
|
74
backend/v3/domain/org_add.go
Normal file
74
backend/v3/domain/org_add.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type AddOrgCommand struct {
|
||||
ID string `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Admins []AddAdminCommand `json:"admins"`
|
||||
}
|
||||
|
||||
func NewAddOrgCommand(name string, admins ...AddAdminCommand) *AddOrgCommand {
|
||||
return &AddOrgCommand{
|
||||
Name: name,
|
||||
Admins: admins,
|
||||
}
|
||||
}
|
||||
|
||||
// Execute implements Commander.
|
||||
func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
|
||||
if len(cmd.Admins) == 0 {
|
||||
return ErrNoAdminSpecified
|
||||
}
|
||||
if err = cmd.ensureID(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
err = orgRepo(opts.DB).Create(ctx, &Org{
|
||||
ID: cmd.ID,
|
||||
Name: cmd.Name,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*AddOrgCommand)(nil)
|
||||
)
|
||||
|
||||
func (cmd *AddOrgCommand) ensureID() (err error) {
|
||||
if cmd.ID != "" {
|
||||
return nil
|
||||
}
|
||||
cmd.ID, err = generateID()
|
||||
return err
|
||||
}
|
||||
|
||||
type AddAdminCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Roles []string `json:"roles"`
|
||||
}
|
||||
|
||||
// Execute implements Commander.
|
||||
func (a *AddAdminCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
return nil
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*AddAdminCommand)(nil)
|
||||
)
|
82
backend/v3/domain/repository.go
Normal file
82
backend/v3/domain/repository.go
Normal file
@@ -0,0 +1,82 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Operation interface {
|
||||
// TextOperation |
|
||||
// NumberOperation |
|
||||
// BoolOperation
|
||||
|
||||
op()
|
||||
}
|
||||
|
||||
type clause[F ~uint8, Op Operation] struct {
|
||||
field F
|
||||
op Op
|
||||
}
|
||||
|
||||
func (c *clause[F, Op]) Field() F {
|
||||
return c.field
|
||||
}
|
||||
|
||||
func (c *clause[F, Op]) Operation() Op {
|
||||
return c.op
|
||||
}
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
}
|
||||
|
||||
type TextOperation uint8
|
||||
|
||||
const (
|
||||
TextOperationEqual TextOperation = iota
|
||||
TextOperationNotEqual
|
||||
TextOperationStartsWith
|
||||
TextOperationStartsWithIgnoreCase
|
||||
)
|
||||
|
||||
func (TextOperation) op() {}
|
||||
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex | time.Time
|
||||
}
|
||||
|
||||
type NumberOperation uint8
|
||||
|
||||
const (
|
||||
NumberOperationEqual NumberOperation = iota
|
||||
NumberOperationNotEqual
|
||||
NumberOperationLessThan
|
||||
NumberOperationLessThanOrEqual
|
||||
NumberOperationGreaterThan
|
||||
NumberOperationGreaterThanOrEqual
|
||||
)
|
||||
|
||||
func (NumberOperation) op() {}
|
||||
|
||||
type Bool interface {
|
||||
~bool
|
||||
}
|
||||
|
||||
type BoolOperation uint8
|
||||
|
||||
const (
|
||||
BoolOperationIs BoolOperation = iota
|
||||
BoolOperationNot
|
||||
)
|
||||
|
||||
func (BoolOperation) op() {}
|
||||
|
||||
type ListOperation uint8
|
||||
|
||||
const (
|
||||
ListOperationContains ListOperation = iota
|
||||
ListOperationNotContains
|
||||
)
|
||||
|
||||
func (ListOperation) op() {}
|
64
backend/v3/domain/set_email.go
Normal file
64
backend/v3/domain/set_email.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
)
|
||||
|
||||
type SetEmailCommand struct {
|
||||
UserID string `json:"userId"`
|
||||
Email string `json:"email"`
|
||||
verification Commander
|
||||
}
|
||||
|
||||
var (
|
||||
_ Commander = (*SetEmailCommand)(nil)
|
||||
_ eventer = (*SetEmailCommand)(nil)
|
||||
_ CreateHumanOpt = (*SetEmailCommand)(nil)
|
||||
)
|
||||
|
||||
type SetEmailOpt interface {
|
||||
applyOnSetEmail(*SetEmailCommand)
|
||||
}
|
||||
|
||||
func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand {
|
||||
cmd := &SetEmailCommand{
|
||||
UserID: userID,
|
||||
Email: email,
|
||||
}
|
||||
verificationType.applyOnSetEmail(cmd)
|
||||
return cmd
|
||||
}
|
||||
|
||||
func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
close, err := opts.EnsureTx(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() { err = close(ctx, err) }()
|
||||
// userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email)
|
||||
err = userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmail(ctx, cmd.Email)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return opts.Invoke(ctx, cmd.verification)
|
||||
}
|
||||
|
||||
// Events implements [eventer].
|
||||
func (cmd *SetEmailCommand) Events() []*eventstore.Event {
|
||||
return []*eventstore.Event{
|
||||
{
|
||||
AggregateType: "user",
|
||||
AggregateID: cmd.UserID,
|
||||
Type: "user.email.set",
|
||||
Payload: cmd,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// applyOnCreateHuman implements [CreateHumanOpt].
|
||||
func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand[Human]) {
|
||||
createUserCmd.email = cmd
|
||||
}
|
193
backend/v3/domain/user.go
Normal file
193
backend/v3/domain/user.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
|
||||
)
|
||||
|
||||
type userColumns interface {
|
||||
// TODO: move v4.columns to domain
|
||||
InstanceIDColumn() column
|
||||
OrgIDColumn() column
|
||||
IDColumn() column
|
||||
usernameColumn() column
|
||||
CreatedAtColumn() column
|
||||
UpdatedAtColumn() column
|
||||
DeletedAtColumn() column
|
||||
}
|
||||
|
||||
type userConditions interface {
|
||||
InstanceIDCondition(instanceID string) v4.Condition
|
||||
OrgIDCondition(orgID string) v4.Condition
|
||||
IDCondition(userID string) v4.Condition
|
||||
UsernameCondition(op v4.TextOperator, username string) v4.Condition
|
||||
CreatedAtCondition(op v4.NumberOperator, createdAt time.Time) v4.Condition
|
||||
UpdatedAtCondition(op v4.NumberOperator, updatedAt time.Time) v4.Condition
|
||||
DeletedCondition(isDeleted bool) v4.Condition
|
||||
DeletedAtCondition(op v4.NumberOperator, deletedAt time.Time) v4.Condition
|
||||
}
|
||||
|
||||
type UserRepository interface {
|
||||
userColumns
|
||||
userConditions
|
||||
// TODO: move condition to domain
|
||||
WithCondition(condition v4.Condition) UserRepository
|
||||
Get(ctx context.Context) (*User, error)
|
||||
List(ctx context.Context) ([]*User, error)
|
||||
Create(ctx context.Context, user *User) error
|
||||
Delete(ctx context.Context) error
|
||||
|
||||
Human() HumanRepository
|
||||
Machine() MachineRepository
|
||||
}
|
||||
|
||||
type humanColumns interface {
|
||||
FirstNameColumn() column
|
||||
LastNameColumn() column
|
||||
EmailAddressColumn() column
|
||||
EmailVerifiedAtColumn() column
|
||||
PhoneNumberColumn() column
|
||||
PhoneVerifiedAtColumn() column
|
||||
}
|
||||
|
||||
type humanConditions interface {
|
||||
FirstNameCondition(op v4.TextOperator, firstName string) v4.Condition
|
||||
LastNameCondition(op v4.TextOperator, lastName string) v4.Condition
|
||||
EmailAddressCondition(op v4.TextOperator, email string) v4.Condition
|
||||
EmailAddressVerifiedCondition(isVerified bool) v4.Condition
|
||||
EmailVerifiedAtCondition(op v4.TextOperator, emailVerifiedAt string) v4.Condition
|
||||
PhoneNumberCondition(op v4.TextOperator, phoneNumber string) v4.Condition
|
||||
PhoneNumberVerifiedCondition(isVerified bool) v4.Condition
|
||||
PhoneVerifiedAtCondition(op v4.TextOperator, phoneVerifiedAt string) v4.Condition
|
||||
}
|
||||
|
||||
type HumanRepository interface {
|
||||
humanColumns
|
||||
humanConditions
|
||||
|
||||
GetEmail(ctx context.Context) (*Email, error)
|
||||
// TODO: replace any with add email update columns
|
||||
SetEmail(ctx context.Context, columns ...any) error
|
||||
}
|
||||
|
||||
type machineColumns interface {
|
||||
DescriptionColumn() column
|
||||
}
|
||||
|
||||
type machineConditions interface {
|
||||
DescriptionCondition(op v4.TextOperator, description string) v4.Condition
|
||||
}
|
||||
|
||||
type MachineRepository interface {
|
||||
machineColumns
|
||||
machineConditions
|
||||
}
|
||||
|
||||
// type UserRepository interface {
|
||||
// // Get(ctx context.Context, clauses ...UserClause) (*User, error)
|
||||
// // Search(ctx context.Context, clauses ...UserClause) ([]*User, error)
|
||||
|
||||
// UserQuery[UserOperation]
|
||||
// Human() HumanQuery
|
||||
// Machine() MachineQuery
|
||||
// }
|
||||
|
||||
// type UserQuery[Op UserOperation] interface {
|
||||
// ByID(id string) UserQuery[Op]
|
||||
// Username(username string) UserQuery[Op]
|
||||
// Exec() Op
|
||||
// }
|
||||
|
||||
// type HumanQuery interface {
|
||||
// UserQuery[HumanOperation]
|
||||
// Email(op TextOperation, email string) HumanQuery
|
||||
// HumanOperation
|
||||
// }
|
||||
|
||||
// type MachineQuery interface {
|
||||
// UserQuery[MachineOperation]
|
||||
// MachineOperation
|
||||
// }
|
||||
|
||||
// type UserClause interface {
|
||||
// Field() UserField
|
||||
// Operation() Operation
|
||||
// Args() []any
|
||||
// }
|
||||
|
||||
// type UserField uint8
|
||||
|
||||
// const (
|
||||
// // Fields used for all users
|
||||
// UserFieldInstanceID UserField = iota + 1
|
||||
// UserFieldOrgID
|
||||
// UserFieldID
|
||||
// UserFieldUsername
|
||||
|
||||
// // Fields used for human users
|
||||
// UserHumanFieldEmail
|
||||
// UserHumanFieldEmailVerified
|
||||
|
||||
// // Fields used for machine users
|
||||
// UserMachineFieldDescription
|
||||
// )
|
||||
|
||||
// type userByIDClause struct {
|
||||
// id string
|
||||
// }
|
||||
|
||||
// func (c *userByIDClause) Field() UserField {
|
||||
// return UserFieldID
|
||||
// }
|
||||
|
||||
// func (c *userByIDClause) Operation() Operation {
|
||||
// return TextOperationEqual
|
||||
// }
|
||||
|
||||
// func (c *userByIDClause) Args() []any {
|
||||
// return []any{c.id}
|
||||
// }
|
||||
|
||||
// type UserOperation interface {
|
||||
// Delete(ctx context.Context) error
|
||||
// SetUsername(ctx context.Context, username string) error
|
||||
// }
|
||||
|
||||
// type HumanOperation interface {
|
||||
// UserOperation
|
||||
// SetEmail(ctx context.Context, email string) error
|
||||
// SetEmailVerified(ctx context.Context, email string) error
|
||||
// GetEmail(ctx context.Context) (*Email, error)
|
||||
// }
|
||||
|
||||
// type MachineOperation interface {
|
||||
// UserOperation
|
||||
// SetDescription(ctx context.Context, description string) error
|
||||
// }
|
||||
|
||||
type User struct {
|
||||
v4.User
|
||||
}
|
||||
|
||||
// type userTraits interface {
|
||||
// isUserTraits()
|
||||
// }
|
||||
|
||||
// type Human struct {
|
||||
// Email *Email `json:"email"`
|
||||
// }
|
||||
|
||||
// func (*Human) isUserTraits() {}
|
||||
|
||||
// type Machine struct {
|
||||
// Description string `json:"description"`
|
||||
// }
|
||||
|
||||
// func (*Machine) isUserTraits() {}
|
||||
|
||||
// type Email struct {
|
||||
// Address string `json:"address"`
|
||||
// IsVerified bool `json:"isVerified"`
|
||||
// }
|
112
backend/v3/storage/cache/cache.go
vendored
Normal file
112
backend/v3/storage/cache/cache.go
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package cache provides abstraction of cache implementations that can be used by zitadel.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Purpose describes which object types are stored by a cache.
|
||||
type Purpose int
|
||||
|
||||
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
|
||||
const (
|
||||
PurposeUnspecified Purpose = iota
|
||||
PurposeAuthzInstance
|
||||
PurposeMilestones
|
||||
PurposeOrganization
|
||||
PurposeIdPFormCallback
|
||||
)
|
||||
|
||||
// Cache stores objects with a value of type `V`.
|
||||
// Objects may be referred to by one or more indices.
|
||||
// Implementations may encode the value for storage.
|
||||
// This means non-exported fields may be lost and objects
|
||||
// with function values may fail to encode.
|
||||
// See https://pkg.go.dev/encoding/json#Marshal for example.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
// Indices are defined when calling the constructor of an implementation of this interface.
|
||||
// It is illegal to refer to an idex not defined during construction.
|
||||
//
|
||||
// `K` is the type used as key in each index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
//
|
||||
// Implementations are free to use stricter type constraints or fixed typing.
|
||||
type Cache[I, K comparable, V Entry[I, K]] interface {
|
||||
// Get an object through specified index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// [ErrCacheMiss] is returned if the key was not found in the index,
|
||||
// or the object is not valid.
|
||||
Get(ctx context.Context, index I, key K) (V, bool)
|
||||
|
||||
// Set an object.
|
||||
// Keys are created on each index based in the [Entry.Keys] method.
|
||||
// If any key maps to an existing object, the object is invalidated,
|
||||
// regardless if the object has other keys defined in the new entry.
|
||||
// This to prevent ghost objects when an entry reduces the amount of keys
|
||||
// for a given index.
|
||||
Set(ctx context.Context, value V)
|
||||
|
||||
// Invalidate an object through specified index.
|
||||
// Implementations may choose to instantly delete the object,
|
||||
// defer until prune or a separate cleanup routine.
|
||||
// Invalidated object are no longer returned from Get.
|
||||
// It is safe to call Invalidate multiple times or on non-existing entries.
|
||||
Invalidate(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Delete one or more keys from a specific index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// The referred object is not invalidated and may still be accessible though
|
||||
// other indices and keys.
|
||||
// It is safe to call Delete multiple times or on non-existing entries
|
||||
Delete(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Truncate deletes all cached objects.
|
||||
Truncate(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Entry contains a value of type `V` to be cached.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
//
|
||||
// `K` is the type used as key in an index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
type Entry[I, K comparable] interface {
|
||||
// Keys returns which keys map to the object in a specified index.
|
||||
// May return nil if the index in unknown or when there are no keys.
|
||||
Keys(index I) (key []K)
|
||||
}
|
||||
|
||||
type Connector int
|
||||
|
||||
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
|
||||
const (
|
||||
// Empty line comment ensures empty string for unspecified value
|
||||
ConnectorUnspecified Connector = iota //
|
||||
ConnectorMemory
|
||||
ConnectorPostgres
|
||||
ConnectorRedis
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Connector Connector
|
||||
|
||||
// Age since an object was added to the cache,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables max age checks.
|
||||
MaxAge time.Duration
|
||||
|
||||
// Age since last use (Get) of an object,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables last use age checks.
|
||||
LastUseAge time.Duration
|
||||
|
||||
// Log allows logging of the specific cache.
|
||||
// By default only errors are logged to stdout.
|
||||
Log *logging.Config
|
||||
}
|
49
backend/v3/storage/cache/connector/connector.go
vendored
Normal file
49
backend/v3/storage/cache/connector/connector.go
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop"
|
||||
)
|
||||
|
||||
type CachesConfig struct {
|
||||
Connectors struct {
|
||||
Memory gomap.Config
|
||||
}
|
||||
Instance *cache.Config
|
||||
Milestones *cache.Config
|
||||
Organization *cache.Config
|
||||
IdPFormCallbacks *cache.Config
|
||||
}
|
||||
|
||||
type Connectors struct {
|
||||
Config CachesConfig
|
||||
Memory *gomap.Connector
|
||||
}
|
||||
|
||||
func StartConnectors(conf *CachesConfig) (Connectors, error) {
|
||||
if conf == nil {
|
||||
return Connectors{}, nil
|
||||
}
|
||||
return Connectors{
|
||||
Config: *conf,
|
||||
Memory: gomap.NewConnector(conf.Connectors.Memory),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
|
||||
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
|
||||
return noop.NewCache[I, K, V](), nil
|
||||
}
|
||||
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
|
||||
c := gomap.NewCache[I, K, V](background, indices, *conf)
|
||||
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
|
||||
}
|
23
backend/v3/storage/cache/connector/gomap/connector.go
vendored
Normal file
23
backend/v3/storage/cache/connector/gomap/connector.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
AutoPrune cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
Config cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
func NewConnector(config Config) *Connector {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
return &Connector{
|
||||
Config: config.AutoPrune,
|
||||
}
|
||||
}
|
200
backend/v3/storage/cache/connector/gomap/gomap.go
vendored
Normal file
200
backend/v3/storage/cache/connector/gomap/gomap.go
vendored
Normal file
@@ -0,0 +1,200 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
config *cache.Config
|
||||
indexMap map[I]*index[K, V]
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
|
||||
// Object values are stored as-is and there is no encoding or decoding involved.
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
|
||||
m := &mapCache[I, K, V]{
|
||||
config: &config,
|
||||
indexMap: make(map[I]*index[K, V], len(indices)),
|
||||
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelError,
|
||||
})),
|
||||
}
|
||||
if config.Log != nil {
|
||||
m.logger = config.Log.Slog()
|
||||
}
|
||||
m.logger.InfoContext(background, "map cache logging enabled")
|
||||
|
||||
for _, name := range indices {
|
||||
m.indexMap[name] = &index[K, V]{
|
||||
config: m.config,
|
||||
entries: make(map[K]*entry[V]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
entry, err := i.Get(key)
|
||||
if err == nil {
|
||||
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
|
||||
return entry.value, true
|
||||
}
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
|
||||
now := time.Now()
|
||||
entry := &entry[V]{
|
||||
value: value,
|
||||
created: now,
|
||||
}
|
||||
entry.lastUse.Store(now.UnixMicro())
|
||||
|
||||
for name, i := range c.indexMap {
|
||||
keys := value.Keys(name)
|
||||
i.Set(keys, entry)
|
||||
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Invalidate(keys)
|
||||
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Delete(keys)
|
||||
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Prune()
|
||||
c.logger.DebugContext(ctx, "map cache prune", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Truncate()
|
||||
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type index[K comparable, V any] struct {
|
||||
mutex sync.RWMutex
|
||||
config *cache.Config
|
||||
entries map[K]*entry[V]
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Get(key K) (*entry[V], error) {
|
||||
i.mutex.RLock()
|
||||
entry, ok := i.entries[key]
|
||||
i.mutex.RUnlock()
|
||||
if ok && entry.isValid(i.config) {
|
||||
return entry, nil
|
||||
}
|
||||
return nil, cache.ErrCacheMiss
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
c.entries[key] = entry
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Invalidate(keys []K) {
|
||||
i.mutex.RLock()
|
||||
for _, key := range keys {
|
||||
if entry, ok := i.entries[key]; ok {
|
||||
entry.invalid.Store(true)
|
||||
}
|
||||
}
|
||||
i.mutex.RUnlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Delete(keys []K) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
delete(c.entries, key)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Prune() {
|
||||
c.mutex.Lock()
|
||||
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
|
||||
return !entry.isValid(c.config)
|
||||
})
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Truncate() {
|
||||
c.mutex.Lock()
|
||||
c.entries = make(map[K]*entry[V])
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
type entry[V any] struct {
|
||||
value V
|
||||
created time.Time
|
||||
invalid atomic.Bool
|
||||
lastUse atomic.Int64 // UnixMicro time
|
||||
}
|
||||
|
||||
func (e *entry[V]) isValid(c *cache.Config) bool {
|
||||
if e.invalid.Load() {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if c.MaxAge > 0 {
|
||||
if e.created.Add(c.MaxAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.LastUseAge > 0 {
|
||||
lastUse := e.lastUse.Load()
|
||||
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
|
||||
}
|
||||
return true
|
||||
}
|
329
backend/v3/storage/cache/connector/gomap/gomap_test.go
vendored
Normal file
329
backend/v3/storage/cache/connector/gomap/gomap_test.go
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type testIndex int
|
||||
|
||||
const (
|
||||
testIndexID testIndex = iota
|
||||
testIndexName
|
||||
)
|
||||
|
||||
var testIndices = []testIndex{
|
||||
testIndexID,
|
||||
testIndexName,
|
||||
}
|
||||
|
||||
type testObject struct {
|
||||
id string
|
||||
names []string
|
||||
}
|
||||
|
||||
func (o *testObject) Keys(index testIndex) []string {
|
||||
switch index {
|
||||
case testIndexID:
|
||||
return []string{o.id}
|
||||
case testIndexName:
|
||||
return o.names
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Get(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
|
||||
type args struct {
|
||||
index testIndex
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *testObject
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id",
|
||||
},
|
||||
want: obj,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "miss",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "spanac",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "unknown index",
|
||||
args: args{
|
||||
index: 99,
|
||||
key: "id",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantOk, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Invalidate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Delete(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Delete(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Shouldn't find object by deleted name
|
||||
got, ok := c.Get(context.Background(), testIndexName, "bar")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Should find object by other name
|
||||
got, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Should find object by id
|
||||
got, ok = c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Prune(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
// invalidate one entry
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.(cache.Pruner).Prune(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Other object should still be found
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id2")
|
||||
assert.Equal(t, objects[1], got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Truncate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
|
||||
err := c.Truncate(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
mc := c.(*mapCache[testIndex, string, *testObject])
|
||||
for _, index := range mc.indexMap {
|
||||
index.mutex.RLock()
|
||||
assert.Len(t, index.entries, 0)
|
||||
index.mutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_entry_isValid(t *testing.T) {
|
||||
type fields struct {
|
||||
created time.Time
|
||||
invalid bool
|
||||
lastUse time.Time
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
config *cache.Config
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: true,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "last use age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "last use age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &entry[any]{
|
||||
created: tt.fields.created,
|
||||
}
|
||||
e.invalid.Store(tt.fields.invalid)
|
||||
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
|
||||
got := e.isValid(tt.config)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
21
backend/v3/storage/cache/connector/noop/noop.go
vendored
Normal file
21
backend/v3/storage/cache/connector/noop/noop.go
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package noop
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
|
||||
|
||||
// NewCache returns a cache that does nothing
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
|
||||
return noop[I, K, V]{}
|
||||
}
|
||||
|
||||
func (noop[I, K, V]) Set(context.Context, V) {}
|
||||
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
|
||||
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
|
||||
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }
|
98
backend/v3/storage/cache/connector_enumer.go
vendored
Normal file
98
backend/v3/storage/cache/connector_enumer.go
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _ConnectorName = "memorypostgresredis"
|
||||
|
||||
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
|
||||
|
||||
const _ConnectorLowerName = "memorypostgresredis"
|
||||
|
||||
func (i Connector) String() string {
|
||||
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
|
||||
return fmt.Sprintf("Connector(%d)", i)
|
||||
}
|
||||
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _ConnectorNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[ConnectorUnspecified-(0)]
|
||||
_ = x[ConnectorMemory-(1)]
|
||||
_ = x[ConnectorPostgres-(2)]
|
||||
_ = x[ConnectorRedis-(3)]
|
||||
}
|
||||
|
||||
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
|
||||
|
||||
var _ConnectorNameToValueMap = map[string]Connector{
|
||||
_ConnectorName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorLowerName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorName[0:6]: ConnectorMemory,
|
||||
_ConnectorLowerName[0:6]: ConnectorMemory,
|
||||
_ConnectorName[6:14]: ConnectorPostgres,
|
||||
_ConnectorLowerName[6:14]: ConnectorPostgres,
|
||||
_ConnectorName[14:19]: ConnectorRedis,
|
||||
_ConnectorLowerName[14:19]: ConnectorRedis,
|
||||
}
|
||||
|
||||
var _ConnectorNames = []string{
|
||||
_ConnectorName[0:0],
|
||||
_ConnectorName[0:6],
|
||||
_ConnectorName[6:14],
|
||||
_ConnectorName[14:19],
|
||||
}
|
||||
|
||||
// ConnectorString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func ConnectorString(s string) (Connector, error) {
|
||||
if val, ok := _ConnectorNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Connector values", s)
|
||||
}
|
||||
|
||||
// ConnectorValues returns all values of the enum
|
||||
func ConnectorValues() []Connector {
|
||||
return _ConnectorValues
|
||||
}
|
||||
|
||||
// ConnectorStrings returns a slice of all String values of the enum
|
||||
func ConnectorStrings() []string {
|
||||
strs := make([]string, len(_ConnectorNames))
|
||||
copy(strs, _ConnectorNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Connector) IsAConnector() bool {
|
||||
for _, v := range _ConnectorValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for Connector
|
||||
func (i Connector) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
|
||||
func (i *Connector) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = ConnectorString(string(text))
|
||||
return err
|
||||
}
|
29
backend/v3/storage/cache/error.go
vendored
Normal file
29
backend/v3/storage/cache/error.go
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type IndexUnknownError[I comparable] struct {
|
||||
index I
|
||||
}
|
||||
|
||||
func NewIndexUnknownErr[I comparable](index I) error {
|
||||
return IndexUnknownError[I]{index}
|
||||
}
|
||||
|
||||
func (i IndexUnknownError[I]) Error() string {
|
||||
return fmt.Sprintf("index %v unknown", i.index)
|
||||
}
|
||||
|
||||
func (a IndexUnknownError[I]) Is(err error) bool {
|
||||
if b, ok := err.(IndexUnknownError[I]); ok {
|
||||
return a.index == b.index
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
)
|
76
backend/v3/storage/cache/pruner.go
vendored
Normal file
76
backend/v3/storage/cache/pruner.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Pruner is an optional [Cache] interface.
|
||||
type Pruner interface {
|
||||
// Prune deletes all invalidated or expired objects.
|
||||
Prune(ctx context.Context) error
|
||||
}
|
||||
|
||||
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
|
||||
Cache[I, K, V]
|
||||
Pruner
|
||||
}
|
||||
|
||||
type AutoPruneConfig struct {
|
||||
// Interval at which the cache is automatically pruned.
|
||||
// 0 or lower disables automatic pruning.
|
||||
Interval time.Duration
|
||||
|
||||
// Timeout for an automatic prune.
|
||||
// It is recommended to keep the value shorter than AutoPruneInterval
|
||||
// 0 or lower disables automatic pruning.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
|
||||
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
|
||||
if c.Interval <= 0 {
|
||||
return func() {}
|
||||
}
|
||||
background, cancel := context.WithCancel(background)
|
||||
// randomize the first interval
|
||||
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
|
||||
go c.pruneTimer(background, pruner, purpose, timer)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
<-timer.Chan()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-background.Done():
|
||||
return
|
||||
case <-timer.Chan():
|
||||
err := c.doPrune(background, pruner)
|
||||
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
|
||||
timer.Reset(c.Interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
|
||||
ctx, cancel := context.WithCancel(background)
|
||||
defer cancel()
|
||||
if c.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(background, c.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
return pruner.Prune(ctx)
|
||||
}
|
43
backend/v3/storage/cache/pruner_test.go
vendored
Normal file
43
backend/v3/storage/cache/pruner_test.go
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testPruner struct {
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func (p *testPruner) Prune(context.Context) error {
|
||||
p.called <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
|
||||
c := AutoPruneConfig{
|
||||
Interval: time.Second,
|
||||
Timeout: time.Millisecond,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
pruner := testPruner{
|
||||
called: make(chan struct{}),
|
||||
}
|
||||
clock := clockwork.NewFakeClock()
|
||||
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
|
||||
defer close()
|
||||
clock.Advance(time.Second)
|
||||
|
||||
select {
|
||||
case _, ok := <-pruner.called:
|
||||
assert.True(t, ok)
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
90
backend/v3/storage/cache/purpose_enumer.go
vendored
Normal file
90
backend/v3/storage/cache/purpose_enumer.go
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
|
||||
|
||||
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
func (i Purpose) String() string {
|
||||
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
|
||||
return fmt.Sprintf("Purpose(%d)", i)
|
||||
}
|
||||
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _PurposeNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[PurposeUnspecified-(0)]
|
||||
_ = x[PurposeAuthzInstance-(1)]
|
||||
_ = x[PurposeMilestones-(2)]
|
||||
_ = x[PurposeOrganization-(3)]
|
||||
_ = x[PurposeIdPFormCallback-(4)]
|
||||
}
|
||||
|
||||
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
|
||||
|
||||
var _PurposeNameToValueMap = map[string]Purpose{
|
||||
_PurposeName[0:11]: PurposeUnspecified,
|
||||
_PurposeLowerName[0:11]: PurposeUnspecified,
|
||||
_PurposeName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeLowerName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeName[25:35]: PurposeMilestones,
|
||||
_PurposeLowerName[25:35]: PurposeMilestones,
|
||||
_PurposeName[35:47]: PurposeOrganization,
|
||||
_PurposeLowerName[35:47]: PurposeOrganization,
|
||||
_PurposeName[47:65]: PurposeIdPFormCallback,
|
||||
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
|
||||
}
|
||||
|
||||
var _PurposeNames = []string{
|
||||
_PurposeName[0:11],
|
||||
_PurposeName[11:25],
|
||||
_PurposeName[25:35],
|
||||
_PurposeName[35:47],
|
||||
_PurposeName[47:65],
|
||||
}
|
||||
|
||||
// PurposeString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func PurposeString(s string) (Purpose, error) {
|
||||
if val, ok := _PurposeNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
|
||||
}
|
||||
|
||||
// PurposeValues returns all values of the enum
|
||||
func PurposeValues() []Purpose {
|
||||
return _PurposeValues
|
||||
}
|
||||
|
||||
// PurposeStrings returns a slice of all String values of the enum
|
||||
func PurposeStrings() []string {
|
||||
strs := make([]string, len(_PurposeNames))
|
||||
copy(strs, _PurposeNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Purpose) IsAPurpose() bool {
|
||||
for _, v := range _PurposeValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
9
backend/v3/storage/database/config.go
Normal file
9
backend/v3/storage/database/config.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
type Connector interface {
|
||||
Connect(ctx context.Context) (Pool, error)
|
||||
}
|
60
backend/v3/storage/database/database.go
Normal file
60
backend/v3/storage/database/database.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
var (
|
||||
db *database
|
||||
)
|
||||
|
||||
type database struct {
|
||||
connector Connector
|
||||
pool Pool
|
||||
}
|
||||
|
||||
type Pool interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
|
||||
Acquire(ctx context.Context) (Client, error)
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Client interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
|
||||
Release(ctx context.Context) error
|
||||
}
|
||||
|
||||
type Querier interface {
|
||||
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, stmt string, args ...any) Row
|
||||
}
|
||||
|
||||
type Executor interface {
|
||||
Exec(ctx context.Context, stmt string, args ...any) error
|
||||
}
|
||||
|
||||
type QueryExecutor interface {
|
||||
Querier
|
||||
Executor
|
||||
}
|
||||
|
||||
type Scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
type Row interface {
|
||||
Scanner
|
||||
}
|
||||
|
||||
type Rows interface {
|
||||
Row
|
||||
Next() bool
|
||||
Close() error
|
||||
Err() error
|
||||
}
|
||||
|
||||
type Query[T any] func(querier Querier) (result T, err error)
|
92
backend/v3/storage/database/dialect/config.go
Normal file
92
backend/v3/storage/database/dialect/config.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
|
||||
)
|
||||
|
||||
type Hook struct {
|
||||
Match func(string) bool
|
||||
Decode func(config any) (database.Connector, error)
|
||||
Name string
|
||||
Constructor func() database.Connector
|
||||
}
|
||||
|
||||
var hooks = []Hook{
|
||||
{
|
||||
Match: postgres.NameMatcher,
|
||||
Decode: postgres.DecodeConfig,
|
||||
Name: postgres.Name,
|
||||
Constructor: func() database.Connector { return new(postgres.Config) },
|
||||
},
|
||||
// {
|
||||
// Match: gosql.NameMatcher,
|
||||
// Decode: gosql.DecodeConfig,
|
||||
// Name: gosql.Name,
|
||||
// Constructor: func() database.Connector { return new(gosql.Config) },
|
||||
// },
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
|
||||
|
||||
connector database.Connector
|
||||
}
|
||||
|
||||
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
if len(c.Dialects) != 1 {
|
||||
return nil, errors.New("Exactly one dialect must be configured")
|
||||
}
|
||||
|
||||
return c.connector.Connect(ctx)
|
||||
}
|
||||
|
||||
// Hooks implements [configure.Unmarshaller].
|
||||
func (c Config) Hooks() []viper.DecoderConfigOption {
|
||||
return []viper.DecoderConfigOption{
|
||||
viper.DecodeHook(decodeHook),
|
||||
}
|
||||
}
|
||||
|
||||
func decodeHook(from, to reflect.Value) (_ any, err error) {
|
||||
if to.Type() != reflect.TypeOf(Config{}) {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
config := new(Config)
|
||||
if err = mapstructure.Decode(from.Interface(), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = config.decodeDialect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (c *Config) decodeDialect() error {
|
||||
for _, hook := range hooks {
|
||||
for name, config := range c.Dialects {
|
||||
if !hook.Match(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
connector, err := hook.Decode(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.connector = connector
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("no dialect found")
|
||||
}
|
80
backend/v3/storage/database/dialect/postgres/config.go
Normal file
80
backend/v3/storage/database/dialect/postgres/config.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var (
|
||||
_ database.Connector = (*Config)(nil)
|
||||
Name = "postgres"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
*pgxpool.Config
|
||||
|
||||
// Host string
|
||||
// Port int32
|
||||
// Database string
|
||||
// MaxOpenConns uint32
|
||||
// MaxIdleConns uint32
|
||||
// MaxConnLifetime time.Duration
|
||||
// MaxConnIdleTime time.Duration
|
||||
// User User
|
||||
// // Additional options to be appended as options=<Options>
|
||||
// // The value will be taken as is. Multiple options are space separated.
|
||||
// Options string
|
||||
|
||||
configuredFields []string
|
||||
}
|
||||
|
||||
// Connect implements [database.Connector].
|
||||
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
pool, err := pgxpool.NewWithConfig(ctx, c.Config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = pool.Ping(ctx); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxPool{pool}, nil
|
||||
}
|
||||
|
||||
func NameMatcher(name string) bool {
|
||||
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
|
||||
}
|
||||
|
||||
func DecodeConfig(input any) (database.Connector, error) {
|
||||
switch c := input.(type) {
|
||||
case string:
|
||||
config, err := pgxpool.ParseConfig(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{Config: config}, nil
|
||||
case map[string]any:
|
||||
connector := new(Config)
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: connector,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = decoder.Decode(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{
|
||||
Config: &pgxpool.Config{},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid configuration")
|
||||
}
|
48
backend/v3/storage/database/dialect/postgres/conn.go
Normal file
48
backend/v3/storage/database/dialect/postgres/conn.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type pgxConn struct{ *pgxpool.Conn }
|
||||
|
||||
var _ database.Client = (*pgxConn)(nil)
|
||||
|
||||
// Release implements [database.Client].
|
||||
func (c *pgxConn) Release(_ context.Context) error {
|
||||
c.Conn.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Begin implements [database.Client].
|
||||
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Query implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
|
||||
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Conn.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
|
||||
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return c.Conn.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := c.Conn.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
57
backend/v3/storage/database/dialect/postgres/pool.go
Normal file
57
backend/v3/storage/database/dialect/postgres/pool.go
Normal file
@@ -0,0 +1,57 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type pgxPool struct{ *pgxpool.Pool }
|
||||
|
||||
var _ database.Pool = (*pgxPool)(nil)
|
||||
|
||||
// Acquire implements [database.Pool].
|
||||
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
conn, err := c.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxConn{conn}, nil
|
||||
}
|
||||
|
||||
// Query implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Pool.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return c.Pool.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := c.Pool.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Begin implements [database.Pool].
|
||||
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Close implements [database.Pool].
|
||||
func (c *pgxPool) Close(_ context.Context) error {
|
||||
c.Pool.Close()
|
||||
return nil
|
||||
}
|
18
backend/v3/storage/database/dialect/postgres/rows.go
Normal file
18
backend/v3/storage/database/dialect/postgres/rows.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var _ database.Rows = (*Rows)(nil)
|
||||
|
||||
type Rows struct{ pgx.Rows }
|
||||
|
||||
// Close implements [database.Rows].
|
||||
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
|
||||
func (r *Rows) Close() error {
|
||||
r.Rows.Close()
|
||||
return nil
|
||||
}
|
95
backend/v3/storage/database/dialect/postgres/tx.go
Normal file
95
backend/v3/storage/database/dialect/postgres/tx.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type pgxTx struct{ pgx.Tx }
|
||||
|
||||
var _ database.Transaction = (*pgxTx)(nil)
|
||||
|
||||
// Commit implements [database.Transaction].
|
||||
func (tx *pgxTx) Commit(ctx context.Context) error {
|
||||
return tx.Tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Rollback implements [database.Transaction].
|
||||
func (tx *pgxTx) Rollback(ctx context.Context) error {
|
||||
return tx.Tx.Rollback(ctx)
|
||||
}
|
||||
|
||||
// End implements [database.Transaction].
|
||||
func (tx *pgxTx) End(ctx context.Context, err error) error {
|
||||
if err != nil {
|
||||
tx.Rollback(ctx)
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Query implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
|
||||
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := tx.Tx.Query(ctx, sql, args...)
|
||||
return &Rows{rows}, err
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
|
||||
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return tx.Tx.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
|
||||
_, err := tx.Tx.Exec(ctx, sql, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
// Begin implements [database.Transaction].
|
||||
// As postgres does not support nested transactions we use savepoints to emulate them.
|
||||
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
|
||||
savepoint, err := tx.Tx.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &pgxTx{savepoint}, nil
|
||||
}
|
||||
|
||||
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
|
||||
if opts == nil {
|
||||
return pgx.TxOptions{}
|
||||
}
|
||||
|
||||
return pgx.TxOptions{
|
||||
IsoLevel: isolationToPgx(opts.IsolationLevel),
|
||||
AccessMode: accessModeToPgx(opts.AccessMode),
|
||||
}
|
||||
}
|
||||
|
||||
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
|
||||
switch isolation {
|
||||
case database.IsolationLevelSerializable:
|
||||
return pgx.Serializable
|
||||
case database.IsolationLevelReadCommitted:
|
||||
return pgx.ReadCommitted
|
||||
default:
|
||||
return pgx.Serializable
|
||||
}
|
||||
}
|
||||
|
||||
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
|
||||
switch accessMode {
|
||||
case database.AccessModeReadWrite:
|
||||
return pgx.ReadWrite
|
||||
case database.AccessModeReadOnly:
|
||||
return pgx.ReadOnly
|
||||
default:
|
||||
return pgx.ReadWrite
|
||||
}
|
||||
}
|
3
backend/v3/storage/database/gen_mock.go
Normal file
3
backend/v3/storage/database/gen_mock.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package database
|
||||
|
||||
//go:generate mockgen -typed -package mock -destination ./mock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
|
1067
backend/v3/storage/database/mock/database.mock.go
Normal file
1067
backend/v3/storage/database/mock/database.mock.go
Normal file
File diff suppressed because it is too large
Load Diff
160
backend/v3/storage/database/repository/clause.go
Normal file
160
backend/v3/storage/database/repository/clause.go
Normal file
@@ -0,0 +1,160 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
)
|
||||
|
||||
type field interface {
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
type fieldDescriptor struct {
|
||||
schema string
|
||||
table string
|
||||
name string
|
||||
}
|
||||
|
||||
func (f fieldDescriptor) String() string {
|
||||
return f.schema + "." + f.table + "." + f.name
|
||||
}
|
||||
|
||||
type ignoreCaseFieldDescriptor struct {
|
||||
fieldDescriptor
|
||||
fieldNameSuffix string
|
||||
}
|
||||
|
||||
func (f ignoreCaseFieldDescriptor) String() string {
|
||||
return f.fieldDescriptor.String() + f.fieldNameSuffix
|
||||
}
|
||||
|
||||
type textFieldDescriptor struct {
|
||||
field
|
||||
isIgnoreCase bool
|
||||
}
|
||||
|
||||
type clause[Op domain.Operation] struct {
|
||||
field field
|
||||
op Op
|
||||
}
|
||||
|
||||
const (
|
||||
schema = "zitadel"
|
||||
userTable = "users"
|
||||
)
|
||||
|
||||
var userFields = map[domain.UserField]field{
|
||||
domain.UserFieldInstanceID: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "instance_id",
|
||||
},
|
||||
domain.UserFieldOrgID: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "org_id",
|
||||
},
|
||||
domain.UserFieldID: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "id",
|
||||
},
|
||||
domain.UserFieldUsername: textFieldDescriptor{
|
||||
field: ignoreCaseFieldDescriptor{
|
||||
fieldDescriptor: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "username",
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
},
|
||||
domain.UserHumanFieldEmail: textFieldDescriptor{
|
||||
field: ignoreCaseFieldDescriptor{
|
||||
fieldDescriptor: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "email",
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
},
|
||||
domain.UserHumanFieldEmailVerified: fieldDescriptor{
|
||||
schema: schema,
|
||||
table: userTable,
|
||||
name: "email_is_verified",
|
||||
},
|
||||
}
|
||||
|
||||
type textClause[V domain.Text] struct {
|
||||
clause[domain.TextOperation]
|
||||
value V
|
||||
}
|
||||
|
||||
var textOp map[domain.TextOperation]string = map[domain.TextOperation]string{
|
||||
domain.TextOperationEqual: " = ",
|
||||
domain.TextOperationNotEqual: " <> ",
|
||||
domain.TextOperationStartsWith: " LIKE ",
|
||||
domain.TextOperationStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
func (tc textClause[V]) Write(stmt *statement) {
|
||||
placeholder := stmt.appendArg(tc.value)
|
||||
var (
|
||||
left, right string
|
||||
)
|
||||
switch tc.clause.op {
|
||||
case domain.TextOperationEqual:
|
||||
left = tc.clause.field.String()
|
||||
right = placeholder
|
||||
case domain.TextOperationNotEqual:
|
||||
left = tc.clause.field.String()
|
||||
right = placeholder
|
||||
case domain.TextOperationStartsWith:
|
||||
left = tc.clause.field.String()
|
||||
right = placeholder + "%"
|
||||
case domain.TextOperationStartsWithIgnoreCase:
|
||||
left = tc.clause.field.String()
|
||||
if _, ok := tc.clause.field.(ignoreCaseFieldDescriptor); !ok {
|
||||
left = "LOWER(" + left + ")"
|
||||
}
|
||||
right = "LOWER(" + placeholder + "%)"
|
||||
}
|
||||
|
||||
stmt.builder.WriteString(left)
|
||||
stmt.builder.WriteString(textOp[tc.clause.op])
|
||||
stmt.builder.WriteString(right)
|
||||
}
|
||||
|
||||
type boolClause[V domain.Bool] struct {
|
||||
clause[domain.BoolOperation]
|
||||
value V
|
||||
}
|
||||
|
||||
func (bc boolClause[V]) Write(stmt *statement) {
|
||||
if !bc.value {
|
||||
stmt.builder.WriteString("NOT ")
|
||||
}
|
||||
stmt.builder.WriteString(bc.clause.field.String())
|
||||
}
|
||||
|
||||
type numberClause[V domain.Number] struct {
|
||||
clause[domain.NumberOperation]
|
||||
value V
|
||||
}
|
||||
|
||||
var numberOp map[domain.NumberOperation]string = map[domain.NumberOperation]string{
|
||||
domain.NumberOperationEqual: " = ",
|
||||
domain.NumberOperationNotEqual: " <> ",
|
||||
domain.NumberOperationLessThan: " < ",
|
||||
domain.NumberOperationLessThanOrEqual: " <= ",
|
||||
domain.NumberOperationGreaterThan: " > ",
|
||||
domain.NumberOperationGreaterThanOrEqual: " >= ",
|
||||
}
|
||||
|
||||
func (nc numberClause[V]) Write(stmt *statement) {
|
||||
stmt.builder.WriteString(nc.clause.field.String())
|
||||
stmt.builder.WriteString(numberOp[nc.clause.op])
|
||||
stmt.builder.WriteString(stmt.appendArg(nc.value))
|
||||
}
|
45
backend/v3/storage/database/repository/crypto.go
Normal file
45
backend/v3/storage/database/repository/crypto.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/internal/crypto"
|
||||
)
|
||||
|
||||
type cryptoRepo struct {
|
||||
database.QueryExecutor
|
||||
}
|
||||
|
||||
func Crypto(db database.QueryExecutor) domain.CryptoRepository {
|
||||
return &cryptoRepo{
|
||||
QueryExecutor: db,
|
||||
}
|
||||
}
|
||||
|
||||
const getEncryptionConfigQuery = "SELECT" +
|
||||
" length" +
|
||||
", expiry" +
|
||||
", should_include_lower_letters" +
|
||||
", should_include_upper_letters" +
|
||||
", should_include_digits" +
|
||||
", should_include_symbols" +
|
||||
" FROM encryption_config"
|
||||
|
||||
func (repo *cryptoRepo) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) {
|
||||
var config crypto.GeneratorConfig
|
||||
row := repo.QueryRow(ctx, getEncryptionConfigQuery)
|
||||
err := row.Scan(
|
||||
&config.Length,
|
||||
&config.Expiry,
|
||||
&config.IncludeLowerLetters,
|
||||
&config.IncludeUpperLetters,
|
||||
&config.IncludeDigits,
|
||||
&config.IncludeSymbols,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &config, nil
|
||||
}
|
7
backend/v3/storage/database/repository/doc.go
Normal file
7
backend/v3/storage/database/repository/doc.go
Normal file
@@ -0,0 +1,7 @@
|
||||
// Repository package provides the database repository for the application.
|
||||
// It contains the implementation of the [repository pattern](https://martinfowler.com/eaaCatalog/repository.html) for the database.
|
||||
|
||||
// funcs which need to interact with the database should create interfaces which are implemented by the
|
||||
// [query] and [exec] structs respectively their factory methods [Query] and [Execute]. The [query] struct is used for read operations, while the [exec] struct is used for write operations.
|
||||
|
||||
package repository
|
54
backend/v3/storage/database/repository/instance.go
Normal file
54
backend/v3/storage/database/repository/instance.go
Normal file
@@ -0,0 +1,54 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type instance struct {
|
||||
database.QueryExecutor
|
||||
}
|
||||
|
||||
func Instance(client database.QueryExecutor) domain.InstanceRepository {
|
||||
return &instance{QueryExecutor: client}
|
||||
}
|
||||
|
||||
func (i *instance) ByID(ctx context.Context, id string) (*domain.Instance, error) {
|
||||
var instance domain.Instance
|
||||
err := i.QueryExecutor.QueryRow(ctx, `SELECT id, name, created_at, updated_at, deleted_at FROM instances WHERE id = $1`, id).Scan(
|
||||
&instance.ID,
|
||||
&instance.Name,
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
&instance.DeletedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &instance, nil
|
||||
}
|
||||
|
||||
const createInstanceStmt = `INSERT INTO instances (id, name) VALUES ($1, $2) RETURNING created_at, updated_at`
|
||||
|
||||
// Create implements [domain.InstanceRepository].
|
||||
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
|
||||
return i.QueryExecutor.QueryRow(ctx, createInstanceStmt,
|
||||
instance.ID,
|
||||
instance.Name,
|
||||
).Scan(
|
||||
&instance.CreatedAt,
|
||||
&instance.UpdatedAt,
|
||||
)
|
||||
}
|
||||
|
||||
// On implements [domain.InstanceRepository].
|
||||
func (i *instance) On(id string) domain.InstanceOperation {
|
||||
return &instanceOperation{
|
||||
QueryExecutor: i.QueryExecutor,
|
||||
id: id,
|
||||
}
|
||||
}
|
||||
|
||||
var _ domain.InstanceRepository = (*instance)(nil)
|
52
backend/v3/storage/database/repository/instance_operation.go
Normal file
52
backend/v3/storage/database/repository/instance_operation.go
Normal file
@@ -0,0 +1,52 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type instanceOperation struct {
|
||||
database.QueryExecutor
|
||||
id string
|
||||
}
|
||||
|
||||
const addInstanceAdminStmt = `INSERT INTO instance_admins (instance_id, user_id, roles) VALUES ($1, $2, $3)`
|
||||
|
||||
// AddAdmin implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) AddAdmin(ctx context.Context, userID string, roles []string) error {
|
||||
return i.QueryExecutor.Exec(ctx, addInstanceAdminStmt, i.id, userID, roles)
|
||||
}
|
||||
|
||||
// Delete implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) Delete(ctx context.Context) error {
|
||||
return i.QueryExecutor.Exec(ctx, `DELETE FROM instances WHERE id = $1`, i.id)
|
||||
}
|
||||
|
||||
const removeInstanceAdminStmt = `DELETE FROM instance_admins WHERE instance_id = $1 AND user_id = $2`
|
||||
|
||||
// RemoveAdmin implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) RemoveAdmin(ctx context.Context, userID string) error {
|
||||
return i.QueryExecutor.Exec(ctx, removeInstanceAdminStmt, i.id, userID)
|
||||
}
|
||||
|
||||
const setInstanceAdminRolesStmt = `UPDATE instance_admins SET roles = $1 WHERE instance_id = $2 AND user_id = $3`
|
||||
|
||||
// SetAdminRoles implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) SetAdminRoles(ctx context.Context, userID string, roles []string) error {
|
||||
return i.QueryExecutor.Exec(ctx, setInstanceAdminRolesStmt, roles, i.id, userID)
|
||||
}
|
||||
|
||||
const updateInstanceStmt = `UPDATE instances SET name = $1, updated_at = $2 WHERE id = $3 RETURNING updated_at`
|
||||
|
||||
// Update implements [domain.InstanceOperation].
|
||||
func (i *instanceOperation) Update(ctx context.Context, instance *domain.Instance) error {
|
||||
return i.QueryExecutor.QueryRow(ctx, updateInstanceStmt,
|
||||
instance.Name,
|
||||
instance.UpdatedAt,
|
||||
i.id,
|
||||
).Scan(&instance.UpdatedAt)
|
||||
}
|
||||
|
||||
var _ domain.InstanceOperation = (*instanceOperation)(nil)
|
17
backend/v3/storage/database/repository/query.go
Normal file
17
backend/v3/storage/database/repository/query.go
Normal file
@@ -0,0 +1,17 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type query struct{ database.Querier }
|
||||
|
||||
func Query(querier database.Querier) *query {
|
||||
return &query{Querier: querier}
|
||||
}
|
||||
|
||||
type executor struct{ database.Executor }
|
||||
|
||||
func Execute(exec database.Executor) *executor {
|
||||
return &executor{Executor: exec}
|
||||
}
|
21
backend/v3/storage/database/repository/statement.go
Normal file
21
backend/v3/storage/database/repository/statement.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package repository
|
||||
|
||||
import "strings"
|
||||
|
||||
type statement struct {
|
||||
builder strings.Builder
|
||||
args []any
|
||||
}
|
||||
|
||||
func (s *statement) appendArg(arg any) (placeholder string) {
|
||||
s.args = append(s.args, arg)
|
||||
return "$" + string(len(s.args))
|
||||
}
|
||||
|
||||
func (s *statement) appendArgs(args ...any) (placeholders []string) {
|
||||
placeholders = make([]string, len(args))
|
||||
for i, arg := range args {
|
||||
placeholders[i] = s.appendArg(arg)
|
||||
}
|
||||
return placeholders
|
||||
}
|
43
backend/v3/storage/database/repository/stmt/column.go
Normal file
43
backend/v3/storage/database/repository/stmt/column.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package stmt
|
||||
|
||||
import "fmt"
|
||||
|
||||
type Column[T any] interface {
|
||||
fmt.Stringer
|
||||
statementApplier[T]
|
||||
scanner(t *T) any
|
||||
}
|
||||
|
||||
type columnDescriptor[T any] struct {
|
||||
name string
|
||||
scan func(*T) any
|
||||
}
|
||||
|
||||
func (cd columnDescriptor[T]) scanner(t *T) any {
|
||||
return cd.scan(t)
|
||||
}
|
||||
|
||||
// Apply implements [Column].
|
||||
func (f columnDescriptor[T]) Apply(stmt *statement[T]) {
|
||||
stmt.builder.WriteString(stmt.columnPrefix())
|
||||
stmt.builder.WriteString(f.String())
|
||||
}
|
||||
|
||||
// String implements [Column].
|
||||
func (f columnDescriptor[T]) String() string {
|
||||
return f.name
|
||||
}
|
||||
|
||||
var _ Column[any] = (*columnDescriptor[any])(nil)
|
||||
|
||||
type ignoreCaseColumnDescriptor[T any] struct {
|
||||
columnDescriptor[T]
|
||||
fieldNameSuffix string
|
||||
}
|
||||
|
||||
func (f ignoreCaseColumnDescriptor[T]) ApplyIgnoreCase(stmt *statement[T]) {
|
||||
stmt.builder.WriteString(f.String())
|
||||
stmt.builder.WriteString(f.fieldNameSuffix)
|
||||
}
|
||||
|
||||
var _ Column[any] = (*ignoreCaseColumnDescriptor[any])(nil)
|
97
backend/v3/storage/database/repository/stmt/condition.go
Normal file
97
backend/v3/storage/database/repository/stmt/condition.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package stmt
|
||||
|
||||
import "fmt"
|
||||
|
||||
type statementApplier[T any] interface {
|
||||
// Apply writes the statement to the builder.
|
||||
Apply(stmt *statement[T])
|
||||
}
|
||||
|
||||
type Condition[T any] interface {
|
||||
statementApplier[T]
|
||||
}
|
||||
|
||||
type op interface {
|
||||
TextOperation | NumberOperation | ListOperation
|
||||
fmt.Stringer
|
||||
}
|
||||
|
||||
type operation[T any, O op] struct {
|
||||
o O
|
||||
}
|
||||
|
||||
func (o operation[T, O]) String() string {
|
||||
return o.o.String()
|
||||
}
|
||||
|
||||
func (o operation[T, O]) Apply(stmt *statement[T]) {
|
||||
stmt.builder.WriteString(o.o.String())
|
||||
}
|
||||
|
||||
type condition[V, T any, OP op] struct {
|
||||
field Column[T]
|
||||
op OP
|
||||
value V
|
||||
}
|
||||
|
||||
func (c *condition[V, T, OP]) Apply(stmt *statement[T]) {
|
||||
// placeholder := stmt.appendArg(c.value)
|
||||
stmt.builder.WriteString(stmt.columnPrefix())
|
||||
stmt.builder.WriteString(c.field.String())
|
||||
// stmt.builder.WriteString(c.op)
|
||||
// stmt.builder.WriteString(placeholder)
|
||||
}
|
||||
|
||||
type and[T any] struct {
|
||||
conditions []Condition[T]
|
||||
}
|
||||
|
||||
func And[T any](conditions ...Condition[T]) *and[T] {
|
||||
return &and[T]{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply implements [Condition].
|
||||
func (a *and[T]) Apply(stmt *statement[T]) {
|
||||
if len(a.conditions) > 1 {
|
||||
stmt.builder.WriteString("(")
|
||||
defer stmt.builder.WriteString(")")
|
||||
}
|
||||
|
||||
for i, condition := range a.conditions {
|
||||
if i > 0 {
|
||||
stmt.builder.WriteString(" AND ")
|
||||
}
|
||||
condition.Apply(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition[any] = (*and[any])(nil)
|
||||
|
||||
type or[T any] struct {
|
||||
conditions []Condition[T]
|
||||
}
|
||||
|
||||
func Or[T any](conditions ...Condition[T]) *or[T] {
|
||||
return &or[T]{
|
||||
conditions: conditions,
|
||||
}
|
||||
}
|
||||
|
||||
// Apply implements [Condition].
|
||||
func (o *or[T]) Apply(stmt *statement[T]) {
|
||||
if len(o.conditions) > 1 {
|
||||
stmt.builder.WriteString("(")
|
||||
defer stmt.builder.WriteString(")")
|
||||
}
|
||||
|
||||
for i, condition := range o.conditions {
|
||||
if i > 0 {
|
||||
stmt.builder.WriteString(" OR ")
|
||||
}
|
||||
condition.Apply(stmt)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Condition[any] = (*or[any])(nil)
|
71
backend/v3/storage/database/repository/stmt/list.go
Normal file
71
backend/v3/storage/database/repository/stmt/list.go
Normal file
@@ -0,0 +1,71 @@
|
||||
package stmt
|
||||
|
||||
type ListEntry interface {
|
||||
Number | Text | any
|
||||
}
|
||||
|
||||
type ListCondition[E ListEntry, T any] struct {
|
||||
condition[[]E, T, ListOperation]
|
||||
}
|
||||
|
||||
func (lc *ListCondition[E, T]) Apply(stmt *statement[T]) {
|
||||
placeholder := stmt.appendArg(lc.value)
|
||||
|
||||
switch lc.op {
|
||||
case ListOperationEqual, ListOperationNotEqual:
|
||||
lc.field.Apply(stmt)
|
||||
operation[T, ListOperation]{lc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
case ListOperationContainsAny, ListOperationContainsAll:
|
||||
lc.field.Apply(stmt)
|
||||
operation[T, ListOperation]{lc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
case ListOperationNotContainsAny, ListOperationNotContainsAll:
|
||||
stmt.builder.WriteString("NOT (")
|
||||
lc.field.Apply(stmt)
|
||||
operation[T, ListOperation]{lc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString(")")
|
||||
default:
|
||||
panic("unknown list operation")
|
||||
}
|
||||
}
|
||||
|
||||
type ListOperation uint8
|
||||
|
||||
const (
|
||||
// ListOperationEqual checks if the arrays are equal including the order of the elements
|
||||
ListOperationEqual ListOperation = iota + 1
|
||||
// ListOperationNotEqual checks if the arrays are not equal including the order of the elements
|
||||
ListOperationNotEqual
|
||||
|
||||
// ListOperationContains checks if the array column contains all the values of the specified array
|
||||
ListOperationContainsAll
|
||||
// ListOperationContainsAny checks if the arrays have at least one value in common
|
||||
ListOperationContainsAny
|
||||
// ListOperationContainsAll checks if the array column contains all the values of the specified array
|
||||
|
||||
// ListOperationNotContainsAll checks if the specified array is not contained by the column
|
||||
ListOperationNotContainsAll
|
||||
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
|
||||
ListOperationNotContainsAny
|
||||
)
|
||||
|
||||
var listOperations = map[ListOperation]string{
|
||||
// ListOperationEqual checks if the lists are equal
|
||||
ListOperationEqual: " = ",
|
||||
// ListOperationNotEqual checks if the lists are not equal
|
||||
ListOperationNotEqual: " <> ",
|
||||
// ListOperationContainsAny checks if the arrays have at least one value in common
|
||||
ListOperationContainsAny: " && ",
|
||||
// ListOperationContainsAll checks if the array column contains all the values of the specified array
|
||||
ListOperationContainsAll: " @> ",
|
||||
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
|
||||
ListOperationNotContainsAny: " && ", // Base operator for NOT (A && B)
|
||||
// ListOperationNotContainsAll checks if the array column is not contained by the specified array
|
||||
ListOperationNotContainsAll: " <@ ", // Base operator for NOT (A <@ B)
|
||||
}
|
||||
|
||||
func (lo ListOperation) String() string {
|
||||
return listOperations[lo]
|
||||
}
|
61
backend/v3/storage/database/repository/stmt/number.go
Normal file
61
backend/v3/storage/database/repository/stmt/number.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package stmt
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
|
||||
}
|
||||
|
||||
type between[N Number] struct {
|
||||
min, max N
|
||||
}
|
||||
|
||||
type NumberBetween[V Number, T any] struct {
|
||||
condition[between[V], T, NumberOperation]
|
||||
}
|
||||
|
||||
func (nb *NumberBetween[V, T]) Apply(stmt *statement[T]) {
|
||||
nb.field.Apply(stmt)
|
||||
stmt.builder.WriteString(" BETWEEN ")
|
||||
stmt.builder.WriteString(stmt.appendArg(nb.value.min))
|
||||
stmt.builder.WriteString(" AND ")
|
||||
stmt.builder.WriteString(stmt.appendArg(nb.value.max))
|
||||
}
|
||||
|
||||
type NumberCondition[V Number, T any] struct {
|
||||
condition[V, T, NumberOperation]
|
||||
}
|
||||
|
||||
func (nc *NumberCondition[V, T]) Apply(stmt *statement[T]) {
|
||||
nc.field.Apply(stmt)
|
||||
operation[T, NumberOperation]{nc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(stmt.appendArg(nc.value))
|
||||
}
|
||||
|
||||
type NumberOperation uint8
|
||||
|
||||
const (
|
||||
NumberOperationEqual NumberOperation = iota + 1
|
||||
NumberOperationNotEqual
|
||||
NumberOperationLessThan
|
||||
NumberOperationLessThanOrEqual
|
||||
NumberOperationGreaterThan
|
||||
NumberOperationGreaterThanOrEqual
|
||||
)
|
||||
|
||||
var numberOperations = map[NumberOperation]string{
|
||||
NumberOperationEqual: " = ",
|
||||
NumberOperationNotEqual: " <> ",
|
||||
NumberOperationLessThan: " < ",
|
||||
NumberOperationLessThanOrEqual: " <= ",
|
||||
NumberOperationGreaterThan: " > ",
|
||||
NumberOperationGreaterThanOrEqual: " >= ",
|
||||
}
|
||||
|
||||
func (no NumberOperation) String() string {
|
||||
return numberOperations[no]
|
||||
}
|
104
backend/v3/storage/database/repository/stmt/statement.go
Normal file
104
backend/v3/storage/database/repository/stmt/statement.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package stmt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type statement[T any] struct {
|
||||
builder strings.Builder
|
||||
client database.QueryExecutor
|
||||
|
||||
columns []Column[T]
|
||||
|
||||
schema string
|
||||
table string
|
||||
alias string
|
||||
|
||||
condition Condition[T]
|
||||
|
||||
limit uint32
|
||||
offset uint32
|
||||
// order by fieldname and sort direction false for asc true for desc
|
||||
// orderBy SortingColumns[C]
|
||||
args []any
|
||||
existingArgs map[any]string
|
||||
}
|
||||
|
||||
func (s *statement[T]) scanners(t *T) []any {
|
||||
scanners := make([]any, len(s.columns))
|
||||
for i, column := range s.columns {
|
||||
scanners[i] = column.scanner(t)
|
||||
}
|
||||
return scanners
|
||||
}
|
||||
|
||||
func (s *statement[T]) query() string {
|
||||
s.builder.WriteString(`SELECT `)
|
||||
for i, column := range s.columns {
|
||||
if i > 0 {
|
||||
s.builder.WriteString(", ")
|
||||
}
|
||||
column.Apply(s)
|
||||
}
|
||||
s.builder.WriteString(` FROM `)
|
||||
s.builder.WriteString(s.schema)
|
||||
s.builder.WriteRune('.')
|
||||
s.builder.WriteString(s.table)
|
||||
if s.alias != "" {
|
||||
s.builder.WriteString(" AS ")
|
||||
s.builder.WriteString(s.alias)
|
||||
}
|
||||
|
||||
s.builder.WriteString(` WHERE `)
|
||||
|
||||
s.condition.Apply(s)
|
||||
|
||||
if s.limit > 0 {
|
||||
s.builder.WriteString(` LIMIT `)
|
||||
s.builder.WriteString(s.appendArg(s.limit))
|
||||
}
|
||||
if s.offset > 0 {
|
||||
s.builder.WriteString(` OFFSET `)
|
||||
s.builder.WriteString(s.appendArg(s.offset))
|
||||
}
|
||||
|
||||
return s.builder.String()
|
||||
}
|
||||
|
||||
// func (s *statement[T]) Where(condition Condition[T]) *statement[T] {
|
||||
// s.condition = condition
|
||||
// return s
|
||||
// }
|
||||
|
||||
// func (s *statement[T]) Limit(limit uint32) *statement[T] {
|
||||
// s.limit = limit
|
||||
// return s
|
||||
// }
|
||||
|
||||
// func (s *statement[T]) Offset(offset uint32) *statement[T] {
|
||||
// s.offset = offset
|
||||
// return s
|
||||
// }
|
||||
|
||||
func (s *statement[T]) columnPrefix() string {
|
||||
if s.alias != "" {
|
||||
return s.alias + "."
|
||||
}
|
||||
return s.schema + "." + s.table + "."
|
||||
}
|
||||
|
||||
func (s *statement[T]) appendArg(arg any) string {
|
||||
if s.existingArgs == nil {
|
||||
s.existingArgs = make(map[any]string)
|
||||
}
|
||||
if existing, ok := s.existingArgs[arg]; ok {
|
||||
return existing
|
||||
}
|
||||
s.args = append(s.args, arg)
|
||||
placeholder := fmt.Sprintf("$%d", len(s.args))
|
||||
s.existingArgs[arg] = placeholder
|
||||
return placeholder
|
||||
}
|
18
backend/v3/storage/database/repository/stmt/stmt_test.go
Normal file
18
backend/v3/storage/database/repository/stmt/stmt_test.go
Normal file
@@ -0,0 +1,18 @@
|
||||
package stmt_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt"
|
||||
)
|
||||
|
||||
func Test_Bla(t *testing.T) {
|
||||
stmt.User(nil).Where(
|
||||
stmt.Or(
|
||||
stmt.UserIDCondition("123"),
|
||||
stmt.UserIDCondition("123"),
|
||||
stmt.UserUsernameCondition(stmt.TextOperationEqualIgnoreCase, "test"),
|
||||
),
|
||||
).Limit(1).Offset(1).Get(context.Background())
|
||||
}
|
72
backend/v3/storage/database/repository/stmt/text.go
Normal file
72
backend/v3/storage/database/repository/stmt/text.go
Normal file
@@ -0,0 +1,72 @@
|
||||
package stmt
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
}
|
||||
|
||||
type TextCondition[V Text, T any] struct {
|
||||
condition[V, T, TextOperation]
|
||||
}
|
||||
|
||||
func (tc *TextCondition[V, T]) Apply(stmt *statement[T]) {
|
||||
placeholder := stmt.appendArg(tc.value)
|
||||
|
||||
switch tc.op {
|
||||
case TextOperationEqual, TextOperationNotEqual:
|
||||
tc.field.Apply(stmt)
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
case TextOperationEqualIgnoreCase:
|
||||
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
|
||||
desc.ApplyIgnoreCase(stmt)
|
||||
} else {
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
tc.field.Apply(stmt)
|
||||
stmt.builder.WriteString(")")
|
||||
}
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString(")")
|
||||
case TextOperationStartsWith:
|
||||
tc.field.Apply(stmt)
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString("|| '%'")
|
||||
case TextOperationStartsWithIgnoreCase:
|
||||
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
|
||||
desc.ApplyIgnoreCase(stmt)
|
||||
} else {
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
tc.field.Apply(stmt)
|
||||
stmt.builder.WriteString(")")
|
||||
}
|
||||
operation[T, TextOperation]{tc.op}.Apply(stmt)
|
||||
stmt.builder.WriteString("LOWER(")
|
||||
stmt.builder.WriteString(placeholder)
|
||||
stmt.builder.WriteString(")")
|
||||
stmt.builder.WriteString("|| '%'")
|
||||
}
|
||||
}
|
||||
|
||||
type TextOperation uint8
|
||||
|
||||
const (
|
||||
TextOperationEqual TextOperation = iota + 1
|
||||
TextOperationEqualIgnoreCase
|
||||
TextOperationNotEqual
|
||||
TextOperationStartsWith
|
||||
TextOperationStartsWithIgnoreCase
|
||||
)
|
||||
|
||||
var textOperations = map[TextOperation]string{
|
||||
TextOperationEqual: " = ",
|
||||
TextOperationEqualIgnoreCase: " = ",
|
||||
TextOperationNotEqual: " <> ",
|
||||
TextOperationStartsWith: " LIKE ",
|
||||
TextOperationStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
func (to TextOperation) String() string {
|
||||
return textOperations[to]
|
||||
}
|
193
backend/v3/storage/database/repository/stmt/user.go
Normal file
193
backend/v3/storage/database/repository/stmt/user.go
Normal file
@@ -0,0 +1,193 @@
|
||||
package stmt
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type userStatement struct {
|
||||
statement[domain.User]
|
||||
}
|
||||
|
||||
func User(client database.QueryExecutor) *userStatement {
|
||||
return &userStatement{
|
||||
statement: statement[domain.User]{
|
||||
schema: "zitadel",
|
||||
table: "users",
|
||||
alias: "u",
|
||||
client: client,
|
||||
columns: []Column[domain.User]{
|
||||
userColumns[UserInstanceID],
|
||||
userColumns[UserOrgID],
|
||||
userColumns[UserColumnID],
|
||||
userColumns[UserColumnUsername],
|
||||
userColumns[UserCreatedAt],
|
||||
userColumns[UserUpdatedAt],
|
||||
userColumns[UserDeletedAt],
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *userStatement) Where(condition Condition[domain.User]) *userStatement {
|
||||
s.condition = condition
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *userStatement) Limit(limit uint32) *userStatement {
|
||||
s.limit = limit
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *userStatement) Offset(offset uint32) *userStatement {
|
||||
s.offset = offset
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *userStatement) Get(ctx context.Context) (*domain.User, error) {
|
||||
var user domain.User
|
||||
err := s.client.QueryRow(ctx, s.query(), s.statement.args...).Scan(s.scanners(&user)...)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
|
||||
func (s *userStatement) List(ctx context.Context) ([]*domain.User, error) {
|
||||
var users []*domain.User
|
||||
rows, err := s.client.Query(ctx, s.query(), s.statement.args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer rows.Close()
|
||||
|
||||
for rows.Next() {
|
||||
var user domain.User
|
||||
err = rows.Scan(s.scanners(&user)...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, &user)
|
||||
}
|
||||
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (s *userStatement) SetUsername(ctx context.Context, username string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type UserColumn uint8
|
||||
|
||||
var (
|
||||
userColumns map[UserColumn]Column[domain.User] = map[UserColumn]Column[domain.User]{
|
||||
UserInstanceID: columnDescriptor[domain.User]{
|
||||
name: "instance_id",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.InstanceID
|
||||
},
|
||||
},
|
||||
UserOrgID: columnDescriptor[domain.User]{
|
||||
name: "org_id",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.OrgID
|
||||
},
|
||||
},
|
||||
UserColumnID: columnDescriptor[domain.User]{
|
||||
name: "id",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.ID
|
||||
},
|
||||
},
|
||||
UserColumnUsername: ignoreCaseColumnDescriptor[domain.User]{
|
||||
columnDescriptor: columnDescriptor[domain.User]{
|
||||
name: "username",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.Username
|
||||
},
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
UserCreatedAt: columnDescriptor[domain.User]{
|
||||
name: "created_at",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.CreatedAt
|
||||
},
|
||||
},
|
||||
UserUpdatedAt: columnDescriptor[domain.User]{
|
||||
name: "updated_at",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.UpdatedAt
|
||||
},
|
||||
},
|
||||
UserDeletedAt: columnDescriptor[domain.User]{
|
||||
name: "deleted_at",
|
||||
scan: func(u *domain.User) any {
|
||||
return &u.DeletedAt
|
||||
},
|
||||
},
|
||||
}
|
||||
humanColumns = map[UserColumn]Column[domain.User]{
|
||||
UserHumanColumnEmail: ignoreCaseColumnDescriptor[domain.User]{
|
||||
columnDescriptor: columnDescriptor[domain.User]{
|
||||
name: "email",
|
||||
scan: func(u *domain.User) any {
|
||||
human, ok := u.Traits.(*domain.Human)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if human.Email == nil {
|
||||
human.Email = new(domain.Email)
|
||||
}
|
||||
return &human.Email.Address
|
||||
},
|
||||
},
|
||||
fieldNameSuffix: "_lower",
|
||||
},
|
||||
UserHumanColumnEmailVerified: columnDescriptor[domain.User]{
|
||||
name: "email_is_verified",
|
||||
scan: func(u *domain.User) any {
|
||||
human, ok := u.Traits.(*domain.Human)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if human.Email == nil {
|
||||
human.Email = new(domain.Email)
|
||||
}
|
||||
return &human.Email.IsVerified
|
||||
},
|
||||
},
|
||||
}
|
||||
machineColumns = map[UserColumn]Column[domain.User]{
|
||||
UserMachineDescription: columnDescriptor[domain.User]{
|
||||
name: "description",
|
||||
scan: func(u *domain.User) any {
|
||||
machine, ok := u.Traits.(*domain.Machine)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if machine == nil {
|
||||
machine = new(domain.Machine)
|
||||
}
|
||||
return &machine.Description
|
||||
},
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
const (
|
||||
UserInstanceID UserColumn = iota + 1
|
||||
UserOrgID
|
||||
UserColumnID
|
||||
UserColumnUsername
|
||||
UserHumanColumnEmail
|
||||
UserHumanColumnEmailVerified
|
||||
UserMachineDescription
|
||||
UserCreatedAt
|
||||
UserUpdatedAt
|
||||
UserDeletedAt
|
||||
)
|
@@ -0,0 +1,23 @@
|
||||
package stmt
|
||||
|
||||
import "github.com/zitadel/zitadel/backend/v3/domain"
|
||||
|
||||
func UserIDCondition(id string) *TextCondition[string, domain.User] {
|
||||
return &TextCondition[string, domain.User]{
|
||||
condition: condition[string, domain.User, TextOperation]{
|
||||
field: userColumns[UserColumnID],
|
||||
op: TextOperationEqual,
|
||||
value: id,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func UserUsernameCondition(op TextOperation, username string) *TextCondition[string, domain.User] {
|
||||
return &TextCondition[string, domain.User]{
|
||||
condition: condition[string, domain.User, TextOperation]{
|
||||
field: userColumns[UserColumnUsername],
|
||||
op: op,
|
||||
value: username,
|
||||
},
|
||||
}
|
||||
}
|
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user