multiple tries

This commit is contained in:
adlerhurst
2025-04-29 06:03:47 +02:00
parent 77c4cc8185
commit 986c62b61a
131 changed files with 9805 additions and 47 deletions

View File

@@ -0,0 +1,56 @@
package grpc
import (
"context"
"github.com/zitadel/zitadel/backend/command/command"
"github.com/zitadel/zitadel/backend/command/query"
"github.com/zitadel/zitadel/backend/command/receiver"
"github.com/zitadel/zitadel/backend/command/receiver/cache"
"github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/telemetry/logging"
"github.com/zitadel/zitadel/backend/telemetry/tracing"
)
type api struct {
db database.Pool
manipulator receiver.InstanceManipulator
reader receiver.InstanceReader
tracer *tracing.Tracer
logger *logging.Logger
cache cache.Cache[receiver.InstanceIndex, string, *receiver.Instance]
}
func (a *api) CreateInstance(ctx context.Context) error {
instance := &receiver.Instance{
ID: "123",
Name: "test",
}
return command.Trace(
a.tracer,
command.SetCache(a.cache,
command.Activity(a.logger, command.CreateInstance(a.manipulator, instance)),
instance,
),
).Execute(ctx)
}
func (a *api) DeleteInstance(ctx context.Context) error {
return command.Trace(
a.tracer,
command.DeleteCache(a.cache,
command.Activity(
a.logger,
command.DeleteInstance(a.manipulator, &receiver.Instance{
ID: "123",
})),
receiver.InstanceByID,
"123",
)).Execute(ctx)
}
func (a *api) InstanceByID(ctx context.Context) (*receiver.Instance, error) {
q := query.InstanceByID(a.reader, "123")
return q.Execute(ctx)
}

View File

@@ -0,0 +1,102 @@
package command
import (
"context"
"github.com/zitadel/zitadel/backend/command/receiver/cache"
)
type setCache[I, K comparable, V cache.Entry[I, K]] struct {
cache cache.Cache[I, K, V]
command Command
entry V
}
// SetCache decorates the command, if the command is executed without error it will set the cache entry.
func SetCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, entry V) Command {
return &setCache[I, K, V]{
cache: cache,
command: command,
entry: entry,
}
}
var _ Command = (*setCache[any, any, cache.Entry[any, any]])(nil)
// Execute implements [Command].
func (s *setCache[I, K, V]) Execute(ctx context.Context) error {
if err := s.command.Execute(ctx); err != nil {
return err
}
s.cache.Set(ctx, s.entry)
return nil
}
// Name implements [Command].
func (s *setCache[I, K, V]) Name() string {
return s.command.Name()
}
type deleteCache[I, K comparable, V cache.Entry[I, K]] struct {
cache cache.Cache[I, K, V]
command Command
index I
keys []K
}
// DeleteCache decorates the command, if the command is executed without error it will delete the cache entry.
func DeleteCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, index I, keys ...K) Command {
return &deleteCache[I, K, V]{
cache: cache,
command: command,
index: index,
keys: keys,
}
}
var _ Command = (*deleteCache[any, any, cache.Entry[any, any]])(nil)
// Execute implements [Command].
func (s *deleteCache[I, K, V]) Execute(ctx context.Context) error {
if err := s.command.Execute(ctx); err != nil {
return err
}
return s.cache.Delete(ctx, s.index, s.keys...)
}
// Name implements [Command].
func (s *deleteCache[I, K, V]) Name() string {
return s.command.Name()
}
type invalidateCache[I, K comparable, V cache.Entry[I, K]] struct {
cache cache.Cache[I, K, V]
command Command
index I
keys []K
}
// InvalidateCache decorates the command, if the command is executed without error it will invalidate the cache entry.
func InvalidateCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, index I, keys ...K) Command {
return &invalidateCache[I, K, V]{
cache: cache,
command: command,
index: index,
keys: keys,
}
}
var _ Command = (*invalidateCache[any, any, cache.Entry[any, any]])(nil)
// Execute implements [Command].
func (s *invalidateCache[I, K, V]) Execute(ctx context.Context) error {
if err := s.command.Execute(ctx); err != nil {
return err
}
return s.cache.Invalidate(ctx, s.index, s.keys...)
}
// Name implements [Command].
func (s *invalidateCache[I, K, V]) Name() string {
return s.command.Name()
}

View File

@@ -1,6 +1,10 @@
package command
import "context"
import (
"context"
"github.com/zitadel/zitadel/backend/command/receiver/cache"
)
type Command interface {
Execute(context.Context) error
@@ -20,3 +24,8 @@ func (b *Batch) Execute(ctx context.Context) error {
}
return nil
}
type CacheableCommand[I, K comparable, V cache.Entry[I, K]] interface {
Command
Entry() V
}

View File

@@ -65,7 +65,7 @@ func UpdateInstance(receiver receiver.InstanceManipulator, instance *receiver.In
func (u *updateInstance) Execute(ctx context.Context) error {
u.Instance.Name = u.name
// return u.receiver.(ctx, u.Instance)
// return u.receiver.Update(ctx, u.Instance)
return nil
}

View File

@@ -8,21 +8,28 @@ import (
"github.com/zitadel/zitadel/backend/telemetry/logging"
)
type Logger struct {
type logger struct {
level slog.Level
*logging.Logger
cmd Command
}
func Activity(l *logging.Logger, command Command) *Logger {
return &Logger{
// Activity decorates the commands execute method with logging.
// It logs the command name, duration, and success or failure of the command.
func Activity(l *logging.Logger, command Command) Command {
return &logger{
Logger: l.With(slog.String("type", "activity")),
level: slog.LevelInfo,
cmd: command,
}
}
func (l *Logger) Execute(ctx context.Context) error {
// Name implements [Command].
func (l *logger) Name() string {
return l.cmd.Name()
}
func (l *logger) Execute(ctx context.Context) error {
start := time.Now()
log := l.Logger.With(slog.String("command", l.cmd.Name()))
log.InfoContext(ctx, "execute")

View File

@@ -6,12 +6,27 @@ import (
"github.com/zitadel/zitadel/backend/telemetry/tracing"
)
type Trace struct {
type trace struct {
command Command
tracer *tracing.Tracer
}
func (t *Trace) Execute(ctx context.Context) error {
// Trace decorates the commands execute method with tracing.
// It creates a span with the command name and records any errors that occur during execution.
// The span is ended after the command is executed.
func Trace(tracer *tracing.Tracer, command Command) Command {
return &trace{
command: command,
tracer: tracer,
}
}
// Name implements [Command].
func (l *trace) Name() string {
return l.command.Name()
}
func (t *trace) Execute(ctx context.Context) error {
ctx, span := t.tracer.Start(ctx, t.command.Name())
defer span.End()
err := t.command.Execute(ctx)

View File

@@ -1,38 +0,0 @@
package invoker
import (
"context"
"github.com/zitadel/zitadel/backend/command/command"
"github.com/zitadel/zitadel/backend/command/query"
"github.com/zitadel/zitadel/backend/command/receiver"
"github.com/zitadel/zitadel/backend/command/receiver/db"
"github.com/zitadel/zitadel/backend/storage/database"
)
type api struct {
db database.Pool
manipulator receiver.InstanceManipulator
reader receiver.InstanceReader
}
func (a *api) CreateInstance(ctx context.Context) error {
cmd := command.CreateInstance(db.NewInstance(a.db), &receiver.Instance{
ID: "123",
Name: "test",
})
return cmd.Execute(ctx)
}
func (a *api) DeleteInstance(ctx context.Context) error {
cmd := command.DeleteInstance(db.NewInstance(a.db), &receiver.Instance{
ID: "123",
})
return cmd.Execute(ctx)
}
func (a *api) InstanceByID(ctx context.Context) (*receiver.Instance, error) {
q := query.InstanceByID(a.reader, "123")
return q.Execute(ctx)
}

112
backend/command/receiver/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,112 @@
// Package cache provides abstraction of cache implementations that can be used by zitadel.
package cache
import (
"context"
"time"
"github.com/zitadel/logging"
)
// Purpose describes which object types are stored by a cache.
type Purpose int
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
const (
PurposeUnspecified Purpose = iota
PurposeAuthzInstance
PurposeMilestones
PurposeOrganization
PurposeIdPFormCallback
)
// Cache stores objects with a value of type `V`.
// Objects may be referred to by one or more indices.
// Implementations may encode the value for storage.
// This means non-exported fields may be lost and objects
// with function values may fail to encode.
// See https://pkg.go.dev/encoding/json#Marshal for example.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
// Indices are defined when calling the constructor of an implementation of this interface.
// It is illegal to refer to an idex not defined during construction.
//
// `K` is the type used as key in each index.
// Due to the limitations in type constraints, all indices use the same key type.
//
// Implementations are free to use stricter type constraints or fixed typing.
type Cache[I, K comparable, V Entry[I, K]] interface {
// Get an object through specified index.
// An [IndexUnknownError] may be returned if the index is unknown.
// [ErrCacheMiss] is returned if the key was not found in the index,
// or the object is not valid.
Get(ctx context.Context, index I, key K) (V, bool)
// Set an object.
// Keys are created on each index based in the [Entry.Keys] method.
// If any key maps to an existing object, the object is invalidated,
// regardless if the object has other keys defined in the new entry.
// This to prevent ghost objects when an entry reduces the amount of keys
// for a given index.
Set(ctx context.Context, value V)
// Invalidate an object through specified index.
// Implementations may choose to instantly delete the object,
// defer until prune or a separate cleanup routine.
// Invalidated object are no longer returned from Get.
// It is safe to call Invalidate multiple times or on non-existing entries.
Invalidate(ctx context.Context, index I, key ...K) error
// Delete one or more keys from a specific index.
// An [IndexUnknownError] may be returned if the index is unknown.
// The referred object is not invalidated and may still be accessible though
// other indices and keys.
// It is safe to call Delete multiple times or on non-existing entries
Delete(ctx context.Context, index I, key ...K) error
// Truncate deletes all cached objects.
Truncate(ctx context.Context) error
}
// Entry contains a value of type `V` to be cached.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
//
// `K` is the type used as key in an index.
// Due to the limitations in type constraints, all indices use the same key type.
type Entry[I, K comparable] interface {
// Keys returns which keys map to the object in a specified index.
// May return nil if the index in unknown or when there are no keys.
Keys(index I) (key []K)
}
type Connector int
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
const (
// Empty line comment ensures empty string for unspecified value
ConnectorUnspecified Connector = iota //
ConnectorMemory
ConnectorPostgres
ConnectorRedis
)
type Config struct {
Connector Connector
// Age since an object was added to the cache,
// after which the object is considered invalid.
// 0 disables max age checks.
MaxAge time.Duration
// Age since last use (Get) of an object,
// after which the object is considered invalid.
// 0 disables last use age checks.
LastUseAge time.Duration
// Log allows logging of the specific cache.
// By default only errors are logged to stdout.
Log *logging.Config
}

View File

@@ -0,0 +1,49 @@
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
package connector
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/storage/cache"
"github.com/zitadel/zitadel/backend/storage/cache/connector/gomap"
"github.com/zitadel/zitadel/backend/storage/cache/connector/noop"
)
type CachesConfig struct {
Connectors struct {
Memory gomap.Config
}
Instance *cache.Config
Milestones *cache.Config
Organization *cache.Config
IdPFormCallbacks *cache.Config
}
type Connectors struct {
Config CachesConfig
Memory *gomap.Connector
}
func StartConnectors(conf *CachesConfig) (Connectors, error) {
if conf == nil {
return Connectors{}, nil
}
return Connectors{
Config: *conf,
Memory: gomap.NewConnector(conf.Connectors.Memory),
}, nil
}
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
return noop.NewCache[I, K, V](), nil
}
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@@ -0,0 +1,23 @@
package gomap
import (
"github.com/zitadel/zitadel/backend/storage/cache"
)
type Config struct {
Enabled bool
AutoPrune cache.AutoPruneConfig
}
type Connector struct {
Config cache.AutoPruneConfig
}
func NewConnector(config Config) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
Config: config.AutoPrune,
}
}

View File

@@ -0,0 +1,200 @@
package gomap
import (
"context"
"errors"
"log/slog"
"maps"
"os"
"sync"
"sync/atomic"
"time"
"github.com/zitadel/zitadel/backend/storage/cache"
)
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
config *cache.Config
indexMap map[I]*index[K, V]
logger *slog.Logger
}
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
// Object values are stored as-is and there is no encoding or decoding involved.
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
m := &mapCache[I, K, V]{
config: &config,
indexMap: make(map[I]*index[K, V], len(indices)),
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelError,
})),
}
if config.Log != nil {
m.logger = config.Log.Slog()
}
m.logger.InfoContext(background, "map cache logging enabled")
for _, name := range indices {
m.indexMap[name] = &index[K, V]{
config: m.config,
entries: make(map[K]*entry[V]),
}
}
return m
}
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
i, ok := c.indexMap[index]
if !ok {
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
entry, err := i.Get(key)
if err == nil {
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
return entry.value, true
}
if errors.Is(err, cache.ErrCacheMiss) {
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
return value, false
}
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
now := time.Now()
entry := &entry[V]{
value: value,
created: now,
}
entry.lastUse.Store(now.UnixMicro())
for name, i := range c.indexMap {
keys := value.Keys(name)
i.Set(keys, entry)
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
}
}
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Invalidate(keys)
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Delete(keys)
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
for name, index := range c.indexMap {
index.Prune()
c.logger.DebugContext(ctx, "map cache prune", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
for name, index := range c.indexMap {
index.Truncate()
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
}
return nil
}
type index[K comparable, V any] struct {
mutex sync.RWMutex
config *cache.Config
entries map[K]*entry[V]
}
func (i *index[K, V]) Get(key K) (*entry[V], error) {
i.mutex.RLock()
entry, ok := i.entries[key]
i.mutex.RUnlock()
if ok && entry.isValid(i.config) {
return entry, nil
}
return nil, cache.ErrCacheMiss
}
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
c.mutex.Lock()
for _, key := range keys {
c.entries[key] = entry
}
c.mutex.Unlock()
}
func (i *index[K, V]) Invalidate(keys []K) {
i.mutex.RLock()
for _, key := range keys {
if entry, ok := i.entries[key]; ok {
entry.invalid.Store(true)
}
}
i.mutex.RUnlock()
}
func (c *index[K, V]) Delete(keys []K) {
c.mutex.Lock()
for _, key := range keys {
delete(c.entries, key)
}
c.mutex.Unlock()
}
func (c *index[K, V]) Prune() {
c.mutex.Lock()
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
return !entry.isValid(c.config)
})
c.mutex.Unlock()
}
func (c *index[K, V]) Truncate() {
c.mutex.Lock()
c.entries = make(map[K]*entry[V])
c.mutex.Unlock()
}
type entry[V any] struct {
value V
created time.Time
invalid atomic.Bool
lastUse atomic.Int64 // UnixMicro time
}
func (e *entry[V]) isValid(c *cache.Config) bool {
if e.invalid.Load() {
return false
}
now := time.Now()
if c.MaxAge > 0 {
if e.created.Add(c.MaxAge).Before(now) {
e.invalid.Store(true)
return false
}
}
if c.LastUseAge > 0 {
lastUse := e.lastUse.Load()
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
e.invalid.Store(true)
return false
}
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
}
return true
}

View File

@@ -0,0 +1,329 @@
package gomap
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/backend/storage/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
id string
names []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.id}
case testIndexName:
return o.names
default:
return nil
}
}
func Test_mapCache_Get(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
args args
want *testObject
wantOk bool
}{
{
name: "ok",
args: args{
index: testIndexID,
key: "id",
},
want: obj,
wantOk: true,
},
{
name: "miss",
args: args{
index: testIndexID,
key: "spanac",
},
want: nil,
wantOk: false,
},
{
name: "unknown index",
args: args{
index: 99,
key: "id",
},
want: nil,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, ok)
})
}
}
func Test_mapCache_Invalidate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
got, ok := c.Get(context.Background(), testIndexID, "id")
assert.Nil(t, got)
assert.False(t, ok)
}
func Test_mapCache_Delete(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Delete(context.Background(), testIndexName, "bar")
require.NoError(t, err)
// Shouldn't find object by deleted name
got, ok := c.Get(context.Background(), testIndexName, "bar")
assert.Nil(t, got)
assert.False(t, ok)
// Should find object by other name
got, ok = c.Get(context.Background(), testIndexName, "foo")
assert.Equal(t, obj, got)
assert.True(t, ok)
// Should find object by id
got, ok = c.Get(context.Background(), testIndexID, "id")
assert.Equal(t, obj, got)
assert.True(t, ok)
}
func Test_mapCache_Prune(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
// invalidate one entry
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
err = c.(cache.Pruner).Prune(context.Background())
require.NoError(t, err)
// Other object should still be found
got, ok := c.Get(context.Background(), testIndexID, "id2")
assert.Equal(t, objects[1], got)
assert.True(t, ok)
}
func Test_mapCache_Truncate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
err := c.Truncate(context.Background())
require.NoError(t, err)
mc := c.(*mapCache[testIndex, string, *testObject])
for _, index := range mc.indexMap {
index.mutex.RLock()
assert.Len(t, index.entries, 0)
index.mutex.RUnlock()
}
}
func Test_entry_isValid(t *testing.T) {
type fields struct {
created time.Time
invalid bool
lastUse time.Time
}
tests := []struct {
name string
fields fields
config *cache.Config
want bool
}{
{
name: "invalid",
fields: fields{
created: time.Now(),
invalid: true,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
LastUseAge: time.Second,
},
want: true,
},
{
name: "last use age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "last use age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
},
want: true,
},
{
name: "valid",
fields: fields{
created: time.Now(),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &entry[any]{
created: tt.fields.created,
}
e.invalid.Store(tt.fields.invalid)
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
got := e.isValid(tt.config)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,21 @@
package noop
import (
"context"
"github.com/zitadel/zitadel/backend/storage/cache"
)
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
// NewCache returns a cache that does nothing
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
return noop[I, K, V]{}
}
func (noop[I, K, V]) Set(context.Context, V) {}
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }

View File

@@ -0,0 +1,98 @@
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _ConnectorName = "memorypostgresredis"
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
const _ConnectorLowerName = "memorypostgresredis"
func (i Connector) String() string {
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
return fmt.Sprintf("Connector(%d)", i)
}
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _ConnectorNoOp() {
var x [1]struct{}
_ = x[ConnectorUnspecified-(0)]
_ = x[ConnectorMemory-(1)]
_ = x[ConnectorPostgres-(2)]
_ = x[ConnectorRedis-(3)]
}
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
var _ConnectorNameToValueMap = map[string]Connector{
_ConnectorName[0:0]: ConnectorUnspecified,
_ConnectorLowerName[0:0]: ConnectorUnspecified,
_ConnectorName[0:6]: ConnectorMemory,
_ConnectorLowerName[0:6]: ConnectorMemory,
_ConnectorName[6:14]: ConnectorPostgres,
_ConnectorLowerName[6:14]: ConnectorPostgres,
_ConnectorName[14:19]: ConnectorRedis,
_ConnectorLowerName[14:19]: ConnectorRedis,
}
var _ConnectorNames = []string{
_ConnectorName[0:0],
_ConnectorName[0:6],
_ConnectorName[6:14],
_ConnectorName[14:19],
}
// ConnectorString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func ConnectorString(s string) (Connector, error) {
if val, ok := _ConnectorNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Connector values", s)
}
// ConnectorValues returns all values of the enum
func ConnectorValues() []Connector {
return _ConnectorValues
}
// ConnectorStrings returns a slice of all String values of the enum
func ConnectorStrings() []string {
strs := make([]string, len(_ConnectorNames))
copy(strs, _ConnectorNames)
return strs
}
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Connector) IsAConnector() bool {
for _, v := range _ConnectorValues {
if i == v {
return true
}
}
return false
}
// MarshalText implements the encoding.TextMarshaler interface for Connector
func (i Connector) MarshalText() ([]byte, error) {
return []byte(i.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
func (i *Connector) UnmarshalText(text []byte) error {
var err error
*i, err = ConnectorString(string(text))
return err
}

29
backend/command/receiver/cache/error.go vendored Normal file
View File

@@ -0,0 +1,29 @@
package cache
import (
"errors"
"fmt"
)
type IndexUnknownError[I comparable] struct {
index I
}
func NewIndexUnknownErr[I comparable](index I) error {
return IndexUnknownError[I]{index}
}
func (i IndexUnknownError[I]) Error() string {
return fmt.Sprintf("index %v unknown", i.index)
}
func (a IndexUnknownError[I]) Is(err error) bool {
if b, ok := err.(IndexUnknownError[I]); ok {
return a.index == b.index
}
return false
}
var (
ErrCacheMiss = errors.New("cache miss")
)

View File

@@ -0,0 +1,76 @@
package cache
import (
"context"
"math/rand"
"time"
"github.com/jonboulle/clockwork"
"github.com/zitadel/logging"
)
// Pruner is an optional [Cache] interface.
type Pruner interface {
// Prune deletes all invalidated or expired objects.
Prune(ctx context.Context) error
}
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
Cache[I, K, V]
Pruner
}
type AutoPruneConfig struct {
// Interval at which the cache is automatically pruned.
// 0 or lower disables automatic pruning.
Interval time.Duration
// Timeout for an automatic prune.
// It is recommended to keep the value shorter than AutoPruneInterval
// 0 or lower disables automatic pruning.
Timeout time.Duration
}
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
}
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
if c.Interval <= 0 {
return func() {}
}
background, cancel := context.WithCancel(background)
// randomize the first interval
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
go c.pruneTimer(background, pruner, purpose, timer)
return cancel
}
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
defer func() {
if !timer.Stop() {
<-timer.Chan()
}
}()
for {
select {
case <-background.Done():
return
case <-timer.Chan():
err := c.doPrune(background, pruner)
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
timer.Reset(c.Interval)
}
}
}
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
ctx, cancel := context.WithCancel(background)
defer cancel()
if c.Timeout > 0 {
ctx, cancel = context.WithTimeout(background, c.Timeout)
defer cancel()
}
return pruner.Prune(ctx)
}

View File

@@ -0,0 +1,43 @@
package cache
import (
"context"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
)
type testPruner struct {
called chan struct{}
}
func (p *testPruner) Prune(context.Context) error {
p.called <- struct{}{}
return nil
}
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
c := AutoPruneConfig{
Interval: time.Second,
Timeout: time.Millisecond,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pruner := testPruner{
called: make(chan struct{}),
}
clock := clockwork.NewFakeClock()
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
defer close()
clock.Advance(time.Second)
select {
case _, ok := <-pruner.called:
assert.True(t, ok)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}

View File

@@ -0,0 +1,90 @@
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
func (i Purpose) String() string {
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
return fmt.Sprintf("Purpose(%d)", i)
}
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _PurposeNoOp() {
var x [1]struct{}
_ = x[PurposeUnspecified-(0)]
_ = x[PurposeAuthzInstance-(1)]
_ = x[PurposeMilestones-(2)]
_ = x[PurposeOrganization-(3)]
_ = x[PurposeIdPFormCallback-(4)]
}
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
var _PurposeNameToValueMap = map[string]Purpose{
_PurposeName[0:11]: PurposeUnspecified,
_PurposeLowerName[0:11]: PurposeUnspecified,
_PurposeName[11:25]: PurposeAuthzInstance,
_PurposeLowerName[11:25]: PurposeAuthzInstance,
_PurposeName[25:35]: PurposeMilestones,
_PurposeLowerName[25:35]: PurposeMilestones,
_PurposeName[35:47]: PurposeOrganization,
_PurposeLowerName[35:47]: PurposeOrganization,
_PurposeName[47:65]: PurposeIdPFormCallback,
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
}
var _PurposeNames = []string{
_PurposeName[0:11],
_PurposeName[11:25],
_PurposeName[25:35],
_PurposeName[35:47],
_PurposeName[47:65],
}
// PurposeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func PurposeString(s string) (Purpose, error) {
if val, ok := _PurposeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
}
// PurposeValues returns all values of the enum
func PurposeValues() []Purpose {
return _PurposeValues
}
// PurposeStrings returns a slice of all String values of the enum
func PurposeStrings() []string {
strs := make([]string, len(_PurposeNames))
copy(strs, _PurposeNames)
return strs
}
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Purpose) IsAPurpose() bool {
for _, v := range _PurposeValues {
if i == v {
return true
}
}
return false
}

View File

@@ -1,6 +1,10 @@
package receiver
import "context"
import (
"context"
"github.com/zitadel/zitadel/backend/command/receiver/cache"
)
type InstanceState uint8
@@ -16,6 +20,31 @@ type Instance struct {
Domains []*Domain
}
type InstanceIndex uint8
var InstanceIndices = []InstanceIndex{
InstanceByID,
InstanceByDomain,
}
const (
InstanceByID InstanceIndex = iota
InstanceByDomain
)
var _ cache.Entry[InstanceIndex, string] = (*Instance)(nil)
// Keys implements [cache.Entry].
func (i *Instance) Keys(index InstanceIndex) (key []string) {
switch index {
case InstanceByID:
return []string{i.ID}
case InstanceByDomain:
return []string{i.Name}
}
return nil
}
type InstanceManipulator interface {
Create(ctx context.Context, instance *Instance) error
Delete(ctx context.Context, instance *Instance) error

View File

@@ -0,0 +1,6 @@
// The API package implements the protobuf stubs
// It uses the Chain of responsibility pattern to handle requests in a modular way
// It implements the client or invoker of the command pattern.
// The client is responsible for creating the concrete command and setting its receiver.
package api

View File

@@ -0,0 +1,35 @@
package userv2
import (
"context"
"github.com/muhlemmer/gu"
"github.com/zitadel/zitadel/backend/command/v2/domain"
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
request := &domain.SetUserEmail{
UserID: req.GetUserId(),
Email: req.GetEmail(),
}
switch req.GetVerification().(type) {
case *user.SetEmailRequest_IsVerified:
request.IsVerified = gu.Ptr(req.GetIsVerified())
case *user.SetEmailRequest_SendCode:
request.SendCode = &domain.SendCode{
URLTemplate: req.GetSendCode().UrlTemplate,
}
case *user.SetEmailRequest_ReturnCode:
request.ReturnCode = new(domain.ReturnCode)
}
if err := s.domain.SetUserEmail(ctx, request); err != nil {
return nil, err
}
response := new(user.SetEmailResponse)
if request.ReturnCode != nil {
response.VerificationCode = &request.ReturnCode.Code
}
return response, nil
}

View File

@@ -0,0 +1,12 @@
package userv2
import (
"go.opentelemetry.io/otel/trace"
"github.com/zitadel/zitadel/backend/command/v2/domain"
)
type Server struct {
tracer trace.Tracer
domain *domain.Domain
}

View File

@@ -0,0 +1,41 @@
package command
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/pattern"
"github.com/zitadel/zitadel/internal/crypto"
)
type generateCode struct {
set func(code string)
generator pattern.Query[crypto.Generator]
}
func GenerateCode(set func(code string), generator pattern.Query[crypto.Generator]) *generateCode {
return &generateCode{
set: set,
generator: generator,
}
}
var _ pattern.Command = (*generateCode)(nil)
// Execute implements [pattern.Command].
func (cmd *generateCode) Execute(ctx context.Context) error {
if err := cmd.generator.Execute(ctx); err != nil {
return err
}
value, code, err := crypto.NewCode(cmd.generator.Result())
_ = value
if err != nil {
return err
}
cmd.set(code)
return nil
}
// Name implements [pattern.Command].
func (*generateCode) Name() string {
return "command.generate_code"
}

View File

@@ -0,0 +1,41 @@
package command
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/pattern"
)
var _ pattern.Command = (*sendEmailCode)(nil)
type sendEmailCode struct {
UserID string `json:"userId"`
Email string `json:"email"`
URLTemplate *string `json:"urlTemplate"`
code string `json:"-"`
}
func SendEmailCode(userID, email string, urlTemplate *string) pattern.Command {
cmd := &sendEmailCode{
UserID: userID,
Email: email,
URLTemplate: urlTemplate,
}
return pattern.Batch(GenerateCode(cmd.SetCode, generateCode))
}
// Name implements [pattern.Command].
func (c *sendEmailCode) Name() string {
return "user.v2.email.send_code"
}
// Execute implements [pattern.Command].
func (c *sendEmailCode) Execute(ctx context.Context) error {
// Implementation of the command execution
return nil
}
func (c *sendEmailCode) SetCode(code string) {
c.code = code
}

View File

@@ -0,0 +1,39 @@
package command
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/storage/eventstore"
)
var (
_ eventstore.EventCommander = (*setEmail)(nil)
)
type setEmail struct {
UserID string `json:"userId"`
Email string `json:"email"`
}
func SetEmail(userID, email string) *setEmail {
return &setEmail{
UserID: userID,
Email: email,
}
}
// Event implements [eventstore.EventCommander].
func (c *setEmail) Event() *eventstore.Event {
panic("unimplemented")
}
// Name implements [pattern.Command].
func (c *setEmail) Name() string {
return "user.v2.set_email"
}
// Execute implements [pattern.Command].
func (c *setEmail) Execute(ctx context.Context) error {
// Implementation of the command execution
return nil
}

View File

@@ -0,0 +1,32 @@
package command
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/pattern"
)
var _ pattern.Command = (*verifyEmail)(nil)
type verifyEmail struct {
UserID string `json:"userId"`
Email string `json:"email"`
}
func VerifyEmail(userID, email string) *verifyEmail {
return &verifyEmail{
UserID: userID,
Email: email,
}
}
// Name implements [pattern.Command].
func (c *verifyEmail) Name() string {
return "user.v2.verify_email"
}
// Execute implements [pattern.Command].
func (c *verifyEmail) Execute(ctx context.Context) error {
// Implementation of the command execution
return nil
}

View File

@@ -0,0 +1,13 @@
package domain
import (
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
"github.com/zitadel/zitadel/internal/crypto"
"go.opentelemetry.io/otel/trace"
)
type Domain struct {
pool database.Pool
tracer trace.Tracer
userCodeAlg crypto.EncryptionAlgorithm
}

View File

@@ -0,0 +1,6 @@
package domain
type Email struct {
Address string
Verified bool
}

View File

@@ -0,0 +1,42 @@
package query
import (
"context"
"github.com/zitadel/zitadel/internal/crypto"
)
type encryptionConfigReceiver interface {
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
}
type encryptionGenerator struct {
receiver encryptionConfigReceiver
algorithm crypto.EncryptionAlgorithm
res crypto.Generator
}
func QueryEncryptionGenerator(receiver encryptionConfigReceiver, algorithm crypto.EncryptionAlgorithm) *encryptionGenerator {
return &encryptionGenerator{
receiver: receiver,
algorithm: algorithm,
}
}
func (q *encryptionGenerator) Execute(ctx context.Context) error {
config, err := q.receiver.GetEncryptionConfig(ctx)
if err != nil {
return err
}
q.res = crypto.NewEncryptionGenerator(*config, q.algorithm)
return nil
}
func (q *encryptionGenerator) Name() string {
return "query.encryption_generator"
}
func (q *encryptionGenerator) Result() crypto.Generator {
return q.res
}

View File

@@ -0,0 +1,38 @@
package query
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/pattern"
)
var _ pattern.Query[string] = (*returnEmailCode)(nil)
type returnEmailCode struct {
UserID string `json:"userId"`
Email string `json:"email"`
code string `json:"-"`
}
func ReturnEmailCode(userID, email string) *returnEmailCode {
return &returnEmailCode{
UserID: userID,
Email: email,
}
}
// Name implements [pattern.Command].
func (c *returnEmailCode) Name() string {
return "user.v2.email.return_code"
}
// Execute implements [pattern.Command].
func (c *returnEmailCode) Execute(ctx context.Context) error {
// Implementation of the command execution
return nil
}
// Result implements [pattern.Query].
func (c *returnEmailCode) Result() string {
return c.code
}

View File

@@ -0,0 +1,38 @@
package query
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/domain"
"github.com/zitadel/zitadel/backend/command/v2/pattern"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
)
type UserByIDQuery struct {
querier database.Querier
UserID string `json:"userId"`
res *domain.User
}
var _ pattern.Query[*domain.User] = (*UserByIDQuery)(nil)
// Name implements [pattern.Command].
func (q *UserByIDQuery) Name() string {
return "user.v2.by_id"
}
// Execute implements [pattern.Command].
func (q *UserByIDQuery) Execute(ctx context.Context) error {
var res *domain.User
err := q.querier.QueryRow(ctx, "SELECT id, username, email FROM users WHERE id = $1", q.UserID).Scan(&res.ID, &res.Username, &res.Email.Address)
if err != nil {
return err
}
q.res = res
return nil
}
// Result implements [pattern.Query].
func (q *UserByIDQuery) Result() *domain.User {
return q.res
}

View File

@@ -0,0 +1,77 @@
package domain
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/domain/command"
"github.com/zitadel/zitadel/backend/command/v2/domain/query"
"github.com/zitadel/zitadel/backend/command/v2/pattern"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
"github.com/zitadel/zitadel/backend/command/v2/telemetry/tracing"
)
type User struct {
ID string
Username string
Email Email
}
type SetUserEmail struct {
UserID string
Email string
IsVerified *bool
ReturnCode *ReturnCode
SendCode *SendCode
code string
client database.QueryExecutor
}
func (e *SetUserEmail) SetCode(code string) {
e.code = code
}
type ReturnCode struct {
// Code is the code to be sent to the user
Code string
}
type SendCode struct {
// URLTemplate is the template for the URL that is rendered into the message
URLTemplate *string
}
func (d *Domain) SetUserEmail(ctx context.Context, req *SetUserEmail) error {
batch := pattern.Batch(
tracing.Trace(d.tracer, command.SetEmail(req.UserID, req.Email)),
)
if req.IsVerified == nil {
batch.Append(command.GenerateCode(
req.SetCode,
query.QueryEncryptionGenerator(
database.Query(d.pool),
d.userCodeAlg,
),
))
} else {
batch.Append(command.VerifyEmail(req.UserID, req.Email))
}
// if !req.GetVerification().GetIsVerified() {
// batch.
// switch req.GetVerification().(type) {
// case *user.SetEmailRequest_IsVerified:
// batch.Append(tracing.Trace(s.tracer, command.VerifyEmail(req.GetUserId(), req.GetEmail())))
// case *user.SetEmailRequest_SendCode:
// batch.Append(tracing.Trace(s.tracer, command.SendEmailCode(req.GetUserId(), req.GetEmail(), req.GetSendCode().UrlTemplate)))
// case *user.SetEmailRequest_ReturnCode:
// batch.Append(tracing.Trace(s.tracer, query.ReturnEmailCode(req.GetUserId(), req.GetEmail())))
// }
// if err := batch.Execute(ctx); err != nil {
// return nil, err
// }
}

View File

@@ -0,0 +1,100 @@
package pattern
import (
"context"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
)
// Command implements the command pattern.
// It is used to encapsulate a request as an object, thereby allowing for parameterization of clients with queues, requests, and operations.
// The command pattern allows for the decoupling of the sender and receiver of a request.
// It is often used in conjunction with the invoker pattern, which is responsible for executing the command.
// The command pattern is a behavioral design pattern that turns a request into a stand-alone object.
// This object contains all the information about the request.
// The command pattern is useful for implementing undo/redo functionality, logging, and queuing requests.
// It is also useful for implementing the macro command pattern, which allows for the execution of a series of commands as a single command.
// The command pattern is also used in event-driven architectures, where events are encapsulated as commands.
type Command interface {
Execute(ctx context.Context) error
Name() string
}
type Query[T any] interface {
Command
Result() T
}
type Invoker struct{}
// func bla() {
// sync.Pool{
// New: func() any {
// return new(Invoker)
// },
// }
// }
type Transaction struct {
beginner database.Beginner
cmd Command
opts *database.TransactionOptions
}
func (t *Transaction) Execute(ctx context.Context) error {
tx, err := t.beginner.Begin(ctx, t.opts)
if err != nil {
return err
}
defer func() { err = tx.End(ctx, err) }()
return t.cmd.Execute(ctx)
}
func (t *Transaction) Name() string {
return t.cmd.Name()
}
type batch struct {
Commands []Command
}
func Batch(cmds ...Command) *batch {
return &batch{
Commands: cmds,
}
}
func (b *batch) Execute(ctx context.Context) error {
for _, cmd := range b.Commands {
if err := cmd.Execute(ctx); err != nil {
return err
}
}
return nil
}
func (b *batch) Name() string {
return "batch"
}
func (b *batch) Append(cmds ...Command) {
b.Commands = append(b.Commands, cmds...)
}
type NoopCommand struct{}
func (c *NoopCommand) Execute(_ context.Context) error {
return nil
}
func (c *NoopCommand) Name() string {
return "noop"
}
type NoopQuery[T any] struct {
NoopCommand
}
func (q *NoopQuery[T]) Result() T {
var zero T
return zero
}

View File

@@ -0,0 +1,9 @@
package database
import (
"context"
)
type Connector interface {
Connect(ctx context.Context) (Pool, error)
}

View File

@@ -0,0 +1,54 @@
package database
import (
"context"
)
var (
db *database
)
type database struct {
connector Connector
pool Pool
}
type Pool interface {
Beginner
QueryExecutor
Acquire(ctx context.Context) (Client, error)
Close(ctx context.Context) error
}
type Client interface {
Beginner
QueryExecutor
Release(ctx context.Context) error
}
type Querier interface {
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
QueryRow(ctx context.Context, stmt string, args ...any) Row
}
type Executor interface {
Exec(ctx context.Context, stmt string, args ...any) error
}
type Row interface {
Scan(dest ...any) error
}
type Rows interface {
Row
Next() bool
Close() error
Err() error
}
type QueryExecutor interface {
Querier
Executor
}

View File

@@ -0,0 +1,92 @@
package dialect
import (
"context"
"errors"
"reflect"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
"github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
)
type Hook struct {
Match func(string) bool
Decode func(config any) (database.Connector, error)
Name string
Constructor func() database.Connector
}
var hooks = []Hook{
{
Match: postgres.NameMatcher,
Decode: postgres.DecodeConfig,
Name: postgres.Name,
Constructor: func() database.Connector { return new(postgres.Config) },
},
// {
// Match: gosql.NameMatcher,
// Decode: gosql.DecodeConfig,
// Name: gosql.Name,
// Constructor: func() database.Connector { return new(gosql.Config) },
// },
}
type Config struct {
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
connector database.Connector
}
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
if len(c.Dialects) != 1 {
return nil, errors.New("Exactly one dialect must be configured")
}
return c.connector.Connect(ctx)
}
// Hooks implements [configure.Unmarshaller].
func (c Config) Hooks() []viper.DecoderConfigOption {
return []viper.DecoderConfigOption{
viper.DecodeHook(decodeHook),
}
}
func decodeHook(from, to reflect.Value) (_ any, err error) {
if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil
}
config := new(Config)
if err = mapstructure.Decode(from.Interface(), config); err != nil {
return nil, err
}
if err = config.decodeDialect(); err != nil {
return nil, err
}
return config, nil
}
func (c *Config) decodeDialect() error {
for _, hook := range hooks {
for name, config := range c.Dialects {
if !hook.Match(name) {
continue
}
connector, err := hook.Decode(config)
if err != nil {
return err
}
c.connector = connector
return nil
}
}
return errors.New("no dialect found")
}

View File

@@ -0,0 +1,80 @@
package postgres
import (
"context"
"errors"
"slices"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
)
var (
_ database.Connector = (*Config)(nil)
Name = "postgres"
)
type Config struct {
config *pgxpool.Config
// Host string
// Port int32
// Database string
// MaxOpenConns uint32
// MaxIdleConns uint32
// MaxConnLifetime time.Duration
// MaxConnIdleTime time.Duration
// User User
// // Additional options to be appended as options=<Options>
// // The value will be taken as is. Multiple options are space separated.
// Options string
configuredFields []string
}
// Connect implements [database.Connector].
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
pool, err := pgxpool.NewWithConfig(ctx, c.config)
if err != nil {
return nil, err
}
if err = pool.Ping(ctx); err != nil {
return nil, err
}
return &pgxPool{pool}, nil
}
func NameMatcher(name string) bool {
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
}
func DecodeConfig(input any) (database.Connector, error) {
switch c := input.(type) {
case string:
config, err := pgxpool.ParseConfig(c)
if err != nil {
return nil, err
}
return &Config{config: config}, nil
case map[string]any:
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
if err = decoder.Decode(c); err != nil {
return nil, err
}
return &Config{
config: &pgxpool.Config{},
}, nil
}
return nil, errors.New("invalid configuration")
}

View File

@@ -0,0 +1,48 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
)
type pgxConn struct{ *pgxpool.Conn }
var _ database.Client = (*pgxConn)(nil)
// Release implements [database.Client].
func (c *pgxConn) Release(_ context.Context) error {
c.Conn.Release()
return nil
}
// Begin implements [database.Client].
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Query implements sql.Client.
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Conn.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Conn.Exec(ctx, sql, args...)
return err
}

View File

@@ -0,0 +1,57 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
)
type pgxPool struct{ *pgxpool.Pool }
var _ database.Pool = (*pgxPool)(nil)
// Acquire implements [database.Pool].
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Pool.Acquire(ctx)
if err != nil {
return nil, err
}
return &pgxConn{conn}, nil
}
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Pool.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Pool.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Pool.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Pool].
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Close implements [database.Pool].
func (c *pgxPool) Close(_ context.Context) error {
c.Pool.Close()
return nil
}

View File

@@ -0,0 +1,18 @@
package postgres
import (
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
)
var _ database.Rows = (*Rows)(nil)
type Rows struct{ pgx.Rows }
// Close implements [database.Rows].
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
func (r *Rows) Close() error {
r.Rows.Close()
return nil
}

View File

@@ -0,0 +1,95 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/command/v2/storage/database"
)
type pgxTx struct{ pgx.Tx }
var _ database.Transaction = (*pgxTx)(nil)
// Commit implements [database.Transaction].
func (tx *pgxTx) Commit(ctx context.Context) error {
return tx.Tx.Commit(ctx)
}
// Rollback implements [database.Transaction].
func (tx *pgxTx) Rollback(ctx context.Context) error {
return tx.Tx.Rollback(ctx)
}
// End implements [database.Transaction].
func (tx *pgxTx) End(ctx context.Context, err error) error {
if err != nil {
tx.Rollback(ctx)
return err
}
return tx.Commit(ctx)
}
// Query implements [database.Transaction].
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := tx.Tx.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return tx.Tx.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
_, err := tx.Tx.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Transaction].
// As postgres does not support nested transactions we use savepoints to emulate them.
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
savepoint, err := tx.Tx.Begin(ctx)
if err != nil {
return nil, err
}
return &pgxTx{savepoint}, nil
}
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
if opts == nil {
return pgx.TxOptions{}
}
return pgx.TxOptions{
IsoLevel: isolationToPgx(opts.IsolationLevel),
AccessMode: accessModeToPgx(opts.AccessMode),
}
}
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
switch isolation {
case database.IsolationLevelSerializable:
return pgx.Serializable
case database.IsolationLevelReadCommitted:
return pgx.ReadCommitted
default:
return pgx.Serializable
}
}
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
switch accessMode {
case database.AccessModeReadWrite:
return pgx.ReadWrite
case database.AccessModeReadOnly:
return pgx.ReadOnly
default:
return pgx.ReadWrite
}
}

View File

@@ -0,0 +1,39 @@
package database
import (
"context"
"github.com/zitadel/zitadel/internal/crypto"
)
type query struct{ Querier }
func Query(querier Querier) *query {
return &query{Querier: querier}
}
const getEncryptionConfigQuery = "SELECT" +
" length" +
", expiry" +
", should_include_lower_letters" +
", should_include_upper_letters" +
", should_include_digits" +
", should_include_symbols" +
" FROM encryption_config"
func (q query) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) {
var config crypto.GeneratorConfig
row := q.QueryRow(ctx, getEncryptionConfigQuery)
err := row.Scan(
&config.Length,
&config.Expiry,
&config.IncludeLowerLetters,
&config.IncludeUpperLetters,
&config.IncludeDigits,
&config.IncludeSymbols,
)
if err != nil {
return nil, err
}
return &config, nil
}

View File

@@ -0,0 +1,36 @@
package database
import "context"
type Transaction interface {
Commit(ctx context.Context) error
Rollback(ctx context.Context) error
End(ctx context.Context, err error) error
Begin(ctx context.Context) (Transaction, error)
QueryExecutor
}
type Beginner interface {
Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error)
}
type TransactionOptions struct {
IsolationLevel IsolationLevel
AccessMode AccessMode
}
type IsolationLevel uint8
const (
IsolationLevelSerializable IsolationLevel = iota
IsolationLevelReadCommitted
)
type AccessMode uint8
const (
AccessModeReadWrite AccessMode = iota
AccessModeReadOnly
)

View File

@@ -0,0 +1,13 @@
package eventstore
import "github.com/zitadel/zitadel/backend/command/v2/pattern"
type Event struct {
AggregateType string `json:"aggregateType"`
AggregateID string `json:"aggregateId"`
}
type EventCommander interface {
pattern.Command
Event() *Event
}

View File

@@ -0,0 +1,55 @@
package tracing
import (
"context"
"go.opentelemetry.io/otel/trace"
"github.com/zitadel/zitadel/backend/command/v2/pattern"
)
type command struct {
trace.Tracer
cmd pattern.Command
}
func Trace(tracer trace.Tracer, cmd pattern.Command) pattern.Command {
return &command{
Tracer: tracer,
cmd: cmd,
}
}
func (cmd *command) Name() string {
return cmd.cmd.Name()
}
func (cmd *command) Execute(ctx context.Context) error {
ctx, span := cmd.Tracer.Start(ctx, cmd.Name())
defer span.End()
err := cmd.cmd.Execute(ctx)
if err != nil {
span.RecordError(err)
}
return err
}
type query[T any] struct {
command
query pattern.Query[T]
}
func Query[T any](tracer trace.Tracer, q pattern.Query[T]) pattern.Query[T] {
return &query[T]{
command: command{
Tracer: tracer,
cmd: q,
},
query: q,
}
}
func (q *query[T]) Result() T {
return q.query.Result()
}

View File

@@ -0,0 +1,19 @@
package v2
import (
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
var (
logger logging.Logger
tracer tracing.Tracer
)
func SetLogger(l logging.Logger) {
logger = l
}
func SetTracer(t tracing.Tracer) {
tracer = t
}

View File

@@ -0,0 +1,93 @@
package userv2
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/pkg/grpc/user/v2"
)
func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
var (
verification domain.SetEmailOpt
returnCode *domain.ReturnCodeCommand
)
switch req.GetVerification().(type) {
case *user.SetEmailRequest_IsVerified:
verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified())
case *user.SetEmailRequest_SendCode:
verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
case *user.SetEmailRequest_ReturnCode:
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
verification = returnCode
default:
verification = domain.NewSendCodeCommand(req.GetUserId(), nil)
}
err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification))
if err != nil {
return nil, err
}
var code *string
if returnCode != nil && returnCode.Code != "" {
code = &returnCode.Code
}
return &user.SetEmailResponse{
VerificationCode: code,
}, nil
}
func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
var (
returnCode *domain.ReturnCodeCommand
cmd domain.Commander
)
switch req.GetVerification().(type) {
case *user.SendEmailCodeRequest_SendCode:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
case *user.SendEmailCodeRequest_ReturnCode:
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
cmd = returnCode
default:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
}
err = domain.Invoke(ctx, cmd)
if err != nil {
return nil, err
}
resp = new(user.SendEmailCodeResponse)
if returnCode != nil {
resp.VerificationCode = &returnCode.Code
}
return resp, nil
}
func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
var (
returnCode *domain.ReturnCodeCommand
cmd domain.Commander
)
switch req.GetVerification().(type) {
case *user.ResendEmailCodeRequest_SendCode:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
case *user.ResendEmailCodeRequest_ReturnCode:
returnCode = domain.NewReturnCodeCommand(req.GetUserId())
cmd = returnCode
default:
cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
}
err = domain.Invoke(ctx, cmd)
if err != nil {
return nil, err
}
resp = new(user.SendEmailCodeResponse)
if returnCode != nil {
resp.VerificationCode = &returnCode.Code
}
return resp, nil
}

View File

@@ -0,0 +1,19 @@
package userv2
import (
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
var (
logger logging.Logger
tracer tracing.Tracer
)
func SetLogger(l logging.Logger) {
logger = l
}
func SetTracer(t tracing.Tracer) {
tracer = t
}

12
backend/v3/doc.go Normal file
View File

@@ -0,0 +1,12 @@
// the test used the manly relies on the following patterns:
// - domain:
// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package
// - command pattern which implements the changes
// - the invoker decorates the commands by checking for events and tracing
// - the database connections are manged in this package
// - the database connections are passed to the repositories
//
// - storage:
// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package
// - the repositories are used by the domain package to access the database
package v3

View File

@@ -0,0 +1,105 @@
package domain
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Commander interface {
Execute(ctx context.Context, opts *CommandOpts) (err error)
}
type Invoker interface {
Invoke(ctx context.Context, command Commander, opts *CommandOpts) error
}
type CommandOpts struct {
DB database.QueryExecutor
Invoker Invoker
}
type ensureTxOpts struct {
*database.TransactionOptions
}
type EnsureTransactionOpt func(*ensureTxOpts)
// EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction.
// The returned close function will end the transaction. If the DB is already a transaction, the close function
// will do nothing because another [Commander] is already responsible for ending the transaction.
func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) {
beginner, ok := o.DB.(database.Beginner)
if !ok {
// db is already a transaction
return func(_ context.Context, err error) error {
return err
}, nil
}
txOpts := &ensureTxOpts{
TransactionOptions: new(database.TransactionOptions),
}
for _, opt := range opts {
opt(txOpts)
}
tx, err := beginner.Begin(ctx, txOpts.TransactionOptions)
if err != nil {
return nil, err
}
o.DB = tx
return func(ctx context.Context, err error) error {
return tx.End(ctx, err)
}, nil
}
// EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool].
// The returned close function will release the client. If the o.DB is already a client or transaction, the close function
// will do nothing because another [Commander] is already responsible for releasing the client.
func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) {
pool, ok := o.DB.(database.Pool)
if !ok {
// o.DB is already a client
return func(_ context.Context) error {
return nil
}, nil
}
client, err := pool.Acquire(ctx)
if err != nil {
return nil, err
}
o.DB = client
return func(ctx context.Context) error {
return client.Release(ctx)
}, nil
}
func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error {
if o.Invoker == nil {
return command.Execute(ctx, o)
}
return o.Invoker.Invoke(ctx, command, o)
}
func DefaultOpts(invoker Invoker) *CommandOpts {
if invoker == nil {
invoker = &noopInvoker{}
}
return &CommandOpts{
DB: pool,
Invoker: invoker,
}
}
type noopInvoker struct {
next Invoker
}
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
if i.next != nil {
return i.next.Invoke(ctx, command, opts)
}
return command.Execute(ctx, opts)
}

View File

@@ -0,0 +1,76 @@
package domain
import (
"context"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
)
type CreateUserCommand struct {
user *User
email *SetEmailCommand
}
var (
_ Commander = (*CreateUserCommand)(nil)
_ eventer = (*CreateUserCommand)(nil)
)
func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand {
cmd := &CreateUserCommand{
user: &User{
User: v4.User{
Username: username,
Traits: &v4.Human{},
},
},
}
for _, opt := range opts {
opt.applyOnCreateHuman(cmd)
}
return cmd
}
// Events implements [eventer].
func (c *CreateUserCommand) Events() []*eventstore.Event {
panic("unimplemented")
}
// Execute implements [Commander].
func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := c.ensureUserID(); err != nil {
return err
}
c.email.UserID = c.user.ID
if err := opts.Invoke(ctx, c.email); err != nil {
return err
}
return nil
}
type CreateHumanOpt interface {
applyOnCreateHuman(*CreateUserCommand)
}
type createHumanIDOpt string
// applyOnCreateHuman implements [CreateHumanOpt].
func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) {
cmd.user.ID = string(c)
}
var _ CreateHumanOpt = (*createHumanIDOpt)(nil)
func CreateHumanWithID(id string) CreateHumanOpt {
return createHumanIDOpt(id)
}
func (c *CreateUserCommand) ensureUserID() (err error) {
if c.user.ID != "" {
return nil
}
c.user.ID, err = generateID()
return err
}

View File

@@ -0,0 +1,26 @@
package domain
import (
"context"
"github.com/zitadel/zitadel/internal/crypto"
)
type generateCodeCommand struct {
code string
value *crypto.CryptoValue
}
type CryptoRepository interface {
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
}
func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
if err != nil {
return err
}
generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm)
cmd.value, cmd.code, err = crypto.NewCode(generator)
return err
}

View File

@@ -0,0 +1,52 @@
package domain
import (
"math/rand/v2"
"strconv"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
"github.com/zitadel/zitadel/internal/crypto"
)
var (
pool database.Pool
userCodeAlgorithm crypto.EncryptionAlgorithm
tracer tracing.Tracer
// userRepo func(database.QueryExecutor) UserRepository
instanceRepo func(database.QueryExecutor) InstanceRepository
cryptoRepo func(database.QueryExecutor) CryptoRepository
orgRepo func(database.QueryExecutor) OrgRepository
instanceCache cache.Cache[string, string, *Instance]
generateID func() (string, error) = func() (string, error) {
return strconv.FormatUint(rand.Uint64(), 10), nil
}
)
func SetPool(p database.Pool) {
pool = p
}
func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) {
userCodeAlgorithm = algorithm
}
func SetTracer(t tracing.Tracer) {
tracer = t
}
// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
// userRepo = repo
// }
func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
instanceRepo = repo
}
func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) {
cryptoRepo = repo
}

View File

@@ -0,0 +1,45 @@
package domain_test
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
. "github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
func TestExample(t *testing.T) {
ctx := context.Background()
// SetPool(pool)
exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
require.NoError(t, err)
tracerProvider := sdktrace.NewTracerProvider(
sdktrace.WithSyncer(exporter),
)
otel.SetTracerProvider(tracerProvider)
SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
SetUserRepository(repository.User)
SetInstanceRepository(repository.Instance)
SetCryptoRepository(repository.Crypto)
t.Run("verified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
assert.NoError(t, err)
})
t.Run("unverified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
assert.NoError(t, err)
})
}

View File

@@ -0,0 +1,155 @@
package domain
import (
"context"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
)
type EmailVerifiedCommand struct {
UserID string `json:"userId"`
Email *Email `json:"email"`
}
func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand {
return &EmailVerifiedCommand{
UserID: userID,
Email: &Email{
IsVerified: isVerified,
},
}
}
var (
_ Commander = (*EmailVerifiedCommand)(nil)
_ SetEmailOpt = (*EmailVerifiedCommand)(nil)
)
// Execute implements [Commander]
func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error {
return userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmailVerified(ctx, cmd.Email.Address)
}
// applyOnSetEmail implements [SetEmailOpt]
func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
cmd.UserID = setEmailCmd.UserID
cmd.Email.Address = setEmailCmd.Email
setEmailCmd.verification = cmd
}
type SendCodeCommand struct {
UserID string `json:"userId"`
Email string `json:"email"`
URLTemplate *string `json:"urlTemplate"`
generator *generateCodeCommand
}
var (
_ Commander = (*SendCodeCommand)(nil)
_ SetEmailOpt = (*SendCodeCommand)(nil)
)
func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
return &SendCodeCommand{
UserID: userID,
generator: &generateCodeCommand{},
URLTemplate: urlTemplate,
}
}
// Execute implements [Commander]
func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil {
return err
}
if err := cmd.ensureURL(ctx, opts); err != nil {
return err
}
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
return err
}
// TODO: queue notification
return nil
}
func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
if cmd.Email != "" {
return nil
}
email, err := userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().GetEmail(ctx)
if err != nil || email.IsVerified {
return err
}
cmd.Email = email.Address
return nil
}
func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error {
if cmd.URLTemplate != nil && *cmd.URLTemplate != "" {
return nil
}
_, _ = ctx, opts
// TODO: load default template
return nil
}
// applyOnSetEmail implements [SetEmailOpt]
func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
cmd.UserID = setEmailCmd.UserID
cmd.Email = setEmailCmd.Email
setEmailCmd.verification = cmd
}
type ReturnCodeCommand struct {
UserID string `json:"userId"`
Email string `json:"email"`
Code string `json:"code"`
generator *generateCodeCommand
}
var (
_ Commander = (*ReturnCodeCommand)(nil)
_ SetEmailOpt = (*ReturnCodeCommand)(nil)
)
func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
return &ReturnCodeCommand{
UserID: userID,
generator: &generateCodeCommand{},
}
}
// Execute implements [Commander]
func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil {
return err
}
if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
return err
}
cmd.Code = cmd.generator.code
return nil
}
func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
if cmd.Email != "" {
return nil
}
user := v4.UserRepository(opts.DB)
user.WithCondition(user.IDCondition(cmd.UserID))
email, err := user.he.GetEmail(ctx)
if err != nil || email.IsVerified {
return err
}
cmd.Email = email.Address
return nil
}
// applyOnSetEmail implements [SetEmailOpt]
func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
cmd.UserID = setEmailCmd.UserID
cmd.Email = setEmailCmd.Email
setEmailCmd.verification = cmd
}

View File

@@ -0,0 +1,7 @@
package domain
import "errors"
var (
ErrNoAdminSpecified = errors.New("at least one admin must be specified")
)

View File

@@ -0,0 +1,36 @@
package domain
import (
"context"
"time"
)
type Instance struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"-"`
UpdatedAt time.Time `json:"-"`
DeletedAt time.Time `json:"-"`
}
// Keys implements the [cache.Entry].
func (i *Instance) Keys(index string) (key []string) {
// TODO: Return the correct keys for the instance cache, e.g., i.ID, i.Domain
return []string{}
}
type InstanceRepository interface {
ByID(ctx context.Context, id string) (*Instance, error)
Create(ctx context.Context, instance *Instance) error
On(id string) InstanceOperation
}
type InstanceOperation interface {
AdminRepository
Update(ctx context.Context, instance *Instance) error
Delete(ctx context.Context) error
}
type CreateInstance struct {
Name string `json:"name"`
}

View File

@@ -0,0 +1,94 @@
package domain
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
)
var defaultInvoker = newEventStoreInvoker(newTraceInvoker(nil))
func Invoke(ctx context.Context, cmd Commander) error {
invoker := newEventStoreInvoker(newTraceInvoker(nil))
opts := &CommandOpts{
Invoker: invoker.collector,
}
return invoker.Invoke(ctx, cmd, opts)
}
type eventStoreInvoker struct {
collector *eventCollector
}
func newEventStoreInvoker(next Invoker) *eventStoreInvoker {
return &eventStoreInvoker{collector: &eventCollector{next: next}}
}
func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
err = i.collector.Invoke(ctx, command, opts)
if err != nil {
return err
}
if len(i.collector.events) > 0 {
err = eventstore.Publish(ctx, i.collector.events, opts.DB)
if err != nil {
return err
}
}
return nil
}
type eventCollector struct {
next Invoker
events []*eventstore.Event
}
type eventer interface {
Events() []*eventstore.Event
}
func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
if e, ok := command.(eventer); ok && len(e.Events()) > 0 {
// we need to ensure all commands are executed in the same transaction
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
i.events = append(i.events, e.Events()...)
}
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
}
if err != nil {
return err
}
return nil
}
type traceInvoker struct {
next Invoker
}
func newTraceInvoker(next Invoker) *traceInvoker {
return &traceInvoker{next: next}
}
func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
defer span.End()
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
}
if err != nil {
span.RecordError(err)
}
return err
}

39
backend/v3/domain/org.go Normal file
View File

@@ -0,0 +1,39 @@
package domain
import (
"context"
"time"
)
type Org struct {
ID string `json:"id"`
Name string `json:"name"`
CreatedAt time.Time `json:"createdAt"`
UpdatedAt time.Time `json:"updatedAt"`
}
type OrgRepository interface {
ByID(ctx context.Context, orgID string) (*Org, error)
Create(ctx context.Context, org *Org) error
On(id string) OrgOperation
}
type OrgOperation interface {
AdminRepository
DomainRepository
Update(ctx context.Context, org *Org) error
Delete(ctx context.Context) error
}
type AdminRepository interface {
AddAdmin(ctx context.Context, userID string, roles []string) error
SetAdminRoles(ctx context.Context, userID string, roles []string) error
RemoveAdmin(ctx context.Context, userID string) error
}
type DomainRepository interface {
AddDomain(ctx context.Context, domain string) error
SetDomainVerified(ctx context.Context, domain string) error
RemoveDomain(ctx context.Context, domain string) error
}

View File

@@ -0,0 +1,74 @@
package domain
import (
"context"
)
type AddOrgCommand struct {
ID string `json:"id"`
Name string `json:"name"`
Admins []AddAdminCommand `json:"admins"`
}
func NewAddOrgCommand(name string, admins ...AddAdminCommand) *AddOrgCommand {
return &AddOrgCommand{
Name: name,
Admins: admins,
}
}
// Execute implements Commander.
func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
if len(cmd.Admins) == 0 {
return ErrNoAdminSpecified
}
if err = cmd.ensureID(); err != nil {
return err
}
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
err = orgRepo(opts.DB).Create(ctx, &Org{
ID: cmd.ID,
Name: cmd.Name,
})
if err != nil {
return err
}
return nil
}
var (
_ Commander = (*AddOrgCommand)(nil)
)
func (cmd *AddOrgCommand) ensureID() (err error) {
if cmd.ID != "" {
return nil
}
cmd.ID, err = generateID()
return err
}
type AddAdminCommand struct {
UserID string `json:"userId"`
Roles []string `json:"roles"`
}
// Execute implements Commander.
func (a *AddAdminCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
return nil
}
var (
_ Commander = (*AddAdminCommand)(nil)
)

View File

@@ -0,0 +1,82 @@
package domain
import (
"time"
"golang.org/x/exp/constraints"
)
type Operation interface {
// TextOperation |
// NumberOperation |
// BoolOperation
op()
}
type clause[F ~uint8, Op Operation] struct {
field F
op Op
}
func (c *clause[F, Op]) Field() F {
return c.field
}
func (c *clause[F, Op]) Operation() Op {
return c.op
}
type Text interface {
~string | ~[]byte
}
type TextOperation uint8
const (
TextOperationEqual TextOperation = iota
TextOperationNotEqual
TextOperationStartsWith
TextOperationStartsWithIgnoreCase
)
func (TextOperation) op() {}
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time
}
type NumberOperation uint8
const (
NumberOperationEqual NumberOperation = iota
NumberOperationNotEqual
NumberOperationLessThan
NumberOperationLessThanOrEqual
NumberOperationGreaterThan
NumberOperationGreaterThanOrEqual
)
func (NumberOperation) op() {}
type Bool interface {
~bool
}
type BoolOperation uint8
const (
BoolOperationIs BoolOperation = iota
BoolOperationNot
)
func (BoolOperation) op() {}
type ListOperation uint8
const (
ListOperationContains ListOperation = iota
ListOperationNotContains
)
func (ListOperation) op() {}

View File

@@ -0,0 +1,64 @@
package domain
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
)
type SetEmailCommand struct {
UserID string `json:"userId"`
Email string `json:"email"`
verification Commander
}
var (
_ Commander = (*SetEmailCommand)(nil)
_ eventer = (*SetEmailCommand)(nil)
_ CreateHumanOpt = (*SetEmailCommand)(nil)
)
type SetEmailOpt interface {
applyOnSetEmail(*SetEmailCommand)
}
func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand {
cmd := &SetEmailCommand{
UserID: userID,
Email: email,
}
verificationType.applyOnSetEmail(cmd)
return cmd
}
func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
close, err := opts.EnsureTx(ctx)
if err != nil {
return err
}
defer func() { err = close(ctx, err) }()
// userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email)
err = userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmail(ctx, cmd.Email)
if err != nil {
return err
}
return opts.Invoke(ctx, cmd.verification)
}
// Events implements [eventer].
func (cmd *SetEmailCommand) Events() []*eventstore.Event {
return []*eventstore.Event{
{
AggregateType: "user",
AggregateID: cmd.UserID,
Type: "user.email.set",
Payload: cmd,
},
}
}
// applyOnCreateHuman implements [CreateHumanOpt].
func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand[Human]) {
createUserCmd.email = cmd
}

193
backend/v3/domain/user.go Normal file
View File

@@ -0,0 +1,193 @@
package domain
import (
"context"
"time"
v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4"
)
type userColumns interface {
// TODO: move v4.columns to domain
InstanceIDColumn() column
OrgIDColumn() column
IDColumn() column
usernameColumn() column
CreatedAtColumn() column
UpdatedAtColumn() column
DeletedAtColumn() column
}
type userConditions interface {
InstanceIDCondition(instanceID string) v4.Condition
OrgIDCondition(orgID string) v4.Condition
IDCondition(userID string) v4.Condition
UsernameCondition(op v4.TextOperator, username string) v4.Condition
CreatedAtCondition(op v4.NumberOperator, createdAt time.Time) v4.Condition
UpdatedAtCondition(op v4.NumberOperator, updatedAt time.Time) v4.Condition
DeletedCondition(isDeleted bool) v4.Condition
DeletedAtCondition(op v4.NumberOperator, deletedAt time.Time) v4.Condition
}
type UserRepository interface {
userColumns
userConditions
// TODO: move condition to domain
WithCondition(condition v4.Condition) UserRepository
Get(ctx context.Context) (*User, error)
List(ctx context.Context) ([]*User, error)
Create(ctx context.Context, user *User) error
Delete(ctx context.Context) error
Human() HumanRepository
Machine() MachineRepository
}
type humanColumns interface {
FirstNameColumn() column
LastNameColumn() column
EmailAddressColumn() column
EmailVerifiedAtColumn() column
PhoneNumberColumn() column
PhoneVerifiedAtColumn() column
}
type humanConditions interface {
FirstNameCondition(op v4.TextOperator, firstName string) v4.Condition
LastNameCondition(op v4.TextOperator, lastName string) v4.Condition
EmailAddressCondition(op v4.TextOperator, email string) v4.Condition
EmailAddressVerifiedCondition(isVerified bool) v4.Condition
EmailVerifiedAtCondition(op v4.TextOperator, emailVerifiedAt string) v4.Condition
PhoneNumberCondition(op v4.TextOperator, phoneNumber string) v4.Condition
PhoneNumberVerifiedCondition(isVerified bool) v4.Condition
PhoneVerifiedAtCondition(op v4.TextOperator, phoneVerifiedAt string) v4.Condition
}
type HumanRepository interface {
humanColumns
humanConditions
GetEmail(ctx context.Context) (*Email, error)
// TODO: replace any with add email update columns
SetEmail(ctx context.Context, columns ...any) error
}
type machineColumns interface {
DescriptionColumn() column
}
type machineConditions interface {
DescriptionCondition(op v4.TextOperator, description string) v4.Condition
}
type MachineRepository interface {
machineColumns
machineConditions
}
// type UserRepository interface {
// // Get(ctx context.Context, clauses ...UserClause) (*User, error)
// // Search(ctx context.Context, clauses ...UserClause) ([]*User, error)
// UserQuery[UserOperation]
// Human() HumanQuery
// Machine() MachineQuery
// }
// type UserQuery[Op UserOperation] interface {
// ByID(id string) UserQuery[Op]
// Username(username string) UserQuery[Op]
// Exec() Op
// }
// type HumanQuery interface {
// UserQuery[HumanOperation]
// Email(op TextOperation, email string) HumanQuery
// HumanOperation
// }
// type MachineQuery interface {
// UserQuery[MachineOperation]
// MachineOperation
// }
// type UserClause interface {
// Field() UserField
// Operation() Operation
// Args() []any
// }
// type UserField uint8
// const (
// // Fields used for all users
// UserFieldInstanceID UserField = iota + 1
// UserFieldOrgID
// UserFieldID
// UserFieldUsername
// // Fields used for human users
// UserHumanFieldEmail
// UserHumanFieldEmailVerified
// // Fields used for machine users
// UserMachineFieldDescription
// )
// type userByIDClause struct {
// id string
// }
// func (c *userByIDClause) Field() UserField {
// return UserFieldID
// }
// func (c *userByIDClause) Operation() Operation {
// return TextOperationEqual
// }
// func (c *userByIDClause) Args() []any {
// return []any{c.id}
// }
// type UserOperation interface {
// Delete(ctx context.Context) error
// SetUsername(ctx context.Context, username string) error
// }
// type HumanOperation interface {
// UserOperation
// SetEmail(ctx context.Context, email string) error
// SetEmailVerified(ctx context.Context, email string) error
// GetEmail(ctx context.Context) (*Email, error)
// }
// type MachineOperation interface {
// UserOperation
// SetDescription(ctx context.Context, description string) error
// }
type User struct {
v4.User
}
// type userTraits interface {
// isUserTraits()
// }
// type Human struct {
// Email *Email `json:"email"`
// }
// func (*Human) isUserTraits() {}
// type Machine struct {
// Description string `json:"description"`
// }
// func (*Machine) isUserTraits() {}
// type Email struct {
// Address string `json:"address"`
// IsVerified bool `json:"isVerified"`
// }

112
backend/v3/storage/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,112 @@
// Package cache provides abstraction of cache implementations that can be used by zitadel.
package cache
import (
"context"
"time"
"github.com/zitadel/logging"
)
// Purpose describes which object types are stored by a cache.
type Purpose int
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
const (
PurposeUnspecified Purpose = iota
PurposeAuthzInstance
PurposeMilestones
PurposeOrganization
PurposeIdPFormCallback
)
// Cache stores objects with a value of type `V`.
// Objects may be referred to by one or more indices.
// Implementations may encode the value for storage.
// This means non-exported fields may be lost and objects
// with function values may fail to encode.
// See https://pkg.go.dev/encoding/json#Marshal for example.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
// Indices are defined when calling the constructor of an implementation of this interface.
// It is illegal to refer to an idex not defined during construction.
//
// `K` is the type used as key in each index.
// Due to the limitations in type constraints, all indices use the same key type.
//
// Implementations are free to use stricter type constraints or fixed typing.
type Cache[I, K comparable, V Entry[I, K]] interface {
// Get an object through specified index.
// An [IndexUnknownError] may be returned if the index is unknown.
// [ErrCacheMiss] is returned if the key was not found in the index,
// or the object is not valid.
Get(ctx context.Context, index I, key K) (V, bool)
// Set an object.
// Keys are created on each index based in the [Entry.Keys] method.
// If any key maps to an existing object, the object is invalidated,
// regardless if the object has other keys defined in the new entry.
// This to prevent ghost objects when an entry reduces the amount of keys
// for a given index.
Set(ctx context.Context, value V)
// Invalidate an object through specified index.
// Implementations may choose to instantly delete the object,
// defer until prune or a separate cleanup routine.
// Invalidated object are no longer returned from Get.
// It is safe to call Invalidate multiple times or on non-existing entries.
Invalidate(ctx context.Context, index I, key ...K) error
// Delete one or more keys from a specific index.
// An [IndexUnknownError] may be returned if the index is unknown.
// The referred object is not invalidated and may still be accessible though
// other indices and keys.
// It is safe to call Delete multiple times or on non-existing entries
Delete(ctx context.Context, index I, key ...K) error
// Truncate deletes all cached objects.
Truncate(ctx context.Context) error
}
// Entry contains a value of type `V` to be cached.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
//
// `K` is the type used as key in an index.
// Due to the limitations in type constraints, all indices use the same key type.
type Entry[I, K comparable] interface {
// Keys returns which keys map to the object in a specified index.
// May return nil if the index in unknown or when there are no keys.
Keys(index I) (key []K)
}
type Connector int
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
const (
// Empty line comment ensures empty string for unspecified value
ConnectorUnspecified Connector = iota //
ConnectorMemory
ConnectorPostgres
ConnectorRedis
)
type Config struct {
Connector Connector
// Age since an object was added to the cache,
// after which the object is considered invalid.
// 0 disables max age checks.
MaxAge time.Duration
// Age since last use (Get) of an object,
// after which the object is considered invalid.
// 0 disables last use age checks.
LastUseAge time.Duration
// Log allows logging of the specific cache.
// By default only errors are logged to stdout.
Log *logging.Config
}

View File

@@ -0,0 +1,49 @@
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
package connector
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap"
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop"
)
type CachesConfig struct {
Connectors struct {
Memory gomap.Config
}
Instance *cache.Config
Milestones *cache.Config
Organization *cache.Config
IdPFormCallbacks *cache.Config
}
type Connectors struct {
Config CachesConfig
Memory *gomap.Connector
}
func StartConnectors(conf *CachesConfig) (Connectors, error) {
if conf == nil {
return Connectors{}, nil
}
return Connectors{
Config: *conf,
Memory: gomap.NewConnector(conf.Connectors.Memory),
}, nil
}
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
return noop.NewCache[I, K, V](), nil
}
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@@ -0,0 +1,23 @@
package gomap
import (
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type Config struct {
Enabled bool
AutoPrune cache.AutoPruneConfig
}
type Connector struct {
Config cache.AutoPruneConfig
}
func NewConnector(config Config) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
Config: config.AutoPrune,
}
}

View File

@@ -0,0 +1,200 @@
package gomap
import (
"context"
"errors"
"log/slog"
"maps"
"os"
"sync"
"sync/atomic"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
config *cache.Config
indexMap map[I]*index[K, V]
logger *slog.Logger
}
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
// Object values are stored as-is and there is no encoding or decoding involved.
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
m := &mapCache[I, K, V]{
config: &config,
indexMap: make(map[I]*index[K, V], len(indices)),
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelError,
})),
}
if config.Log != nil {
m.logger = config.Log.Slog()
}
m.logger.InfoContext(background, "map cache logging enabled")
for _, name := range indices {
m.indexMap[name] = &index[K, V]{
config: m.config,
entries: make(map[K]*entry[V]),
}
}
return m
}
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
i, ok := c.indexMap[index]
if !ok {
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
entry, err := i.Get(key)
if err == nil {
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
return entry.value, true
}
if errors.Is(err, cache.ErrCacheMiss) {
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
return value, false
}
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
now := time.Now()
entry := &entry[V]{
value: value,
created: now,
}
entry.lastUse.Store(now.UnixMicro())
for name, i := range c.indexMap {
keys := value.Keys(name)
i.Set(keys, entry)
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
}
}
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Invalidate(keys)
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Delete(keys)
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
for name, index := range c.indexMap {
index.Prune()
c.logger.DebugContext(ctx, "map cache prune", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
for name, index := range c.indexMap {
index.Truncate()
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
}
return nil
}
type index[K comparable, V any] struct {
mutex sync.RWMutex
config *cache.Config
entries map[K]*entry[V]
}
func (i *index[K, V]) Get(key K) (*entry[V], error) {
i.mutex.RLock()
entry, ok := i.entries[key]
i.mutex.RUnlock()
if ok && entry.isValid(i.config) {
return entry, nil
}
return nil, cache.ErrCacheMiss
}
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
c.mutex.Lock()
for _, key := range keys {
c.entries[key] = entry
}
c.mutex.Unlock()
}
func (i *index[K, V]) Invalidate(keys []K) {
i.mutex.RLock()
for _, key := range keys {
if entry, ok := i.entries[key]; ok {
entry.invalid.Store(true)
}
}
i.mutex.RUnlock()
}
func (c *index[K, V]) Delete(keys []K) {
c.mutex.Lock()
for _, key := range keys {
delete(c.entries, key)
}
c.mutex.Unlock()
}
func (c *index[K, V]) Prune() {
c.mutex.Lock()
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
return !entry.isValid(c.config)
})
c.mutex.Unlock()
}
func (c *index[K, V]) Truncate() {
c.mutex.Lock()
c.entries = make(map[K]*entry[V])
c.mutex.Unlock()
}
type entry[V any] struct {
value V
created time.Time
invalid atomic.Bool
lastUse atomic.Int64 // UnixMicro time
}
func (e *entry[V]) isValid(c *cache.Config) bool {
if e.invalid.Load() {
return false
}
now := time.Now()
if c.MaxAge > 0 {
if e.created.Add(c.MaxAge).Before(now) {
e.invalid.Store(true)
return false
}
}
if c.LastUseAge > 0 {
lastUse := e.lastUse.Load()
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
e.invalid.Store(true)
return false
}
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
}
return true
}

View File

@@ -0,0 +1,329 @@
package gomap
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
id string
names []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.id}
case testIndexName:
return o.names
default:
return nil
}
}
func Test_mapCache_Get(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
args args
want *testObject
wantOk bool
}{
{
name: "ok",
args: args{
index: testIndexID,
key: "id",
},
want: obj,
wantOk: true,
},
{
name: "miss",
args: args{
index: testIndexID,
key: "spanac",
},
want: nil,
wantOk: false,
},
{
name: "unknown index",
args: args{
index: 99,
key: "id",
},
want: nil,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, ok)
})
}
}
func Test_mapCache_Invalidate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
got, ok := c.Get(context.Background(), testIndexID, "id")
assert.Nil(t, got)
assert.False(t, ok)
}
func Test_mapCache_Delete(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Delete(context.Background(), testIndexName, "bar")
require.NoError(t, err)
// Shouldn't find object by deleted name
got, ok := c.Get(context.Background(), testIndexName, "bar")
assert.Nil(t, got)
assert.False(t, ok)
// Should find object by other name
got, ok = c.Get(context.Background(), testIndexName, "foo")
assert.Equal(t, obj, got)
assert.True(t, ok)
// Should find object by id
got, ok = c.Get(context.Background(), testIndexID, "id")
assert.Equal(t, obj, got)
assert.True(t, ok)
}
func Test_mapCache_Prune(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
// invalidate one entry
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
err = c.(cache.Pruner).Prune(context.Background())
require.NoError(t, err)
// Other object should still be found
got, ok := c.Get(context.Background(), testIndexID, "id2")
assert.Equal(t, objects[1], got)
assert.True(t, ok)
}
func Test_mapCache_Truncate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
err := c.Truncate(context.Background())
require.NoError(t, err)
mc := c.(*mapCache[testIndex, string, *testObject])
for _, index := range mc.indexMap {
index.mutex.RLock()
assert.Len(t, index.entries, 0)
index.mutex.RUnlock()
}
}
func Test_entry_isValid(t *testing.T) {
type fields struct {
created time.Time
invalid bool
lastUse time.Time
}
tests := []struct {
name string
fields fields
config *cache.Config
want bool
}{
{
name: "invalid",
fields: fields{
created: time.Now(),
invalid: true,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
LastUseAge: time.Second,
},
want: true,
},
{
name: "last use age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "last use age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
},
want: true,
},
{
name: "valid",
fields: fields{
created: time.Now(),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &entry[any]{
created: tt.fields.created,
}
e.invalid.Store(tt.fields.invalid)
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
got := e.isValid(tt.config)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,21 @@
package noop
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
// NewCache returns a cache that does nothing
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
return noop[I, K, V]{}
}
func (noop[I, K, V]) Set(context.Context, V) {}
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }

View File

@@ -0,0 +1,98 @@
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _ConnectorName = "memorypostgresredis"
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
const _ConnectorLowerName = "memorypostgresredis"
func (i Connector) String() string {
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
return fmt.Sprintf("Connector(%d)", i)
}
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _ConnectorNoOp() {
var x [1]struct{}
_ = x[ConnectorUnspecified-(0)]
_ = x[ConnectorMemory-(1)]
_ = x[ConnectorPostgres-(2)]
_ = x[ConnectorRedis-(3)]
}
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
var _ConnectorNameToValueMap = map[string]Connector{
_ConnectorName[0:0]: ConnectorUnspecified,
_ConnectorLowerName[0:0]: ConnectorUnspecified,
_ConnectorName[0:6]: ConnectorMemory,
_ConnectorLowerName[0:6]: ConnectorMemory,
_ConnectorName[6:14]: ConnectorPostgres,
_ConnectorLowerName[6:14]: ConnectorPostgres,
_ConnectorName[14:19]: ConnectorRedis,
_ConnectorLowerName[14:19]: ConnectorRedis,
}
var _ConnectorNames = []string{
_ConnectorName[0:0],
_ConnectorName[0:6],
_ConnectorName[6:14],
_ConnectorName[14:19],
}
// ConnectorString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func ConnectorString(s string) (Connector, error) {
if val, ok := _ConnectorNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Connector values", s)
}
// ConnectorValues returns all values of the enum
func ConnectorValues() []Connector {
return _ConnectorValues
}
// ConnectorStrings returns a slice of all String values of the enum
func ConnectorStrings() []string {
strs := make([]string, len(_ConnectorNames))
copy(strs, _ConnectorNames)
return strs
}
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Connector) IsAConnector() bool {
for _, v := range _ConnectorValues {
if i == v {
return true
}
}
return false
}
// MarshalText implements the encoding.TextMarshaler interface for Connector
func (i Connector) MarshalText() ([]byte, error) {
return []byte(i.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
func (i *Connector) UnmarshalText(text []byte) error {
var err error
*i, err = ConnectorString(string(text))
return err
}

29
backend/v3/storage/cache/error.go vendored Normal file
View File

@@ -0,0 +1,29 @@
package cache
import (
"errors"
"fmt"
)
type IndexUnknownError[I comparable] struct {
index I
}
func NewIndexUnknownErr[I comparable](index I) error {
return IndexUnknownError[I]{index}
}
func (i IndexUnknownError[I]) Error() string {
return fmt.Sprintf("index %v unknown", i.index)
}
func (a IndexUnknownError[I]) Is(err error) bool {
if b, ok := err.(IndexUnknownError[I]); ok {
return a.index == b.index
}
return false
}
var (
ErrCacheMiss = errors.New("cache miss")
)

76
backend/v3/storage/cache/pruner.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"context"
"math/rand"
"time"
"github.com/jonboulle/clockwork"
"github.com/zitadel/logging"
)
// Pruner is an optional [Cache] interface.
type Pruner interface {
// Prune deletes all invalidated or expired objects.
Prune(ctx context.Context) error
}
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
Cache[I, K, V]
Pruner
}
type AutoPruneConfig struct {
// Interval at which the cache is automatically pruned.
// 0 or lower disables automatic pruning.
Interval time.Duration
// Timeout for an automatic prune.
// It is recommended to keep the value shorter than AutoPruneInterval
// 0 or lower disables automatic pruning.
Timeout time.Duration
}
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
}
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
if c.Interval <= 0 {
return func() {}
}
background, cancel := context.WithCancel(background)
// randomize the first interval
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
go c.pruneTimer(background, pruner, purpose, timer)
return cancel
}
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
defer func() {
if !timer.Stop() {
<-timer.Chan()
}
}()
for {
select {
case <-background.Done():
return
case <-timer.Chan():
err := c.doPrune(background, pruner)
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
timer.Reset(c.Interval)
}
}
}
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
ctx, cancel := context.WithCancel(background)
defer cancel()
if c.Timeout > 0 {
ctx, cancel = context.WithTimeout(background, c.Timeout)
defer cancel()
}
return pruner.Prune(ctx)
}

43
backend/v3/storage/cache/pruner_test.go vendored Normal file
View File

@@ -0,0 +1,43 @@
package cache
import (
"context"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
)
type testPruner struct {
called chan struct{}
}
func (p *testPruner) Prune(context.Context) error {
p.called <- struct{}{}
return nil
}
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
c := AutoPruneConfig{
Interval: time.Second,
Timeout: time.Millisecond,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pruner := testPruner{
called: make(chan struct{}),
}
clock := clockwork.NewFakeClock()
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
defer close()
clock.Advance(time.Second)
select {
case _, ok := <-pruner.called:
assert.True(t, ok)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}

View File

@@ -0,0 +1,90 @@
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
func (i Purpose) String() string {
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
return fmt.Sprintf("Purpose(%d)", i)
}
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _PurposeNoOp() {
var x [1]struct{}
_ = x[PurposeUnspecified-(0)]
_ = x[PurposeAuthzInstance-(1)]
_ = x[PurposeMilestones-(2)]
_ = x[PurposeOrganization-(3)]
_ = x[PurposeIdPFormCallback-(4)]
}
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
var _PurposeNameToValueMap = map[string]Purpose{
_PurposeName[0:11]: PurposeUnspecified,
_PurposeLowerName[0:11]: PurposeUnspecified,
_PurposeName[11:25]: PurposeAuthzInstance,
_PurposeLowerName[11:25]: PurposeAuthzInstance,
_PurposeName[25:35]: PurposeMilestones,
_PurposeLowerName[25:35]: PurposeMilestones,
_PurposeName[35:47]: PurposeOrganization,
_PurposeLowerName[35:47]: PurposeOrganization,
_PurposeName[47:65]: PurposeIdPFormCallback,
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
}
var _PurposeNames = []string{
_PurposeName[0:11],
_PurposeName[11:25],
_PurposeName[25:35],
_PurposeName[35:47],
_PurposeName[47:65],
}
// PurposeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func PurposeString(s string) (Purpose, error) {
if val, ok := _PurposeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
}
// PurposeValues returns all values of the enum
func PurposeValues() []Purpose {
return _PurposeValues
}
// PurposeStrings returns a slice of all String values of the enum
func PurposeStrings() []string {
strs := make([]string, len(_PurposeNames))
copy(strs, _PurposeNames)
return strs
}
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Purpose) IsAPurpose() bool {
for _, v := range _PurposeValues {
if i == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,9 @@
package database
import (
"context"
)
type Connector interface {
Connect(ctx context.Context) (Pool, error)
}

View File

@@ -0,0 +1,60 @@
package database
import (
"context"
)
var (
db *database
)
type database struct {
connector Connector
pool Pool
}
type Pool interface {
Beginner
QueryExecutor
Acquire(ctx context.Context) (Client, error)
Close(ctx context.Context) error
}
type Client interface {
Beginner
QueryExecutor
Release(ctx context.Context) error
}
type Querier interface {
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
QueryRow(ctx context.Context, stmt string, args ...any) Row
}
type Executor interface {
Exec(ctx context.Context, stmt string, args ...any) error
}
type QueryExecutor interface {
Querier
Executor
}
type Scanner interface {
Scan(dest ...any) error
}
type Row interface {
Scanner
}
type Rows interface {
Row
Next() bool
Close() error
Err() error
}
type Query[T any] func(querier Querier) (result T, err error)

View File

@@ -0,0 +1,92 @@
package dialect
import (
"context"
"errors"
"reflect"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
"github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/storage/database/dialect/postgres"
)
type Hook struct {
Match func(string) bool
Decode func(config any) (database.Connector, error)
Name string
Constructor func() database.Connector
}
var hooks = []Hook{
{
Match: postgres.NameMatcher,
Decode: postgres.DecodeConfig,
Name: postgres.Name,
Constructor: func() database.Connector { return new(postgres.Config) },
},
// {
// Match: gosql.NameMatcher,
// Decode: gosql.DecodeConfig,
// Name: gosql.Name,
// Constructor: func() database.Connector { return new(gosql.Config) },
// },
}
type Config struct {
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
connector database.Connector
}
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
if len(c.Dialects) != 1 {
return nil, errors.New("Exactly one dialect must be configured")
}
return c.connector.Connect(ctx)
}
// Hooks implements [configure.Unmarshaller].
func (c Config) Hooks() []viper.DecoderConfigOption {
return []viper.DecoderConfigOption{
viper.DecodeHook(decodeHook),
}
}
func decodeHook(from, to reflect.Value) (_ any, err error) {
if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil
}
config := new(Config)
if err = mapstructure.Decode(from.Interface(), config); err != nil {
return nil, err
}
if err = config.decodeDialect(); err != nil {
return nil, err
}
return config, nil
}
func (c *Config) decodeDialect() error {
for _, hook := range hooks {
for name, config := range c.Dialects {
if !hook.Match(name) {
continue
}
connector, err := hook.Decode(config)
if err != nil {
return err
}
c.connector = connector
return nil
}
}
return errors.New("no dialect found")
}

View File

@@ -0,0 +1,80 @@
package postgres
import (
"context"
"errors"
"slices"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var (
_ database.Connector = (*Config)(nil)
Name = "postgres"
)
type Config struct {
*pgxpool.Config
// Host string
// Port int32
// Database string
// MaxOpenConns uint32
// MaxIdleConns uint32
// MaxConnLifetime time.Duration
// MaxConnIdleTime time.Duration
// User User
// // Additional options to be appended as options=<Options>
// // The value will be taken as is. Multiple options are space separated.
// Options string
configuredFields []string
}
// Connect implements [database.Connector].
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
pool, err := pgxpool.NewWithConfig(ctx, c.Config)
if err != nil {
return nil, err
}
if err = pool.Ping(ctx); err != nil {
return nil, err
}
return &pgxPool{pool}, nil
}
func NameMatcher(name string) bool {
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
}
func DecodeConfig(input any) (database.Connector, error) {
switch c := input.(type) {
case string:
config, err := pgxpool.ParseConfig(c)
if err != nil {
return nil, err
}
return &Config{Config: config}, nil
case map[string]any:
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
if err = decoder.Decode(c); err != nil {
return nil, err
}
return &Config{
Config: &pgxpool.Config{},
}, nil
}
return nil, errors.New("invalid configuration")
}

View File

@@ -0,0 +1,48 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxConn struct{ *pgxpool.Conn }
var _ database.Client = (*pgxConn)(nil)
// Release implements [database.Client].
func (c *pgxConn) Release(_ context.Context) error {
c.Conn.Release()
return nil
}
// Begin implements [database.Client].
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Query implements sql.Client.
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Conn.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Conn.Exec(ctx, sql, args...)
return err
}

View File

@@ -0,0 +1,57 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxPool struct{ *pgxpool.Pool }
var _ database.Pool = (*pgxPool)(nil)
// Acquire implements [database.Pool].
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Pool.Acquire(ctx)
if err != nil {
return nil, err
}
return &pgxConn{conn}, nil
}
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Pool.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return c.Pool.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Pool.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Pool].
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, err
}
return &pgxTx{tx}, nil
}
// Close implements [database.Pool].
func (c *pgxPool) Close(_ context.Context) error {
c.Pool.Close()
return nil
}

View File

@@ -0,0 +1,18 @@
package postgres
import (
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ database.Rows = (*Rows)(nil)
type Rows struct{ pgx.Rows }
// Close implements [database.Rows].
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
func (r *Rows) Close() error {
r.Rows.Close()
return nil
}

View File

@@ -0,0 +1,95 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxTx struct{ pgx.Tx }
var _ database.Transaction = (*pgxTx)(nil)
// Commit implements [database.Transaction].
func (tx *pgxTx) Commit(ctx context.Context) error {
return tx.Tx.Commit(ctx)
}
// Rollback implements [database.Transaction].
func (tx *pgxTx) Rollback(ctx context.Context) error {
return tx.Tx.Rollback(ctx)
}
// End implements [database.Transaction].
func (tx *pgxTx) End(ctx context.Context, err error) error {
if err != nil {
tx.Rollback(ctx)
return err
}
return tx.Commit(ctx)
}
// Query implements [database.Transaction].
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := tx.Tx.Query(ctx, sql, args...)
return &Rows{rows}, err
}
// QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return tx.Tx.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
_, err := tx.Tx.Exec(ctx, sql, args...)
return err
}
// Begin implements [database.Transaction].
// As postgres does not support nested transactions we use savepoints to emulate them.
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
savepoint, err := tx.Tx.Begin(ctx)
if err != nil {
return nil, err
}
return &pgxTx{savepoint}, nil
}
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
if opts == nil {
return pgx.TxOptions{}
}
return pgx.TxOptions{
IsoLevel: isolationToPgx(opts.IsolationLevel),
AccessMode: accessModeToPgx(opts.AccessMode),
}
}
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
switch isolation {
case database.IsolationLevelSerializable:
return pgx.Serializable
case database.IsolationLevelReadCommitted:
return pgx.ReadCommitted
default:
return pgx.Serializable
}
}
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
switch accessMode {
case database.AccessModeReadWrite:
return pgx.ReadWrite
case database.AccessModeReadOnly:
return pgx.ReadOnly
default:
return pgx.ReadWrite
}
}

View File

@@ -0,0 +1,3 @@
package database
//go:generate mockgen -typed -package mock -destination ./mock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,160 @@
package repository
import (
"fmt"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type field interface {
fmt.Stringer
}
type fieldDescriptor struct {
schema string
table string
name string
}
func (f fieldDescriptor) String() string {
return f.schema + "." + f.table + "." + f.name
}
type ignoreCaseFieldDescriptor struct {
fieldDescriptor
fieldNameSuffix string
}
func (f ignoreCaseFieldDescriptor) String() string {
return f.fieldDescriptor.String() + f.fieldNameSuffix
}
type textFieldDescriptor struct {
field
isIgnoreCase bool
}
type clause[Op domain.Operation] struct {
field field
op Op
}
const (
schema = "zitadel"
userTable = "users"
)
var userFields = map[domain.UserField]field{
domain.UserFieldInstanceID: fieldDescriptor{
schema: schema,
table: userTable,
name: "instance_id",
},
domain.UserFieldOrgID: fieldDescriptor{
schema: schema,
table: userTable,
name: "org_id",
},
domain.UserFieldID: fieldDescriptor{
schema: schema,
table: userTable,
name: "id",
},
domain.UserFieldUsername: textFieldDescriptor{
field: ignoreCaseFieldDescriptor{
fieldDescriptor: fieldDescriptor{
schema: schema,
table: userTable,
name: "username",
},
fieldNameSuffix: "_lower",
},
},
domain.UserHumanFieldEmail: textFieldDescriptor{
field: ignoreCaseFieldDescriptor{
fieldDescriptor: fieldDescriptor{
schema: schema,
table: userTable,
name: "email",
},
fieldNameSuffix: "_lower",
},
},
domain.UserHumanFieldEmailVerified: fieldDescriptor{
schema: schema,
table: userTable,
name: "email_is_verified",
},
}
type textClause[V domain.Text] struct {
clause[domain.TextOperation]
value V
}
var textOp map[domain.TextOperation]string = map[domain.TextOperation]string{
domain.TextOperationEqual: " = ",
domain.TextOperationNotEqual: " <> ",
domain.TextOperationStartsWith: " LIKE ",
domain.TextOperationStartsWithIgnoreCase: " LIKE ",
}
func (tc textClause[V]) Write(stmt *statement) {
placeholder := stmt.appendArg(tc.value)
var (
left, right string
)
switch tc.clause.op {
case domain.TextOperationEqual:
left = tc.clause.field.String()
right = placeholder
case domain.TextOperationNotEqual:
left = tc.clause.field.String()
right = placeholder
case domain.TextOperationStartsWith:
left = tc.clause.field.String()
right = placeholder + "%"
case domain.TextOperationStartsWithIgnoreCase:
left = tc.clause.field.String()
if _, ok := tc.clause.field.(ignoreCaseFieldDescriptor); !ok {
left = "LOWER(" + left + ")"
}
right = "LOWER(" + placeholder + "%)"
}
stmt.builder.WriteString(left)
stmt.builder.WriteString(textOp[tc.clause.op])
stmt.builder.WriteString(right)
}
type boolClause[V domain.Bool] struct {
clause[domain.BoolOperation]
value V
}
func (bc boolClause[V]) Write(stmt *statement) {
if !bc.value {
stmt.builder.WriteString("NOT ")
}
stmt.builder.WriteString(bc.clause.field.String())
}
type numberClause[V domain.Number] struct {
clause[domain.NumberOperation]
value V
}
var numberOp map[domain.NumberOperation]string = map[domain.NumberOperation]string{
domain.NumberOperationEqual: " = ",
domain.NumberOperationNotEqual: " <> ",
domain.NumberOperationLessThan: " < ",
domain.NumberOperationLessThanOrEqual: " <= ",
domain.NumberOperationGreaterThan: " > ",
domain.NumberOperationGreaterThanOrEqual: " >= ",
}
func (nc numberClause[V]) Write(stmt *statement) {
stmt.builder.WriteString(nc.clause.field.String())
stmt.builder.WriteString(numberOp[nc.clause.op])
stmt.builder.WriteString(stmt.appendArg(nc.value))
}

View File

@@ -0,0 +1,45 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/internal/crypto"
)
type cryptoRepo struct {
database.QueryExecutor
}
func Crypto(db database.QueryExecutor) domain.CryptoRepository {
return &cryptoRepo{
QueryExecutor: db,
}
}
const getEncryptionConfigQuery = "SELECT" +
" length" +
", expiry" +
", should_include_lower_letters" +
", should_include_upper_letters" +
", should_include_digits" +
", should_include_symbols" +
" FROM encryption_config"
func (repo *cryptoRepo) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) {
var config crypto.GeneratorConfig
row := repo.QueryRow(ctx, getEncryptionConfigQuery)
err := row.Scan(
&config.Length,
&config.Expiry,
&config.IncludeLowerLetters,
&config.IncludeUpperLetters,
&config.IncludeDigits,
&config.IncludeSymbols,
)
if err != nil {
return nil, err
}
return &config, nil
}

View File

@@ -0,0 +1,7 @@
// Repository package provides the database repository for the application.
// It contains the implementation of the [repository pattern](https://martinfowler.com/eaaCatalog/repository.html) for the database.
// funcs which need to interact with the database should create interfaces which are implemented by the
// [query] and [exec] structs respectively their factory methods [Query] and [Execute]. The [query] struct is used for read operations, while the [exec] struct is used for write operations.
package repository

View File

@@ -0,0 +1,54 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type instance struct {
database.QueryExecutor
}
func Instance(client database.QueryExecutor) domain.InstanceRepository {
return &instance{QueryExecutor: client}
}
func (i *instance) ByID(ctx context.Context, id string) (*domain.Instance, error) {
var instance domain.Instance
err := i.QueryExecutor.QueryRow(ctx, `SELECT id, name, created_at, updated_at, deleted_at FROM instances WHERE id = $1`, id).Scan(
&instance.ID,
&instance.Name,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
return nil, err
}
return &instance, nil
}
const createInstanceStmt = `INSERT INTO instances (id, name) VALUES ($1, $2) RETURNING created_at, updated_at`
// Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
return i.QueryExecutor.QueryRow(ctx, createInstanceStmt,
instance.ID,
instance.Name,
).Scan(
&instance.CreatedAt,
&instance.UpdatedAt,
)
}
// On implements [domain.InstanceRepository].
func (i *instance) On(id string) domain.InstanceOperation {
return &instanceOperation{
QueryExecutor: i.QueryExecutor,
id: id,
}
}
var _ domain.InstanceRepository = (*instance)(nil)

View File

@@ -0,0 +1,52 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type instanceOperation struct {
database.QueryExecutor
id string
}
const addInstanceAdminStmt = `INSERT INTO instance_admins (instance_id, user_id, roles) VALUES ($1, $2, $3)`
// AddAdmin implements [domain.InstanceOperation].
func (i *instanceOperation) AddAdmin(ctx context.Context, userID string, roles []string) error {
return i.QueryExecutor.Exec(ctx, addInstanceAdminStmt, i.id, userID, roles)
}
// Delete implements [domain.InstanceOperation].
func (i *instanceOperation) Delete(ctx context.Context) error {
return i.QueryExecutor.Exec(ctx, `DELETE FROM instances WHERE id = $1`, i.id)
}
const removeInstanceAdminStmt = `DELETE FROM instance_admins WHERE instance_id = $1 AND user_id = $2`
// RemoveAdmin implements [domain.InstanceOperation].
func (i *instanceOperation) RemoveAdmin(ctx context.Context, userID string) error {
return i.QueryExecutor.Exec(ctx, removeInstanceAdminStmt, i.id, userID)
}
const setInstanceAdminRolesStmt = `UPDATE instance_admins SET roles = $1 WHERE instance_id = $2 AND user_id = $3`
// SetAdminRoles implements [domain.InstanceOperation].
func (i *instanceOperation) SetAdminRoles(ctx context.Context, userID string, roles []string) error {
return i.QueryExecutor.Exec(ctx, setInstanceAdminRolesStmt, roles, i.id, userID)
}
const updateInstanceStmt = `UPDATE instances SET name = $1, updated_at = $2 WHERE id = $3 RETURNING updated_at`
// Update implements [domain.InstanceOperation].
func (i *instanceOperation) Update(ctx context.Context, instance *domain.Instance) error {
return i.QueryExecutor.QueryRow(ctx, updateInstanceStmt,
instance.Name,
instance.UpdatedAt,
i.id,
).Scan(&instance.UpdatedAt)
}
var _ domain.InstanceOperation = (*instanceOperation)(nil)

View File

@@ -0,0 +1,17 @@
package repository
import (
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type query struct{ database.Querier }
func Query(querier database.Querier) *query {
return &query{Querier: querier}
}
type executor struct{ database.Executor }
func Execute(exec database.Executor) *executor {
return &executor{Executor: exec}
}

View File

@@ -0,0 +1,21 @@
package repository
import "strings"
type statement struct {
builder strings.Builder
args []any
}
func (s *statement) appendArg(arg any) (placeholder string) {
s.args = append(s.args, arg)
return "$" + string(len(s.args))
}
func (s *statement) appendArgs(args ...any) (placeholders []string) {
placeholders = make([]string, len(args))
for i, arg := range args {
placeholders[i] = s.appendArg(arg)
}
return placeholders
}

View File

@@ -0,0 +1,43 @@
package stmt
import "fmt"
type Column[T any] interface {
fmt.Stringer
statementApplier[T]
scanner(t *T) any
}
type columnDescriptor[T any] struct {
name string
scan func(*T) any
}
func (cd columnDescriptor[T]) scanner(t *T) any {
return cd.scan(t)
}
// Apply implements [Column].
func (f columnDescriptor[T]) Apply(stmt *statement[T]) {
stmt.builder.WriteString(stmt.columnPrefix())
stmt.builder.WriteString(f.String())
}
// String implements [Column].
func (f columnDescriptor[T]) String() string {
return f.name
}
var _ Column[any] = (*columnDescriptor[any])(nil)
type ignoreCaseColumnDescriptor[T any] struct {
columnDescriptor[T]
fieldNameSuffix string
}
func (f ignoreCaseColumnDescriptor[T]) ApplyIgnoreCase(stmt *statement[T]) {
stmt.builder.WriteString(f.String())
stmt.builder.WriteString(f.fieldNameSuffix)
}
var _ Column[any] = (*ignoreCaseColumnDescriptor[any])(nil)

View File

@@ -0,0 +1,97 @@
package stmt
import "fmt"
type statementApplier[T any] interface {
// Apply writes the statement to the builder.
Apply(stmt *statement[T])
}
type Condition[T any] interface {
statementApplier[T]
}
type op interface {
TextOperation | NumberOperation | ListOperation
fmt.Stringer
}
type operation[T any, O op] struct {
o O
}
func (o operation[T, O]) String() string {
return o.o.String()
}
func (o operation[T, O]) Apply(stmt *statement[T]) {
stmt.builder.WriteString(o.o.String())
}
type condition[V, T any, OP op] struct {
field Column[T]
op OP
value V
}
func (c *condition[V, T, OP]) Apply(stmt *statement[T]) {
// placeholder := stmt.appendArg(c.value)
stmt.builder.WriteString(stmt.columnPrefix())
stmt.builder.WriteString(c.field.String())
// stmt.builder.WriteString(c.op)
// stmt.builder.WriteString(placeholder)
}
type and[T any] struct {
conditions []Condition[T]
}
func And[T any](conditions ...Condition[T]) *and[T] {
return &and[T]{
conditions: conditions,
}
}
// Apply implements [Condition].
func (a *and[T]) Apply(stmt *statement[T]) {
if len(a.conditions) > 1 {
stmt.builder.WriteString("(")
defer stmt.builder.WriteString(")")
}
for i, condition := range a.conditions {
if i > 0 {
stmt.builder.WriteString(" AND ")
}
condition.Apply(stmt)
}
}
var _ Condition[any] = (*and[any])(nil)
type or[T any] struct {
conditions []Condition[T]
}
func Or[T any](conditions ...Condition[T]) *or[T] {
return &or[T]{
conditions: conditions,
}
}
// Apply implements [Condition].
func (o *or[T]) Apply(stmt *statement[T]) {
if len(o.conditions) > 1 {
stmt.builder.WriteString("(")
defer stmt.builder.WriteString(")")
}
for i, condition := range o.conditions {
if i > 0 {
stmt.builder.WriteString(" OR ")
}
condition.Apply(stmt)
}
}
var _ Condition[any] = (*or[any])(nil)

View File

@@ -0,0 +1,71 @@
package stmt
type ListEntry interface {
Number | Text | any
}
type ListCondition[E ListEntry, T any] struct {
condition[[]E, T, ListOperation]
}
func (lc *ListCondition[E, T]) Apply(stmt *statement[T]) {
placeholder := stmt.appendArg(lc.value)
switch lc.op {
case ListOperationEqual, ListOperationNotEqual:
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case ListOperationContainsAny, ListOperationContainsAll:
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case ListOperationNotContainsAny, ListOperationNotContainsAll:
stmt.builder.WriteString("NOT (")
lc.field.Apply(stmt)
operation[T, ListOperation]{lc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
default:
panic("unknown list operation")
}
}
type ListOperation uint8
const (
// ListOperationEqual checks if the arrays are equal including the order of the elements
ListOperationEqual ListOperation = iota + 1
// ListOperationNotEqual checks if the arrays are not equal including the order of the elements
ListOperationNotEqual
// ListOperationContains checks if the array column contains all the values of the specified array
ListOperationContainsAll
// ListOperationContainsAny checks if the arrays have at least one value in common
ListOperationContainsAny
// ListOperationContainsAll checks if the array column contains all the values of the specified array
// ListOperationNotContainsAll checks if the specified array is not contained by the column
ListOperationNotContainsAll
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
ListOperationNotContainsAny
)
var listOperations = map[ListOperation]string{
// ListOperationEqual checks if the lists are equal
ListOperationEqual: " = ",
// ListOperationNotEqual checks if the lists are not equal
ListOperationNotEqual: " <> ",
// ListOperationContainsAny checks if the arrays have at least one value in common
ListOperationContainsAny: " && ",
// ListOperationContainsAll checks if the array column contains all the values of the specified array
ListOperationContainsAll: " @> ",
// ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array
ListOperationNotContainsAny: " && ", // Base operator for NOT (A && B)
// ListOperationNotContainsAll checks if the array column is not contained by the specified array
ListOperationNotContainsAll: " <@ ", // Base operator for NOT (A <@ B)
}
func (lo ListOperation) String() string {
return listOperations[lo]
}

View File

@@ -0,0 +1,61 @@
package stmt
import (
"time"
"golang.org/x/exp/constraints"
)
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
type between[N Number] struct {
min, max N
}
type NumberBetween[V Number, T any] struct {
condition[between[V], T, NumberOperation]
}
func (nb *NumberBetween[V, T]) Apply(stmt *statement[T]) {
nb.field.Apply(stmt)
stmt.builder.WriteString(" BETWEEN ")
stmt.builder.WriteString(stmt.appendArg(nb.value.min))
stmt.builder.WriteString(" AND ")
stmt.builder.WriteString(stmt.appendArg(nb.value.max))
}
type NumberCondition[V Number, T any] struct {
condition[V, T, NumberOperation]
}
func (nc *NumberCondition[V, T]) Apply(stmt *statement[T]) {
nc.field.Apply(stmt)
operation[T, NumberOperation]{nc.op}.Apply(stmt)
stmt.builder.WriteString(stmt.appendArg(nc.value))
}
type NumberOperation uint8
const (
NumberOperationEqual NumberOperation = iota + 1
NumberOperationNotEqual
NumberOperationLessThan
NumberOperationLessThanOrEqual
NumberOperationGreaterThan
NumberOperationGreaterThanOrEqual
)
var numberOperations = map[NumberOperation]string{
NumberOperationEqual: " = ",
NumberOperationNotEqual: " <> ",
NumberOperationLessThan: " < ",
NumberOperationLessThanOrEqual: " <= ",
NumberOperationGreaterThan: " > ",
NumberOperationGreaterThanOrEqual: " >= ",
}
func (no NumberOperation) String() string {
return numberOperations[no]
}

View File

@@ -0,0 +1,104 @@
package stmt
import (
"fmt"
"strings"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type statement[T any] struct {
builder strings.Builder
client database.QueryExecutor
columns []Column[T]
schema string
table string
alias string
condition Condition[T]
limit uint32
offset uint32
// order by fieldname and sort direction false for asc true for desc
// orderBy SortingColumns[C]
args []any
existingArgs map[any]string
}
func (s *statement[T]) scanners(t *T) []any {
scanners := make([]any, len(s.columns))
for i, column := range s.columns {
scanners[i] = column.scanner(t)
}
return scanners
}
func (s *statement[T]) query() string {
s.builder.WriteString(`SELECT `)
for i, column := range s.columns {
if i > 0 {
s.builder.WriteString(", ")
}
column.Apply(s)
}
s.builder.WriteString(` FROM `)
s.builder.WriteString(s.schema)
s.builder.WriteRune('.')
s.builder.WriteString(s.table)
if s.alias != "" {
s.builder.WriteString(" AS ")
s.builder.WriteString(s.alias)
}
s.builder.WriteString(` WHERE `)
s.condition.Apply(s)
if s.limit > 0 {
s.builder.WriteString(` LIMIT `)
s.builder.WriteString(s.appendArg(s.limit))
}
if s.offset > 0 {
s.builder.WriteString(` OFFSET `)
s.builder.WriteString(s.appendArg(s.offset))
}
return s.builder.String()
}
// func (s *statement[T]) Where(condition Condition[T]) *statement[T] {
// s.condition = condition
// return s
// }
// func (s *statement[T]) Limit(limit uint32) *statement[T] {
// s.limit = limit
// return s
// }
// func (s *statement[T]) Offset(offset uint32) *statement[T] {
// s.offset = offset
// return s
// }
func (s *statement[T]) columnPrefix() string {
if s.alias != "" {
return s.alias + "."
}
return s.schema + "." + s.table + "."
}
func (s *statement[T]) appendArg(arg any) string {
if s.existingArgs == nil {
s.existingArgs = make(map[any]string)
}
if existing, ok := s.existingArgs[arg]; ok {
return existing
}
s.args = append(s.args, arg)
placeholder := fmt.Sprintf("$%d", len(s.args))
s.existingArgs[arg] = placeholder
return placeholder
}

View File

@@ -0,0 +1,18 @@
package stmt_test
import (
"context"
"testing"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt"
)
func Test_Bla(t *testing.T) {
stmt.User(nil).Where(
stmt.Or(
stmt.UserIDCondition("123"),
stmt.UserIDCondition("123"),
stmt.UserUsernameCondition(stmt.TextOperationEqualIgnoreCase, "test"),
),
).Limit(1).Offset(1).Get(context.Background())
}

View File

@@ -0,0 +1,72 @@
package stmt
type Text interface {
~string | ~[]byte
}
type TextCondition[V Text, T any] struct {
condition[V, T, TextOperation]
}
func (tc *TextCondition[V, T]) Apply(stmt *statement[T]) {
placeholder := stmt.appendArg(tc.value)
switch tc.op {
case TextOperationEqual, TextOperationNotEqual:
tc.field.Apply(stmt)
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
case TextOperationEqualIgnoreCase:
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
desc.ApplyIgnoreCase(stmt)
} else {
stmt.builder.WriteString("LOWER(")
tc.field.Apply(stmt)
stmt.builder.WriteString(")")
}
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString("LOWER(")
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
case TextOperationStartsWith:
tc.field.Apply(stmt)
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString("|| '%'")
case TextOperationStartsWithIgnoreCase:
if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok {
desc.ApplyIgnoreCase(stmt)
} else {
stmt.builder.WriteString("LOWER(")
tc.field.Apply(stmt)
stmt.builder.WriteString(")")
}
operation[T, TextOperation]{tc.op}.Apply(stmt)
stmt.builder.WriteString("LOWER(")
stmt.builder.WriteString(placeholder)
stmt.builder.WriteString(")")
stmt.builder.WriteString("|| '%'")
}
}
type TextOperation uint8
const (
TextOperationEqual TextOperation = iota + 1
TextOperationEqualIgnoreCase
TextOperationNotEqual
TextOperationStartsWith
TextOperationStartsWithIgnoreCase
)
var textOperations = map[TextOperation]string{
TextOperationEqual: " = ",
TextOperationEqualIgnoreCase: " = ",
TextOperationNotEqual: " <> ",
TextOperationStartsWith: " LIKE ",
TextOperationStartsWithIgnoreCase: " LIKE ",
}
func (to TextOperation) String() string {
return textOperations[to]
}

View File

@@ -0,0 +1,193 @@
package stmt
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type userStatement struct {
statement[domain.User]
}
func User(client database.QueryExecutor) *userStatement {
return &userStatement{
statement: statement[domain.User]{
schema: "zitadel",
table: "users",
alias: "u",
client: client,
columns: []Column[domain.User]{
userColumns[UserInstanceID],
userColumns[UserOrgID],
userColumns[UserColumnID],
userColumns[UserColumnUsername],
userColumns[UserCreatedAt],
userColumns[UserUpdatedAt],
userColumns[UserDeletedAt],
},
},
}
}
func (s *userStatement) Where(condition Condition[domain.User]) *userStatement {
s.condition = condition
return s
}
func (s *userStatement) Limit(limit uint32) *userStatement {
s.limit = limit
return s
}
func (s *userStatement) Offset(offset uint32) *userStatement {
s.offset = offset
return s
}
func (s *userStatement) Get(ctx context.Context) (*domain.User, error) {
var user domain.User
err := s.client.QueryRow(ctx, s.query(), s.statement.args...).Scan(s.scanners(&user)...)
if err != nil {
return nil, err
}
return &user, nil
}
func (s *userStatement) List(ctx context.Context) ([]*domain.User, error) {
var users []*domain.User
rows, err := s.client.Query(ctx, s.query(), s.statement.args...)
if err != nil {
return nil, err
}
defer rows.Close()
for rows.Next() {
var user domain.User
err = rows.Scan(s.scanners(&user)...)
if err != nil {
return nil, err
}
users = append(users, &user)
}
return users, nil
}
func (s *userStatement) SetUsername(ctx context.Context, username string) error {
return nil
}
type UserColumn uint8
var (
userColumns map[UserColumn]Column[domain.User] = map[UserColumn]Column[domain.User]{
UserInstanceID: columnDescriptor[domain.User]{
name: "instance_id",
scan: func(u *domain.User) any {
return &u.InstanceID
},
},
UserOrgID: columnDescriptor[domain.User]{
name: "org_id",
scan: func(u *domain.User) any {
return &u.OrgID
},
},
UserColumnID: columnDescriptor[domain.User]{
name: "id",
scan: func(u *domain.User) any {
return &u.ID
},
},
UserColumnUsername: ignoreCaseColumnDescriptor[domain.User]{
columnDescriptor: columnDescriptor[domain.User]{
name: "username",
scan: func(u *domain.User) any {
return &u.Username
},
},
fieldNameSuffix: "_lower",
},
UserCreatedAt: columnDescriptor[domain.User]{
name: "created_at",
scan: func(u *domain.User) any {
return &u.CreatedAt
},
},
UserUpdatedAt: columnDescriptor[domain.User]{
name: "updated_at",
scan: func(u *domain.User) any {
return &u.UpdatedAt
},
},
UserDeletedAt: columnDescriptor[domain.User]{
name: "deleted_at",
scan: func(u *domain.User) any {
return &u.DeletedAt
},
},
}
humanColumns = map[UserColumn]Column[domain.User]{
UserHumanColumnEmail: ignoreCaseColumnDescriptor[domain.User]{
columnDescriptor: columnDescriptor[domain.User]{
name: "email",
scan: func(u *domain.User) any {
human, ok := u.Traits.(*domain.Human)
if !ok {
return nil
}
if human.Email == nil {
human.Email = new(domain.Email)
}
return &human.Email.Address
},
},
fieldNameSuffix: "_lower",
},
UserHumanColumnEmailVerified: columnDescriptor[domain.User]{
name: "email_is_verified",
scan: func(u *domain.User) any {
human, ok := u.Traits.(*domain.Human)
if !ok {
return nil
}
if human.Email == nil {
human.Email = new(domain.Email)
}
return &human.Email.IsVerified
},
},
}
machineColumns = map[UserColumn]Column[domain.User]{
UserMachineDescription: columnDescriptor[domain.User]{
name: "description",
scan: func(u *domain.User) any {
machine, ok := u.Traits.(*domain.Machine)
if !ok {
return nil
}
if machine == nil {
machine = new(domain.Machine)
}
return &machine.Description
},
},
}
)
const (
UserInstanceID UserColumn = iota + 1
UserOrgID
UserColumnID
UserColumnUsername
UserHumanColumnEmail
UserHumanColumnEmailVerified
UserMachineDescription
UserCreatedAt
UserUpdatedAt
UserDeletedAt
)

View File

@@ -0,0 +1,23 @@
package stmt
import "github.com/zitadel/zitadel/backend/v3/domain"
func UserIDCondition(id string) *TextCondition[string, domain.User] {
return &TextCondition[string, domain.User]{
condition: condition[string, domain.User, TextOperation]{
field: userColumns[UserColumnID],
op: TextOperationEqual,
value: id,
},
}
}
func UserUsernameCondition(op TextOperation, username string) *TextCondition[string, domain.User] {
return &TextCondition[string, domain.User]{
condition: condition[string, domain.User, TextOperation]{
field: userColumns[UserColumnUsername],
op: op,
value: username,
},
}
}

Some files were not shown because too many files have changed in this diff Show More