diff --git a/backend/command/client/grpc/api.go b/backend/command/client/grpc/api.go new file mode 100644 index 0000000000..417aa0a9d0 --- /dev/null +++ b/backend/command/client/grpc/api.go @@ -0,0 +1,56 @@ +package grpc + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/command" + "github.com/zitadel/zitadel/backend/command/query" + "github.com/zitadel/zitadel/backend/command/receiver" + "github.com/zitadel/zitadel/backend/command/receiver/cache" + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/backend/telemetry/logging" + "github.com/zitadel/zitadel/backend/telemetry/tracing" +) + +type api struct { + db database.Pool + + manipulator receiver.InstanceManipulator + reader receiver.InstanceReader + tracer *tracing.Tracer + logger *logging.Logger + cache cache.Cache[receiver.InstanceIndex, string, *receiver.Instance] +} + +func (a *api) CreateInstance(ctx context.Context) error { + instance := &receiver.Instance{ + ID: "123", + Name: "test", + } + return command.Trace( + a.tracer, + command.SetCache(a.cache, + command.Activity(a.logger, command.CreateInstance(a.manipulator, instance)), + instance, + ), + ).Execute(ctx) +} + +func (a *api) DeleteInstance(ctx context.Context) error { + return command.Trace( + a.tracer, + command.DeleteCache(a.cache, + command.Activity( + a.logger, + command.DeleteInstance(a.manipulator, &receiver.Instance{ + ID: "123", + })), + receiver.InstanceByID, + "123", + )).Execute(ctx) +} + +func (a *api) InstanceByID(ctx context.Context) (*receiver.Instance, error) { + q := query.InstanceByID(a.reader, "123") + return q.Execute(ctx) +} diff --git a/backend/command/command/caching.go b/backend/command/command/caching.go new file mode 100644 index 0000000000..880ce24ba0 --- /dev/null +++ b/backend/command/command/caching.go @@ -0,0 +1,102 @@ +package command + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/receiver/cache" +) + +type setCache[I, K comparable, V cache.Entry[I, K]] struct { + cache cache.Cache[I, K, V] + command Command + entry V +} + +// SetCache decorates the command, if the command is executed without error it will set the cache entry. +func SetCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, entry V) Command { + return &setCache[I, K, V]{ + cache: cache, + command: command, + entry: entry, + } +} + +var _ Command = (*setCache[any, any, cache.Entry[any, any]])(nil) + +// Execute implements [Command]. +func (s *setCache[I, K, V]) Execute(ctx context.Context) error { + if err := s.command.Execute(ctx); err != nil { + return err + } + s.cache.Set(ctx, s.entry) + return nil +} + +// Name implements [Command]. +func (s *setCache[I, K, V]) Name() string { + return s.command.Name() +} + +type deleteCache[I, K comparable, V cache.Entry[I, K]] struct { + cache cache.Cache[I, K, V] + command Command + index I + keys []K +} + +// DeleteCache decorates the command, if the command is executed without error it will delete the cache entry. +func DeleteCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, index I, keys ...K) Command { + return &deleteCache[I, K, V]{ + cache: cache, + command: command, + index: index, + keys: keys, + } +} + +var _ Command = (*deleteCache[any, any, cache.Entry[any, any]])(nil) + +// Execute implements [Command]. +func (s *deleteCache[I, K, V]) Execute(ctx context.Context) error { + if err := s.command.Execute(ctx); err != nil { + return err + } + return s.cache.Delete(ctx, s.index, s.keys...) +} + +// Name implements [Command]. +func (s *deleteCache[I, K, V]) Name() string { + return s.command.Name() +} + +type invalidateCache[I, K comparable, V cache.Entry[I, K]] struct { + cache cache.Cache[I, K, V] + command Command + index I + keys []K +} + +// InvalidateCache decorates the command, if the command is executed without error it will invalidate the cache entry. +func InvalidateCache[I, K comparable, V cache.Entry[I, K]](cache cache.Cache[I, K, V], command Command, index I, keys ...K) Command { + return &invalidateCache[I, K, V]{ + cache: cache, + command: command, + index: index, + keys: keys, + } +} + +var _ Command = (*invalidateCache[any, any, cache.Entry[any, any]])(nil) + +// Execute implements [Command]. +func (s *invalidateCache[I, K, V]) Execute(ctx context.Context) error { + if err := s.command.Execute(ctx); err != nil { + return err + } + return s.cache.Invalidate(ctx, s.index, s.keys...) +} + +// Name implements [Command]. +func (s *invalidateCache[I, K, V]) Name() string { + return s.command.Name() +} diff --git a/backend/command/command/command.go b/backend/command/command/command.go index 478fd23b55..f0fd82e201 100644 --- a/backend/command/command/command.go +++ b/backend/command/command/command.go @@ -1,6 +1,10 @@ package command -import "context" +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/receiver/cache" +) type Command interface { Execute(context.Context) error @@ -20,3 +24,8 @@ func (b *Batch) Execute(ctx context.Context) error { } return nil } + +type CacheableCommand[I, K comparable, V cache.Entry[I, K]] interface { + Command + Entry() V +} diff --git a/backend/command/command/instance.go b/backend/command/command/instance.go index f019b29da6..b5fd76b0b7 100644 --- a/backend/command/command/instance.go +++ b/backend/command/command/instance.go @@ -65,7 +65,7 @@ func UpdateInstance(receiver receiver.InstanceManipulator, instance *receiver.In func (u *updateInstance) Execute(ctx context.Context) error { u.Instance.Name = u.name - // return u.receiver.(ctx, u.Instance) + // return u.receiver.Update(ctx, u.Instance) return nil } diff --git a/backend/command/command/logging.go b/backend/command/command/logging.go index 27721079d0..a1b7512e4e 100644 --- a/backend/command/command/logging.go +++ b/backend/command/command/logging.go @@ -8,21 +8,28 @@ import ( "github.com/zitadel/zitadel/backend/telemetry/logging" ) -type Logger struct { +type logger struct { level slog.Level *logging.Logger cmd Command } -func Activity(l *logging.Logger, command Command) *Logger { - return &Logger{ +// Activity decorates the commands execute method with logging. +// It logs the command name, duration, and success or failure of the command. +func Activity(l *logging.Logger, command Command) Command { + return &logger{ Logger: l.With(slog.String("type", "activity")), level: slog.LevelInfo, cmd: command, } } -func (l *Logger) Execute(ctx context.Context) error { +// Name implements [Command]. +func (l *logger) Name() string { + return l.cmd.Name() +} + +func (l *logger) Execute(ctx context.Context) error { start := time.Now() log := l.Logger.With(slog.String("command", l.cmd.Name())) log.InfoContext(ctx, "execute") diff --git a/backend/command/command/tracing.go b/backend/command/command/tracing.go index 9dc3ec73a4..81a5735de6 100644 --- a/backend/command/command/tracing.go +++ b/backend/command/command/tracing.go @@ -6,12 +6,27 @@ import ( "github.com/zitadel/zitadel/backend/telemetry/tracing" ) -type Trace struct { +type trace struct { command Command tracer *tracing.Tracer } -func (t *Trace) Execute(ctx context.Context) error { +// Trace decorates the commands execute method with tracing. +// It creates a span with the command name and records any errors that occur during execution. +// The span is ended after the command is executed. +func Trace(tracer *tracing.Tracer, command Command) Command { + return &trace{ + command: command, + tracer: tracer, + } +} + +// Name implements [Command]. +func (l *trace) Name() string { + return l.command.Name() +} + +func (t *trace) Execute(ctx context.Context) error { ctx, span := t.tracer.Start(ctx, t.command.Name()) defer span.End() err := t.command.Execute(ctx) diff --git a/backend/command/invoker/api.go b/backend/command/invoker/api.go deleted file mode 100644 index 7e04ee4507..0000000000 --- a/backend/command/invoker/api.go +++ /dev/null @@ -1,38 +0,0 @@ -package invoker - -import ( - "context" - - "github.com/zitadel/zitadel/backend/command/command" - "github.com/zitadel/zitadel/backend/command/query" - "github.com/zitadel/zitadel/backend/command/receiver" - "github.com/zitadel/zitadel/backend/command/receiver/db" - "github.com/zitadel/zitadel/backend/storage/database" -) - -type api struct { - db database.Pool - - manipulator receiver.InstanceManipulator - reader receiver.InstanceReader -} - -func (a *api) CreateInstance(ctx context.Context) error { - cmd := command.CreateInstance(db.NewInstance(a.db), &receiver.Instance{ - ID: "123", - Name: "test", - }) - return cmd.Execute(ctx) -} - -func (a *api) DeleteInstance(ctx context.Context) error { - cmd := command.DeleteInstance(db.NewInstance(a.db), &receiver.Instance{ - ID: "123", - }) - return cmd.Execute(ctx) -} - -func (a *api) InstanceByID(ctx context.Context) (*receiver.Instance, error) { - q := query.InstanceByID(a.reader, "123") - return q.Execute(ctx) -} diff --git a/backend/command/receiver/cache/cache.go b/backend/command/receiver/cache/cache.go new file mode 100644 index 0000000000..dc05208caa --- /dev/null +++ b/backend/command/receiver/cache/cache.go @@ -0,0 +1,112 @@ +// Package cache provides abstraction of cache implementations that can be used by zitadel. +package cache + +import ( + "context" + "time" + + "github.com/zitadel/logging" +) + +// Purpose describes which object types are stored by a cache. +type Purpose int + +//go:generate enumer -type Purpose -transform snake -trimprefix Purpose +const ( + PurposeUnspecified Purpose = iota + PurposeAuthzInstance + PurposeMilestones + PurposeOrganization + PurposeIdPFormCallback +) + +// Cache stores objects with a value of type `V`. +// Objects may be referred to by one or more indices. +// Implementations may encode the value for storage. +// This means non-exported fields may be lost and objects +// with function values may fail to encode. +// See https://pkg.go.dev/encoding/json#Marshal for example. +// +// `I` is the type by which indices are identified, +// typically an enum for type-safe access. +// Indices are defined when calling the constructor of an implementation of this interface. +// It is illegal to refer to an idex not defined during construction. +// +// `K` is the type used as key in each index. +// Due to the limitations in type constraints, all indices use the same key type. +// +// Implementations are free to use stricter type constraints or fixed typing. +type Cache[I, K comparable, V Entry[I, K]] interface { + // Get an object through specified index. + // An [IndexUnknownError] may be returned if the index is unknown. + // [ErrCacheMiss] is returned if the key was not found in the index, + // or the object is not valid. + Get(ctx context.Context, index I, key K) (V, bool) + + // Set an object. + // Keys are created on each index based in the [Entry.Keys] method. + // If any key maps to an existing object, the object is invalidated, + // regardless if the object has other keys defined in the new entry. + // This to prevent ghost objects when an entry reduces the amount of keys + // for a given index. + Set(ctx context.Context, value V) + + // Invalidate an object through specified index. + // Implementations may choose to instantly delete the object, + // defer until prune or a separate cleanup routine. + // Invalidated object are no longer returned from Get. + // It is safe to call Invalidate multiple times or on non-existing entries. + Invalidate(ctx context.Context, index I, key ...K) error + + // Delete one or more keys from a specific index. + // An [IndexUnknownError] may be returned if the index is unknown. + // The referred object is not invalidated and may still be accessible though + // other indices and keys. + // It is safe to call Delete multiple times or on non-existing entries + Delete(ctx context.Context, index I, key ...K) error + + // Truncate deletes all cached objects. + Truncate(ctx context.Context) error +} + +// Entry contains a value of type `V` to be cached. +// +// `I` is the type by which indices are identified, +// typically an enum for type-safe access. +// +// `K` is the type used as key in an index. +// Due to the limitations in type constraints, all indices use the same key type. +type Entry[I, K comparable] interface { + // Keys returns which keys map to the object in a specified index. + // May return nil if the index in unknown or when there are no keys. + Keys(index I) (key []K) +} + +type Connector int + +//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text +const ( + // Empty line comment ensures empty string for unspecified value + ConnectorUnspecified Connector = iota // + ConnectorMemory + ConnectorPostgres + ConnectorRedis +) + +type Config struct { + Connector Connector + + // Age since an object was added to the cache, + // after which the object is considered invalid. + // 0 disables max age checks. + MaxAge time.Duration + + // Age since last use (Get) of an object, + // after which the object is considered invalid. + // 0 disables last use age checks. + LastUseAge time.Duration + + // Log allows logging of the specific cache. + // By default only errors are logged to stdout. + Log *logging.Config +} diff --git a/backend/command/receiver/cache/connector/connector.go b/backend/command/receiver/cache/connector/connector.go new file mode 100644 index 0000000000..3cc5e852a6 --- /dev/null +++ b/backend/command/receiver/cache/connector/connector.go @@ -0,0 +1,49 @@ +// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages. +package connector + +import ( + "context" + "fmt" + + "github.com/zitadel/zitadel/backend/storage/cache" + "github.com/zitadel/zitadel/backend/storage/cache/connector/gomap" + "github.com/zitadel/zitadel/backend/storage/cache/connector/noop" +) + +type CachesConfig struct { + Connectors struct { + Memory gomap.Config + } + Instance *cache.Config + Milestones *cache.Config + Organization *cache.Config + IdPFormCallbacks *cache.Config +} + +type Connectors struct { + Config CachesConfig + Memory *gomap.Connector +} + +func StartConnectors(conf *CachesConfig) (Connectors, error) { + if conf == nil { + return Connectors{}, nil + } + return Connectors{ + Config: *conf, + Memory: gomap.NewConnector(conf.Connectors.Memory), + }, nil +} + +func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) { + if conf == nil || conf.Connector == cache.ConnectorUnspecified { + return noop.NewCache[I, K, V](), nil + } + if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil { + c := gomap.NewCache[I, K, V](background, indices, *conf) + connectors.Memory.Config.StartAutoPrune(background, c, purpose) + return c, nil + } + + return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector) +} diff --git a/backend/command/receiver/cache/connector/gomap/connector.go b/backend/command/receiver/cache/connector/gomap/connector.go new file mode 100644 index 0000000000..a37055bd73 --- /dev/null +++ b/backend/command/receiver/cache/connector/gomap/connector.go @@ -0,0 +1,23 @@ +package gomap + +import ( + "github.com/zitadel/zitadel/backend/storage/cache" +) + +type Config struct { + Enabled bool + AutoPrune cache.AutoPruneConfig +} + +type Connector struct { + Config cache.AutoPruneConfig +} + +func NewConnector(config Config) *Connector { + if !config.Enabled { + return nil + } + return &Connector{ + Config: config.AutoPrune, + } +} diff --git a/backend/command/receiver/cache/connector/gomap/gomap.go b/backend/command/receiver/cache/connector/gomap/gomap.go new file mode 100644 index 0000000000..d79e323801 --- /dev/null +++ b/backend/command/receiver/cache/connector/gomap/gomap.go @@ -0,0 +1,200 @@ +package gomap + +import ( + "context" + "errors" + "log/slog" + "maps" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/zitadel/zitadel/backend/storage/cache" +) + +type mapCache[I, K comparable, V cache.Entry[I, K]] struct { + config *cache.Config + indexMap map[I]*index[K, V] + logger *slog.Logger +} + +// NewCache returns an in-memory Cache implementation based on the builtin go map type. +// Object values are stored as-is and there is no encoding or decoding involved. +func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] { + m := &mapCache[I, K, V]{ + config: &config, + indexMap: make(map[I]*index[K, V], len(indices)), + logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelError, + })), + } + if config.Log != nil { + m.logger = config.Log.Slog() + } + m.logger.InfoContext(background, "map cache logging enabled") + + for _, name := range indices { + m.indexMap[name] = &index[K, V]{ + config: m.config, + entries: make(map[K]*entry[V]), + } + } + return m +} + +func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) { + i, ok := c.indexMap[index] + if !ok { + c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key) + return value, false + } + entry, err := i.Get(key) + if err == nil { + c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key) + return entry.value, true + } + if errors.Is(err, cache.ErrCacheMiss) { + c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key) + return value, false + } + c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key) + return value, false +} + +func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) { + now := time.Now() + entry := &entry[V]{ + value: value, + created: now, + } + entry.lastUse.Store(now.UnixMicro()) + + for name, i := range c.indexMap { + keys := value.Keys(name) + i.Set(keys, entry) + c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys) + } +} + +func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error { + i, ok := c.indexMap[index] + if !ok { + return cache.NewIndexUnknownErr(index) + } + i.Invalidate(keys) + c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys) + return nil +} + +func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error { + i, ok := c.indexMap[index] + if !ok { + return cache.NewIndexUnknownErr(index) + } + i.Delete(keys) + c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys) + return nil +} + +func (c *mapCache[I, K, V]) Prune(ctx context.Context) error { + for name, index := range c.indexMap { + index.Prune() + c.logger.DebugContext(ctx, "map cache prune", "index", name) + } + return nil +} + +func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error { + for name, index := range c.indexMap { + index.Truncate() + c.logger.DebugContext(ctx, "map cache truncate", "index", name) + } + return nil +} + +type index[K comparable, V any] struct { + mutex sync.RWMutex + config *cache.Config + entries map[K]*entry[V] +} + +func (i *index[K, V]) Get(key K) (*entry[V], error) { + i.mutex.RLock() + entry, ok := i.entries[key] + i.mutex.RUnlock() + if ok && entry.isValid(i.config) { + return entry, nil + } + return nil, cache.ErrCacheMiss +} + +func (c *index[K, V]) Set(keys []K, entry *entry[V]) { + c.mutex.Lock() + for _, key := range keys { + c.entries[key] = entry + } + c.mutex.Unlock() +} + +func (i *index[K, V]) Invalidate(keys []K) { + i.mutex.RLock() + for _, key := range keys { + if entry, ok := i.entries[key]; ok { + entry.invalid.Store(true) + } + } + i.mutex.RUnlock() +} + +func (c *index[K, V]) Delete(keys []K) { + c.mutex.Lock() + for _, key := range keys { + delete(c.entries, key) + } + c.mutex.Unlock() +} + +func (c *index[K, V]) Prune() { + c.mutex.Lock() + maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool { + return !entry.isValid(c.config) + }) + c.mutex.Unlock() +} + +func (c *index[K, V]) Truncate() { + c.mutex.Lock() + c.entries = make(map[K]*entry[V]) + c.mutex.Unlock() +} + +type entry[V any] struct { + value V + created time.Time + invalid atomic.Bool + lastUse atomic.Int64 // UnixMicro time +} + +func (e *entry[V]) isValid(c *cache.Config) bool { + if e.invalid.Load() { + return false + } + now := time.Now() + if c.MaxAge > 0 { + if e.created.Add(c.MaxAge).Before(now) { + e.invalid.Store(true) + return false + } + } + if c.LastUseAge > 0 { + lastUse := e.lastUse.Load() + if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) { + e.invalid.Store(true) + return false + } + e.lastUse.CompareAndSwap(lastUse, now.UnixMicro()) + } + return true +} diff --git a/backend/command/receiver/cache/connector/gomap/gomap_test.go b/backend/command/receiver/cache/connector/gomap/gomap_test.go new file mode 100644 index 0000000000..8ed4f0f30a --- /dev/null +++ b/backend/command/receiver/cache/connector/gomap/gomap_test.go @@ -0,0 +1,329 @@ +package gomap + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/backend/storage/cache" +) + +type testIndex int + +const ( + testIndexID testIndex = iota + testIndexName +) + +var testIndices = []testIndex{ + testIndexID, + testIndexName, +} + +type testObject struct { + id string + names []string +} + +func (o *testObject) Keys(index testIndex) []string { + switch index { + case testIndexID: + return []string{o.id} + case testIndexName: + return o.names + default: + return nil + } +} + +func Test_mapCache_Get(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + + type args struct { + index testIndex + key string + } + tests := []struct { + name string + args args + want *testObject + wantOk bool + }{ + { + name: "ok", + args: args{ + index: testIndexID, + key: "id", + }, + want: obj, + wantOk: true, + }, + { + name: "miss", + args: args{ + index: testIndexID, + key: "spanac", + }, + want: nil, + wantOk: false, + }, + { + name: "unknown index", + args: args{ + index: 99, + key: "id", + }, + want: nil, + wantOk: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := c.Get(context.Background(), tt.args.index, tt.args.key) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wantOk, ok) + }) + } +} + +func Test_mapCache_Invalidate(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + err := c.Invalidate(context.Background(), testIndexName, "bar") + require.NoError(t, err) + got, ok := c.Get(context.Background(), testIndexID, "id") + assert.Nil(t, got) + assert.False(t, ok) +} + +func Test_mapCache_Delete(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + err := c.Delete(context.Background(), testIndexName, "bar") + require.NoError(t, err) + + // Shouldn't find object by deleted name + got, ok := c.Get(context.Background(), testIndexName, "bar") + assert.Nil(t, got) + assert.False(t, ok) + + // Should find object by other name + got, ok = c.Get(context.Background(), testIndexName, "foo") + assert.Equal(t, obj, got) + assert.True(t, ok) + + // Should find object by id + got, ok = c.Get(context.Background(), testIndexID, "id") + assert.Equal(t, obj, got) + assert.True(t, ok) +} + +func Test_mapCache_Prune(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + + objects := []*testObject{ + { + id: "id1", + names: []string{"foo", "bar"}, + }, + { + id: "id2", + names: []string{"hello"}, + }, + } + for _, obj := range objects { + c.Set(context.Background(), obj) + } + // invalidate one entry + err := c.Invalidate(context.Background(), testIndexName, "bar") + require.NoError(t, err) + + err = c.(cache.Pruner).Prune(context.Background()) + require.NoError(t, err) + + // Other object should still be found + got, ok := c.Get(context.Background(), testIndexID, "id2") + assert.Equal(t, objects[1], got) + assert.True(t, ok) +} + +func Test_mapCache_Truncate(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + objects := []*testObject{ + { + id: "id1", + names: []string{"foo", "bar"}, + }, + { + id: "id2", + names: []string{"hello"}, + }, + } + for _, obj := range objects { + c.Set(context.Background(), obj) + } + + err := c.Truncate(context.Background()) + require.NoError(t, err) + + mc := c.(*mapCache[testIndex, string, *testObject]) + for _, index := range mc.indexMap { + index.mutex.RLock() + assert.Len(t, index.entries, 0) + index.mutex.RUnlock() + } +} + +func Test_entry_isValid(t *testing.T) { + type fields struct { + created time.Time + invalid bool + lastUse time.Time + } + tests := []struct { + name string + fields fields + config *cache.Config + want bool + }{ + { + name: "invalid", + fields: fields{ + created: time.Now(), + invalid: true, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "max age exceeded", + fields: fields{ + created: time.Now().Add(-(time.Minute + time.Second)), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "max age disabled", + fields: fields{ + created: time.Now().Add(-(time.Minute + time.Second)), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + LastUseAge: time.Second, + }, + want: true, + }, + { + name: "last use age exceeded", + fields: fields{ + created: time.Now().Add(-(time.Minute / 2)), + invalid: false, + lastUse: time.Now().Add(-(time.Second * 2)), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "last use age disabled", + fields: fields{ + created: time.Now().Add(-(time.Minute / 2)), + invalid: false, + lastUse: time.Now().Add(-(time.Second * 2)), + }, + config: &cache.Config{ + MaxAge: time.Minute, + }, + want: true, + }, + { + name: "valid", + fields: fields{ + created: time.Now(), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &entry[any]{ + created: tt.fields.created, + } + e.invalid.Store(tt.fields.invalid) + e.lastUse.Store(tt.fields.lastUse.UnixMicro()) + got := e.isValid(tt.config) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/command/receiver/cache/connector/noop/noop.go b/backend/command/receiver/cache/connector/noop/noop.go new file mode 100644 index 0000000000..e3cf69c8ec --- /dev/null +++ b/backend/command/receiver/cache/connector/noop/noop.go @@ -0,0 +1,21 @@ +package noop + +import ( + "context" + + "github.com/zitadel/zitadel/backend/storage/cache" +) + +type noop[I, K comparable, V cache.Entry[I, K]] struct{} + +// NewCache returns a cache that does nothing +func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] { + return noop[I, K, V]{} +} + +func (noop[I, K, V]) Set(context.Context, V) {} +func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return } +func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return } +func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return } +func (noop[I, K, V]) Prune(context.Context) (err error) { return } +func (noop[I, K, V]) Truncate(context.Context) (err error) { return } diff --git a/backend/command/receiver/cache/connector_enumer.go b/backend/command/receiver/cache/connector_enumer.go new file mode 100644 index 0000000000..7ea014db16 --- /dev/null +++ b/backend/command/receiver/cache/connector_enumer.go @@ -0,0 +1,98 @@ +// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _ConnectorName = "memorypostgresredis" + +var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19} + +const _ConnectorLowerName = "memorypostgresredis" + +func (i Connector) String() string { + if i < 0 || i >= Connector(len(_ConnectorIndex)-1) { + return fmt.Sprintf("Connector(%d)", i) + } + return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _ConnectorNoOp() { + var x [1]struct{} + _ = x[ConnectorUnspecified-(0)] + _ = x[ConnectorMemory-(1)] + _ = x[ConnectorPostgres-(2)] + _ = x[ConnectorRedis-(3)] +} + +var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis} + +var _ConnectorNameToValueMap = map[string]Connector{ + _ConnectorName[0:0]: ConnectorUnspecified, + _ConnectorLowerName[0:0]: ConnectorUnspecified, + _ConnectorName[0:6]: ConnectorMemory, + _ConnectorLowerName[0:6]: ConnectorMemory, + _ConnectorName[6:14]: ConnectorPostgres, + _ConnectorLowerName[6:14]: ConnectorPostgres, + _ConnectorName[14:19]: ConnectorRedis, + _ConnectorLowerName[14:19]: ConnectorRedis, +} + +var _ConnectorNames = []string{ + _ConnectorName[0:0], + _ConnectorName[0:6], + _ConnectorName[6:14], + _ConnectorName[14:19], +} + +// ConnectorString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ConnectorString(s string) (Connector, error) { + if val, ok := _ConnectorNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Connector values", s) +} + +// ConnectorValues returns all values of the enum +func ConnectorValues() []Connector { + return _ConnectorValues +} + +// ConnectorStrings returns a slice of all String values of the enum +func ConnectorStrings() []string { + strs := make([]string, len(_ConnectorNames)) + copy(strs, _ConnectorNames) + return strs +} + +// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Connector) IsAConnector() bool { + for _, v := range _ConnectorValues { + if i == v { + return true + } + } + return false +} + +// MarshalText implements the encoding.TextMarshaler interface for Connector +func (i Connector) MarshalText() ([]byte, error) { + return []byte(i.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector +func (i *Connector) UnmarshalText(text []byte) error { + var err error + *i, err = ConnectorString(string(text)) + return err +} diff --git a/backend/command/receiver/cache/error.go b/backend/command/receiver/cache/error.go new file mode 100644 index 0000000000..b66b9447bf --- /dev/null +++ b/backend/command/receiver/cache/error.go @@ -0,0 +1,29 @@ +package cache + +import ( + "errors" + "fmt" +) + +type IndexUnknownError[I comparable] struct { + index I +} + +func NewIndexUnknownErr[I comparable](index I) error { + return IndexUnknownError[I]{index} +} + +func (i IndexUnknownError[I]) Error() string { + return fmt.Sprintf("index %v unknown", i.index) +} + +func (a IndexUnknownError[I]) Is(err error) bool { + if b, ok := err.(IndexUnknownError[I]); ok { + return a.index == b.index + } + return false +} + +var ( + ErrCacheMiss = errors.New("cache miss") +) diff --git a/backend/command/receiver/cache/pruner.go b/backend/command/receiver/cache/pruner.go new file mode 100644 index 0000000000..959762d410 --- /dev/null +++ b/backend/command/receiver/cache/pruner.go @@ -0,0 +1,76 @@ +package cache + +import ( + "context" + "math/rand" + "time" + + "github.com/jonboulle/clockwork" + "github.com/zitadel/logging" +) + +// Pruner is an optional [Cache] interface. +type Pruner interface { + // Prune deletes all invalidated or expired objects. + Prune(ctx context.Context) error +} + +type PrunerCache[I, K comparable, V Entry[I, K]] interface { + Cache[I, K, V] + Pruner +} + +type AutoPruneConfig struct { + // Interval at which the cache is automatically pruned. + // 0 or lower disables automatic pruning. + Interval time.Duration + + // Timeout for an automatic prune. + // It is recommended to keep the value shorter than AutoPruneInterval + // 0 or lower disables automatic pruning. + Timeout time.Duration +} + +func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) { + return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock()) +} + +func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) { + if c.Interval <= 0 { + return func() {} + } + background, cancel := context.WithCancel(background) + // randomize the first interval + timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval)))) + go c.pruneTimer(background, pruner, purpose, timer) + return cancel +} + +func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) { + defer func() { + if !timer.Stop() { + <-timer.Chan() + } + }() + + for { + select { + case <-background.Done(): + return + case <-timer.Chan(): + err := c.doPrune(background, pruner) + logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune") + timer.Reset(c.Interval) + } + } +} + +func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error { + ctx, cancel := context.WithCancel(background) + defer cancel() + if c.Timeout > 0 { + ctx, cancel = context.WithTimeout(background, c.Timeout) + defer cancel() + } + return pruner.Prune(ctx) +} diff --git a/backend/command/receiver/cache/pruner_test.go b/backend/command/receiver/cache/pruner_test.go new file mode 100644 index 0000000000..faaedeb88c --- /dev/null +++ b/backend/command/receiver/cache/pruner_test.go @@ -0,0 +1,43 @@ +package cache + +import ( + "context" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" +) + +type testPruner struct { + called chan struct{} +} + +func (p *testPruner) Prune(context.Context) error { + p.called <- struct{}{} + return nil +} + +func TestAutoPruneConfig_startAutoPrune(t *testing.T) { + c := AutoPruneConfig{ + Interval: time.Second, + Timeout: time.Millisecond, + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + pruner := testPruner{ + called: make(chan struct{}), + } + clock := clockwork.NewFakeClock() + close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock) + defer close() + clock.Advance(time.Second) + + select { + case _, ok := <-pruner.called: + assert.True(t, ok) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } +} diff --git a/backend/command/receiver/cache/purpose_enumer.go b/backend/command/receiver/cache/purpose_enumer.go new file mode 100644 index 0000000000..a93a978efb --- /dev/null +++ b/backend/command/receiver/cache/purpose_enumer.go @@ -0,0 +1,90 @@ +// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" + +var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65} + +const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" + +func (i Purpose) String() string { + if i < 0 || i >= Purpose(len(_PurposeIndex)-1) { + return fmt.Sprintf("Purpose(%d)", i) + } + return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _PurposeNoOp() { + var x [1]struct{} + _ = x[PurposeUnspecified-(0)] + _ = x[PurposeAuthzInstance-(1)] + _ = x[PurposeMilestones-(2)] + _ = x[PurposeOrganization-(3)] + _ = x[PurposeIdPFormCallback-(4)] +} + +var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback} + +var _PurposeNameToValueMap = map[string]Purpose{ + _PurposeName[0:11]: PurposeUnspecified, + _PurposeLowerName[0:11]: PurposeUnspecified, + _PurposeName[11:25]: PurposeAuthzInstance, + _PurposeLowerName[11:25]: PurposeAuthzInstance, + _PurposeName[25:35]: PurposeMilestones, + _PurposeLowerName[25:35]: PurposeMilestones, + _PurposeName[35:47]: PurposeOrganization, + _PurposeLowerName[35:47]: PurposeOrganization, + _PurposeName[47:65]: PurposeIdPFormCallback, + _PurposeLowerName[47:65]: PurposeIdPFormCallback, +} + +var _PurposeNames = []string{ + _PurposeName[0:11], + _PurposeName[11:25], + _PurposeName[25:35], + _PurposeName[35:47], + _PurposeName[47:65], +} + +// PurposeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func PurposeString(s string) (Purpose, error) { + if val, ok := _PurposeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Purpose values", s) +} + +// PurposeValues returns all values of the enum +func PurposeValues() []Purpose { + return _PurposeValues +} + +// PurposeStrings returns a slice of all String values of the enum +func PurposeStrings() []string { + strs := make([]string, len(_PurposeNames)) + copy(strs, _PurposeNames) + return strs +} + +// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Purpose) IsAPurpose() bool { + for _, v := range _PurposeValues { + if i == v { + return true + } + } + return false +} diff --git a/backend/command/receiver/instance.go b/backend/command/receiver/instance.go index 43be941eda..728bc6773b 100644 --- a/backend/command/receiver/instance.go +++ b/backend/command/receiver/instance.go @@ -1,6 +1,10 @@ package receiver -import "context" +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/receiver/cache" +) type InstanceState uint8 @@ -16,6 +20,31 @@ type Instance struct { Domains []*Domain } +type InstanceIndex uint8 + +var InstanceIndices = []InstanceIndex{ + InstanceByID, + InstanceByDomain, +} + +const ( + InstanceByID InstanceIndex = iota + InstanceByDomain +) + +var _ cache.Entry[InstanceIndex, string] = (*Instance)(nil) + +// Keys implements [cache.Entry]. +func (i *Instance) Keys(index InstanceIndex) (key []string) { + switch index { + case InstanceByID: + return []string{i.ID} + case InstanceByDomain: + return []string{i.Name} + } + return nil +} + type InstanceManipulator interface { Create(ctx context.Context, instance *Instance) error Delete(ctx context.Context, instance *Instance) error diff --git a/backend/command/v2/api/doc.go b/backend/command/v2/api/doc.go new file mode 100644 index 0000000000..f67198d466 --- /dev/null +++ b/backend/command/v2/api/doc.go @@ -0,0 +1,6 @@ +// The API package implements the protobuf stubs + +// It uses the Chain of responsibility pattern to handle requests in a modular way +// It implements the client or invoker of the command pattern. +// The client is responsible for creating the concrete command and setting its receiver. +package api diff --git a/backend/command/v2/api/user/v2/email.go b/backend/command/v2/api/user/v2/email.go new file mode 100644 index 0000000000..76e266e2c3 --- /dev/null +++ b/backend/command/v2/api/user/v2/email.go @@ -0,0 +1,35 @@ +package userv2 + +import ( + "context" + + "github.com/muhlemmer/gu" + "github.com/zitadel/zitadel/backend/command/v2/domain" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +func (s *Server) SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) { + request := &domain.SetUserEmail{ + UserID: req.GetUserId(), + Email: req.GetEmail(), + } + switch req.GetVerification().(type) { + case *user.SetEmailRequest_IsVerified: + request.IsVerified = gu.Ptr(req.GetIsVerified()) + case *user.SetEmailRequest_SendCode: + request.SendCode = &domain.SendCode{ + URLTemplate: req.GetSendCode().UrlTemplate, + } + case *user.SetEmailRequest_ReturnCode: + request.ReturnCode = new(domain.ReturnCode) + } + if err := s.domain.SetUserEmail(ctx, request); err != nil { + return nil, err + } + + response := new(user.SetEmailResponse) + if request.ReturnCode != nil { + response.VerificationCode = &request.ReturnCode.Code + } + return response, nil +} diff --git a/backend/command/v2/api/user/v2/server.go b/backend/command/v2/api/user/v2/server.go new file mode 100644 index 0000000000..583149a6b6 --- /dev/null +++ b/backend/command/v2/api/user/v2/server.go @@ -0,0 +1,12 @@ +package userv2 + +import ( + "go.opentelemetry.io/otel/trace" + + "github.com/zitadel/zitadel/backend/command/v2/domain" +) + +type Server struct { + tracer trace.Tracer + domain *domain.Domain +} diff --git a/backend/command/v2/domain/command/generate_code.go b/backend/command/v2/domain/command/generate_code.go new file mode 100644 index 0000000000..7ec00bd2e2 --- /dev/null +++ b/backend/command/v2/domain/command/generate_code.go @@ -0,0 +1,41 @@ +package command + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/pattern" + "github.com/zitadel/zitadel/internal/crypto" +) + +type generateCode struct { + set func(code string) + generator pattern.Query[crypto.Generator] +} + +func GenerateCode(set func(code string), generator pattern.Query[crypto.Generator]) *generateCode { + return &generateCode{ + set: set, + generator: generator, + } +} + +var _ pattern.Command = (*generateCode)(nil) + +// Execute implements [pattern.Command]. +func (cmd *generateCode) Execute(ctx context.Context) error { + if err := cmd.generator.Execute(ctx); err != nil { + return err + } + value, code, err := crypto.NewCode(cmd.generator.Result()) + _ = value + if err != nil { + return err + } + cmd.set(code) + return nil +} + +// Name implements [pattern.Command]. +func (*generateCode) Name() string { + return "command.generate_code" +} diff --git a/backend/command/v2/domain/command/send_email_code.go b/backend/command/v2/domain/command/send_email_code.go new file mode 100644 index 0000000000..9e752332b8 --- /dev/null +++ b/backend/command/v2/domain/command/send_email_code.go @@ -0,0 +1,41 @@ +package command + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/pattern" +) + +var _ pattern.Command = (*sendEmailCode)(nil) + +type sendEmailCode struct { + UserID string `json:"userId"` + Email string `json:"email"` + URLTemplate *string `json:"urlTemplate"` + code string `json:"-"` +} + +func SendEmailCode(userID, email string, urlTemplate *string) pattern.Command { + cmd := &sendEmailCode{ + UserID: userID, + Email: email, + URLTemplate: urlTemplate, + } + + return pattern.Batch(GenerateCode(cmd.SetCode, generateCode)) +} + +// Name implements [pattern.Command]. +func (c *sendEmailCode) Name() string { + return "user.v2.email.send_code" +} + +// Execute implements [pattern.Command]. +func (c *sendEmailCode) Execute(ctx context.Context) error { + // Implementation of the command execution + return nil +} + +func (c *sendEmailCode) SetCode(code string) { + c.code = code +} diff --git a/backend/command/v2/domain/command/set_email.go b/backend/command/v2/domain/command/set_email.go new file mode 100644 index 0000000000..5e525d8ea3 --- /dev/null +++ b/backend/command/v2/domain/command/set_email.go @@ -0,0 +1,39 @@ +package command + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/storage/eventstore" +) + +var ( + _ eventstore.EventCommander = (*setEmail)(nil) +) + +type setEmail struct { + UserID string `json:"userId"` + Email string `json:"email"` +} + +func SetEmail(userID, email string) *setEmail { + return &setEmail{ + UserID: userID, + Email: email, + } +} + +// Event implements [eventstore.EventCommander]. +func (c *setEmail) Event() *eventstore.Event { + panic("unimplemented") +} + +// Name implements [pattern.Command]. +func (c *setEmail) Name() string { + return "user.v2.set_email" +} + +// Execute implements [pattern.Command]. +func (c *setEmail) Execute(ctx context.Context) error { + // Implementation of the command execution + return nil +} diff --git a/backend/command/v2/domain/command/verify_email.go b/backend/command/v2/domain/command/verify_email.go new file mode 100644 index 0000000000..80ec2c1b04 --- /dev/null +++ b/backend/command/v2/domain/command/verify_email.go @@ -0,0 +1,32 @@ +package command + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/pattern" +) + +var _ pattern.Command = (*verifyEmail)(nil) + +type verifyEmail struct { + UserID string `json:"userId"` + Email string `json:"email"` +} + +func VerifyEmail(userID, email string) *verifyEmail { + return &verifyEmail{ + UserID: userID, + Email: email, + } +} + +// Name implements [pattern.Command]. +func (c *verifyEmail) Name() string { + return "user.v2.verify_email" +} + +// Execute implements [pattern.Command]. +func (c *verifyEmail) Execute(ctx context.Context) error { + // Implementation of the command execution + return nil +} diff --git a/backend/command/v2/domain/domain.go b/backend/command/v2/domain/domain.go new file mode 100644 index 0000000000..e6ee5d9c61 --- /dev/null +++ b/backend/command/v2/domain/domain.go @@ -0,0 +1,13 @@ +package domain + +import ( + "github.com/zitadel/zitadel/backend/command/v2/storage/database" + "github.com/zitadel/zitadel/internal/crypto" + "go.opentelemetry.io/otel/trace" +) + +type Domain struct { + pool database.Pool + tracer trace.Tracer + userCodeAlg crypto.EncryptionAlgorithm +} diff --git a/backend/command/v2/domain/email.go b/backend/command/v2/domain/email.go new file mode 100644 index 0000000000..14032e9d8e --- /dev/null +++ b/backend/command/v2/domain/email.go @@ -0,0 +1,6 @@ +package domain + +type Email struct { + Address string + Verified bool +} diff --git a/backend/command/v2/domain/query/encryption_generator.go b/backend/command/v2/domain/query/encryption_generator.go new file mode 100644 index 0000000000..dc9fe66056 --- /dev/null +++ b/backend/command/v2/domain/query/encryption_generator.go @@ -0,0 +1,42 @@ +package query + +import ( + "context" + + "github.com/zitadel/zitadel/internal/crypto" +) + +type encryptionConfigReceiver interface { + GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) +} + +type encryptionGenerator struct { + receiver encryptionConfigReceiver + algorithm crypto.EncryptionAlgorithm + + res crypto.Generator +} + +func QueryEncryptionGenerator(receiver encryptionConfigReceiver, algorithm crypto.EncryptionAlgorithm) *encryptionGenerator { + return &encryptionGenerator{ + receiver: receiver, + algorithm: algorithm, + } +} + +func (q *encryptionGenerator) Execute(ctx context.Context) error { + config, err := q.receiver.GetEncryptionConfig(ctx) + if err != nil { + return err + } + q.res = crypto.NewEncryptionGenerator(*config, q.algorithm) + return nil +} + +func (q *encryptionGenerator) Name() string { + return "query.encryption_generator" +} + +func (q *encryptionGenerator) Result() crypto.Generator { + return q.res +} diff --git a/backend/command/v2/domain/query/return_email_code.go b/backend/command/v2/domain/query/return_email_code.go new file mode 100644 index 0000000000..ae726fbe54 --- /dev/null +++ b/backend/command/v2/domain/query/return_email_code.go @@ -0,0 +1,38 @@ +package query + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/pattern" +) + +var _ pattern.Query[string] = (*returnEmailCode)(nil) + +type returnEmailCode struct { + UserID string `json:"userId"` + Email string `json:"email"` + code string `json:"-"` +} + +func ReturnEmailCode(userID, email string) *returnEmailCode { + return &returnEmailCode{ + UserID: userID, + Email: email, + } +} + +// Name implements [pattern.Command]. +func (c *returnEmailCode) Name() string { + return "user.v2.email.return_code" +} + +// Execute implements [pattern.Command]. +func (c *returnEmailCode) Execute(ctx context.Context) error { + // Implementation of the command execution + return nil +} + +// Result implements [pattern.Query]. +func (c *returnEmailCode) Result() string { + return c.code +} diff --git a/backend/command/v2/domain/query/user_by_id.go b/backend/command/v2/domain/query/user_by_id.go new file mode 100644 index 0000000000..04a6b9ea77 --- /dev/null +++ b/backend/command/v2/domain/query/user_by_id.go @@ -0,0 +1,38 @@ +package query + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/domain" + "github.com/zitadel/zitadel/backend/command/v2/pattern" + "github.com/zitadel/zitadel/backend/command/v2/storage/database" +) + +type UserByIDQuery struct { + querier database.Querier + UserID string `json:"userId"` + res *domain.User +} + +var _ pattern.Query[*domain.User] = (*UserByIDQuery)(nil) + +// Name implements [pattern.Command]. +func (q *UserByIDQuery) Name() string { + return "user.v2.by_id" +} + +// Execute implements [pattern.Command]. +func (q *UserByIDQuery) Execute(ctx context.Context) error { + var res *domain.User + err := q.querier.QueryRow(ctx, "SELECT id, username, email FROM users WHERE id = $1", q.UserID).Scan(&res.ID, &res.Username, &res.Email.Address) + if err != nil { + return err + } + q.res = res + return nil +} + +// Result implements [pattern.Query]. +func (q *UserByIDQuery) Result() *domain.User { + return q.res +} diff --git a/backend/command/v2/domain/user.go b/backend/command/v2/domain/user.go new file mode 100644 index 0000000000..9e4c3a6c87 --- /dev/null +++ b/backend/command/v2/domain/user.go @@ -0,0 +1,77 @@ +package domain + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/domain/command" + "github.com/zitadel/zitadel/backend/command/v2/domain/query" + "github.com/zitadel/zitadel/backend/command/v2/pattern" + "github.com/zitadel/zitadel/backend/command/v2/storage/database" + "github.com/zitadel/zitadel/backend/command/v2/telemetry/tracing" +) + +type User struct { + ID string + Username string + Email Email +} + +type SetUserEmail struct { + UserID string + Email string + + IsVerified *bool + ReturnCode *ReturnCode + SendCode *SendCode + + code string + client database.QueryExecutor +} + +func (e *SetUserEmail) SetCode(code string) { + e.code = code +} + +type ReturnCode struct { + // Code is the code to be sent to the user + Code string +} + +type SendCode struct { + // URLTemplate is the template for the URL that is rendered into the message + URLTemplate *string +} + +func (d *Domain) SetUserEmail(ctx context.Context, req *SetUserEmail) error { + batch := pattern.Batch( + tracing.Trace(d.tracer, command.SetEmail(req.UserID, req.Email)), + ) + + if req.IsVerified == nil { + batch.Append(command.GenerateCode( + req.SetCode, + query.QueryEncryptionGenerator( + database.Query(d.pool), + d.userCodeAlg, + ), + )) + } else { + batch.Append(command.VerifyEmail(req.UserID, req.Email)) + } + + // if !req.GetVerification().GetIsVerified() { + // batch. + + // switch req.GetVerification().(type) { + // case *user.SetEmailRequest_IsVerified: + // batch.Append(tracing.Trace(s.tracer, command.VerifyEmail(req.GetUserId(), req.GetEmail()))) + // case *user.SetEmailRequest_SendCode: + // batch.Append(tracing.Trace(s.tracer, command.SendEmailCode(req.GetUserId(), req.GetEmail(), req.GetSendCode().UrlTemplate))) + // case *user.SetEmailRequest_ReturnCode: + // batch.Append(tracing.Trace(s.tracer, query.ReturnEmailCode(req.GetUserId(), req.GetEmail()))) + // } + + // if err := batch.Execute(ctx); err != nil { + // return nil, err + // } +} diff --git a/backend/command/v2/pattern/command.go b/backend/command/v2/pattern/command.go new file mode 100644 index 0000000000..787e05dc62 --- /dev/null +++ b/backend/command/v2/pattern/command.go @@ -0,0 +1,100 @@ +package pattern + +import ( + "context" + + "github.com/zitadel/zitadel/backend/command/v2/storage/database" +) + +// Command implements the command pattern. +// It is used to encapsulate a request as an object, thereby allowing for parameterization of clients with queues, requests, and operations. +// The command pattern allows for the decoupling of the sender and receiver of a request. +// It is often used in conjunction with the invoker pattern, which is responsible for executing the command. +// The command pattern is a behavioral design pattern that turns a request into a stand-alone object. +// This object contains all the information about the request. +// The command pattern is useful for implementing undo/redo functionality, logging, and queuing requests. +// It is also useful for implementing the macro command pattern, which allows for the execution of a series of commands as a single command. +// The command pattern is also used in event-driven architectures, where events are encapsulated as commands. +type Command interface { + Execute(ctx context.Context) error + Name() string +} + +type Query[T any] interface { + Command + Result() T +} + +type Invoker struct{} + +// func bla() { +// sync.Pool{ +// New: func() any { +// return new(Invoker) +// }, +// } +// } + +type Transaction struct { + beginner database.Beginner + cmd Command + opts *database.TransactionOptions +} + +func (t *Transaction) Execute(ctx context.Context) error { + tx, err := t.beginner.Begin(ctx, t.opts) + if err != nil { + return err + } + defer func() { err = tx.End(ctx, err) }() + return t.cmd.Execute(ctx) +} + +func (t *Transaction) Name() string { + return t.cmd.Name() +} + +type batch struct { + Commands []Command +} + +func Batch(cmds ...Command) *batch { + return &batch{ + Commands: cmds, + } +} + +func (b *batch) Execute(ctx context.Context) error { + for _, cmd := range b.Commands { + if err := cmd.Execute(ctx); err != nil { + return err + } + } + return nil +} + +func (b *batch) Name() string { + return "batch" +} + +func (b *batch) Append(cmds ...Command) { + b.Commands = append(b.Commands, cmds...) +} + +type NoopCommand struct{} + +func (c *NoopCommand) Execute(_ context.Context) error { + return nil +} +func (c *NoopCommand) Name() string { + return "noop" +} + +type NoopQuery[T any] struct { + NoopCommand +} + +func (q *NoopQuery[T]) Result() T { + var zero T + return zero +} diff --git a/backend/command/v2/storage/database/config.go b/backend/command/v2/storage/database/config.go new file mode 100644 index 0000000000..d9aa99b869 --- /dev/null +++ b/backend/command/v2/storage/database/config.go @@ -0,0 +1,9 @@ +package database + +import ( + "context" +) + +type Connector interface { + Connect(ctx context.Context) (Pool, error) +} diff --git a/backend/command/v2/storage/database/database.go b/backend/command/v2/storage/database/database.go new file mode 100644 index 0000000000..6dddefe22c --- /dev/null +++ b/backend/command/v2/storage/database/database.go @@ -0,0 +1,54 @@ +package database + +import ( + "context" +) + +var ( + db *database +) + +type database struct { + connector Connector + pool Pool +} + +type Pool interface { + Beginner + QueryExecutor + + Acquire(ctx context.Context) (Client, error) + Close(ctx context.Context) error +} + +type Client interface { + Beginner + QueryExecutor + + Release(ctx context.Context) error +} + +type Querier interface { + Query(ctx context.Context, stmt string, args ...any) (Rows, error) + QueryRow(ctx context.Context, stmt string, args ...any) Row +} + +type Executor interface { + Exec(ctx context.Context, stmt string, args ...any) error +} + +type Row interface { + Scan(dest ...any) error +} + +type Rows interface { + Row + Next() bool + Close() error + Err() error +} + +type QueryExecutor interface { + Querier + Executor +} diff --git a/backend/command/v2/storage/database/dialect/config.go b/backend/command/v2/storage/database/dialect/config.go new file mode 100644 index 0000000000..a044f7bd4e --- /dev/null +++ b/backend/command/v2/storage/database/dialect/config.go @@ -0,0 +1,92 @@ +package dialect + +import ( + "context" + "errors" + "reflect" + + "github.com/mitchellh/mapstructure" + "github.com/spf13/viper" + + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/backend/storage/database/dialect/postgres" +) + +type Hook struct { + Match func(string) bool + Decode func(config any) (database.Connector, error) + Name string + Constructor func() database.Connector +} + +var hooks = []Hook{ + { + Match: postgres.NameMatcher, + Decode: postgres.DecodeConfig, + Name: postgres.Name, + Constructor: func() database.Connector { return new(postgres.Config) }, + }, + // { + // Match: gosql.NameMatcher, + // Decode: gosql.DecodeConfig, + // Name: gosql.Name, + // Constructor: func() database.Connector { return new(gosql.Config) }, + // }, +} + +type Config struct { + Dialects map[string]any `mapstructure:",remain" yaml:",inline"` + + connector database.Connector +} + +func (c Config) Connect(ctx context.Context) (database.Pool, error) { + if len(c.Dialects) != 1 { + return nil, errors.New("Exactly one dialect must be configured") + } + + return c.connector.Connect(ctx) +} + +// Hooks implements [configure.Unmarshaller]. +func (c Config) Hooks() []viper.DecoderConfigOption { + return []viper.DecoderConfigOption{ + viper.DecodeHook(decodeHook), + } +} + +func decodeHook(from, to reflect.Value) (_ any, err error) { + if to.Type() != reflect.TypeOf(Config{}) { + return from.Interface(), nil + } + + config := new(Config) + if err = mapstructure.Decode(from.Interface(), config); err != nil { + return nil, err + } + + if err = config.decodeDialect(); err != nil { + return nil, err + } + + return config, nil +} + +func (c *Config) decodeDialect() error { + for _, hook := range hooks { + for name, config := range c.Dialects { + if !hook.Match(name) { + continue + } + + connector, err := hook.Decode(config) + if err != nil { + return err + } + + c.connector = connector + return nil + } + } + return errors.New("no dialect found") +} diff --git a/backend/command/v2/storage/database/dialect/postgres/config.go b/backend/command/v2/storage/database/dialect/postgres/config.go new file mode 100644 index 0000000000..1007c09542 --- /dev/null +++ b/backend/command/v2/storage/database/dialect/postgres/config.go @@ -0,0 +1,80 @@ +package postgres + +import ( + "context" + "errors" + "slices" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/mitchellh/mapstructure" + + "github.com/zitadel/zitadel/backend/command/v2/storage/database" +) + +var ( + _ database.Connector = (*Config)(nil) + Name = "postgres" +) + +type Config struct { + config *pgxpool.Config + + // Host string + // Port int32 + // Database string + // MaxOpenConns uint32 + // MaxIdleConns uint32 + // MaxConnLifetime time.Duration + // MaxConnIdleTime time.Duration + // User User + // // Additional options to be appended as options= + // // The value will be taken as is. Multiple options are space separated. + // Options string + + configuredFields []string +} + +// Connect implements [database.Connector]. +func (c *Config) Connect(ctx context.Context) (database.Pool, error) { + pool, err := pgxpool.NewWithConfig(ctx, c.config) + if err != nil { + return nil, err + } + if err = pool.Ping(ctx); err != nil { + return nil, err + } + return &pgxPool{pool}, nil +} + +func NameMatcher(name string) bool { + return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name)) +} + +func DecodeConfig(input any) (database.Connector, error) { + switch c := input.(type) { + case string: + config, err := pgxpool.ParseConfig(c) + if err != nil { + return nil, err + } + return &Config{config: config}, nil + case map[string]any: + connector := new(Config) + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + WeaklyTypedInput: true, + Result: connector, + }) + if err != nil { + return nil, err + } + if err = decoder.Decode(c); err != nil { + return nil, err + } + return &Config{ + config: &pgxpool.Config{}, + }, nil + } + return nil, errors.New("invalid configuration") +} diff --git a/backend/command/v2/storage/database/dialect/postgres/conn.go b/backend/command/v2/storage/database/dialect/postgres/conn.go new file mode 100644 index 0000000000..e7bdc0741a --- /dev/null +++ b/backend/command/v2/storage/database/dialect/postgres/conn.go @@ -0,0 +1,48 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/command/v2/storage/database" +) + +type pgxConn struct{ *pgxpool.Conn } + +var _ database.Client = (*pgxConn)(nil) + +// Release implements [database.Client]. +func (c *pgxConn) Release(_ context.Context) error { + c.Conn.Release() + return nil +} + +// Begin implements [database.Client]. +func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts)) + if err != nil { + return nil, err + } + return &pgxTx{tx}, nil +} + +// Query implements sql.Client. +// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn. +func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := c.Conn.Query(ctx, sql, args...) + return &Rows{rows}, err +} + +// QueryRow implements sql.Client. +// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn. +func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return c.Conn.QueryRow(ctx, sql, args...) +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error { + _, err := c.Conn.Exec(ctx, sql, args...) + return err +} diff --git a/backend/command/v2/storage/database/dialect/postgres/pool.go b/backend/command/v2/storage/database/dialect/postgres/pool.go new file mode 100644 index 0000000000..aba0231213 --- /dev/null +++ b/backend/command/v2/storage/database/dialect/postgres/pool.go @@ -0,0 +1,57 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/command/v2/storage/database" +) + +type pgxPool struct{ *pgxpool.Pool } + +var _ database.Pool = (*pgxPool)(nil) + +// Acquire implements [database.Pool]. +func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { + conn, err := c.Pool.Acquire(ctx) + if err != nil { + return nil, err + } + return &pgxConn{conn}, nil +} + +// Query implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. +func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := c.Pool.Query(ctx, sql, args...) + return &Rows{rows}, err +} + +// QueryRow implements [database.Pool]. +// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. +func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return c.Pool.QueryRow(ctx, sql, args...) +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error { + _, err := c.Pool.Exec(ctx, sql, args...) + return err +} + +// Begin implements [database.Pool]. +func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts)) + if err != nil { + return nil, err + } + return &pgxTx{tx}, nil +} + +// Close implements [database.Pool]. +func (c *pgxPool) Close(_ context.Context) error { + c.Pool.Close() + return nil +} diff --git a/backend/command/v2/storage/database/dialect/postgres/rows.go b/backend/command/v2/storage/database/dialect/postgres/rows.go new file mode 100644 index 0000000000..c5ec8aabfd --- /dev/null +++ b/backend/command/v2/storage/database/dialect/postgres/rows.go @@ -0,0 +1,18 @@ +package postgres + +import ( + "github.com/jackc/pgx/v5" + + "github.com/zitadel/zitadel/backend/command/v2/storage/database" +) + +var _ database.Rows = (*Rows)(nil) + +type Rows struct{ pgx.Rows } + +// Close implements [database.Rows]. +// Subtle: this method shadows the method (Rows).Close of Rows.Rows. +func (r *Rows) Close() error { + r.Rows.Close() + return nil +} diff --git a/backend/command/v2/storage/database/dialect/postgres/tx.go b/backend/command/v2/storage/database/dialect/postgres/tx.go new file mode 100644 index 0000000000..677a433240 --- /dev/null +++ b/backend/command/v2/storage/database/dialect/postgres/tx.go @@ -0,0 +1,95 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5" + + "github.com/zitadel/zitadel/backend/command/v2/storage/database" +) + +type pgxTx struct{ pgx.Tx } + +var _ database.Transaction = (*pgxTx)(nil) + +// Commit implements [database.Transaction]. +func (tx *pgxTx) Commit(ctx context.Context) error { + return tx.Tx.Commit(ctx) +} + +// Rollback implements [database.Transaction]. +func (tx *pgxTx) Rollback(ctx context.Context) error { + return tx.Tx.Rollback(ctx) +} + +// End implements [database.Transaction]. +func (tx *pgxTx) End(ctx context.Context, err error) error { + if err != nil { + tx.Rollback(ctx) + return err + } + return tx.Commit(ctx) +} + +// Query implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx. +func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := tx.Tx.Query(ctx, sql, args...) + return &Rows{rows}, err +} + +// QueryRow implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx. +func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return tx.Tx.QueryRow(ctx, sql, args...) +} + +// Exec implements [database.Transaction]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error { + _, err := tx.Tx.Exec(ctx, sql, args...) + return err +} + +// Begin implements [database.Transaction]. +// As postgres does not support nested transactions we use savepoints to emulate them. +func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) { + savepoint, err := tx.Tx.Begin(ctx) + if err != nil { + return nil, err + } + return &pgxTx{savepoint}, nil +} + +func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions { + if opts == nil { + return pgx.TxOptions{} + } + + return pgx.TxOptions{ + IsoLevel: isolationToPgx(opts.IsolationLevel), + AccessMode: accessModeToPgx(opts.AccessMode), + } +} + +func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel { + switch isolation { + case database.IsolationLevelSerializable: + return pgx.Serializable + case database.IsolationLevelReadCommitted: + return pgx.ReadCommitted + default: + return pgx.Serializable + } +} + +func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode { + switch accessMode { + case database.AccessModeReadWrite: + return pgx.ReadWrite + case database.AccessModeReadOnly: + return pgx.ReadOnly + default: + return pgx.ReadWrite + } +} diff --git a/backend/command/v2/storage/database/secret_generator.go b/backend/command/v2/storage/database/secret_generator.go new file mode 100644 index 0000000000..fddebdce44 --- /dev/null +++ b/backend/command/v2/storage/database/secret_generator.go @@ -0,0 +1,39 @@ +package database + +import ( + "context" + + "github.com/zitadel/zitadel/internal/crypto" +) + +type query struct{ Querier } + +func Query(querier Querier) *query { + return &query{Querier: querier} +} + +const getEncryptionConfigQuery = "SELECT" + + " length" + + ", expiry" + + ", should_include_lower_letters" + + ", should_include_upper_letters" + + ", should_include_digits" + + ", should_include_symbols" + + " FROM encryption_config" + +func (q query) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) { + var config crypto.GeneratorConfig + row := q.QueryRow(ctx, getEncryptionConfigQuery) + err := row.Scan( + &config.Length, + &config.Expiry, + &config.IncludeLowerLetters, + &config.IncludeUpperLetters, + &config.IncludeDigits, + &config.IncludeSymbols, + ) + if err != nil { + return nil, err + } + return &config, nil +} diff --git a/backend/command/v2/storage/database/tx.go b/backend/command/v2/storage/database/tx.go new file mode 100644 index 0000000000..02c945dc77 --- /dev/null +++ b/backend/command/v2/storage/database/tx.go @@ -0,0 +1,36 @@ +package database + +import "context" + +type Transaction interface { + Commit(ctx context.Context) error + Rollback(ctx context.Context) error + End(ctx context.Context, err error) error + + Begin(ctx context.Context) (Transaction, error) + + QueryExecutor +} + +type Beginner interface { + Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error) +} + +type TransactionOptions struct { + IsolationLevel IsolationLevel + AccessMode AccessMode +} + +type IsolationLevel uint8 + +const ( + IsolationLevelSerializable IsolationLevel = iota + IsolationLevelReadCommitted +) + +type AccessMode uint8 + +const ( + AccessModeReadWrite AccessMode = iota + AccessModeReadOnly +) diff --git a/backend/command/v2/storage/eventstore/event.go b/backend/command/v2/storage/eventstore/event.go new file mode 100644 index 0000000000..52d0491558 --- /dev/null +++ b/backend/command/v2/storage/eventstore/event.go @@ -0,0 +1,13 @@ +package eventstore + +import "github.com/zitadel/zitadel/backend/command/v2/pattern" + +type Event struct { + AggregateType string `json:"aggregateType"` + AggregateID string `json:"aggregateId"` +} + +type EventCommander interface { + pattern.Command + Event() *Event +} diff --git a/backend/command/v2/telemetry/tracing/command.go b/backend/command/v2/telemetry/tracing/command.go new file mode 100644 index 0000000000..679f1b5512 --- /dev/null +++ b/backend/command/v2/telemetry/tracing/command.go @@ -0,0 +1,55 @@ +package tracing + +import ( + "context" + + "go.opentelemetry.io/otel/trace" + + "github.com/zitadel/zitadel/backend/command/v2/pattern" +) + +type command struct { + trace.Tracer + cmd pattern.Command +} + +func Trace(tracer trace.Tracer, cmd pattern.Command) pattern.Command { + return &command{ + Tracer: tracer, + cmd: cmd, + } +} + +func (cmd *command) Name() string { + return cmd.cmd.Name() +} + +func (cmd *command) Execute(ctx context.Context) error { + ctx, span := cmd.Tracer.Start(ctx, cmd.Name()) + defer span.End() + + err := cmd.cmd.Execute(ctx) + if err != nil { + span.RecordError(err) + } + return err +} + +type query[T any] struct { + command + query pattern.Query[T] +} + +func Query[T any](tracer trace.Tracer, q pattern.Query[T]) pattern.Query[T] { + return &query[T]{ + command: command{ + Tracer: tracer, + cmd: q, + }, + query: q, + } +} + +func (q *query[T]) Result() T { + return q.query.Result() +} diff --git a/backend/v3/api/instance/v2/server.go b/backend/v3/api/instance/v2/server.go new file mode 100644 index 0000000000..07e09c68aa --- /dev/null +++ b/backend/v3/api/instance/v2/server.go @@ -0,0 +1,19 @@ +package v2 + +import ( + "github.com/zitadel/zitadel/backend/v3/telemetry/logging" + "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +) + +var ( + logger logging.Logger + tracer tracing.Tracer +) + +func SetLogger(l logging.Logger) { + logger = l +} + +func SetTracer(t tracing.Tracer) { + tracer = t +} diff --git a/backend/v3/api/user/v2/email.go b/backend/v3/api/user/v2/email.go new file mode 100644 index 0000000000..f02c7ff0dd --- /dev/null +++ b/backend/v3/api/user/v2/email.go @@ -0,0 +1,93 @@ +package userv2 + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/pkg/grpc/user/v2" +) + +func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) { + var ( + verification domain.SetEmailOpt + returnCode *domain.ReturnCodeCommand + ) + + switch req.GetVerification().(type) { + case *user.SetEmailRequest_IsVerified: + verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified()) + case *user.SetEmailRequest_SendCode: + verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) + case *user.SetEmailRequest_ReturnCode: + returnCode = domain.NewReturnCodeCommand(req.GetUserId()) + verification = returnCode + default: + verification = domain.NewSendCodeCommand(req.GetUserId(), nil) + } + + err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification)) + if err != nil { + return nil, err + } + + var code *string + if returnCode != nil && returnCode.Code != "" { + code = &returnCode.Code + } + + return &user.SetEmailResponse{ + VerificationCode: code, + }, nil +} + +func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) { + var ( + returnCode *domain.ReturnCodeCommand + cmd domain.Commander + ) + + switch req.GetVerification().(type) { + case *user.SendEmailCodeRequest_SendCode: + cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) + case *user.SendEmailCodeRequest_ReturnCode: + returnCode = domain.NewReturnCodeCommand(req.GetUserId()) + cmd = returnCode + default: + cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) + } + err = domain.Invoke(ctx, cmd) + if err != nil { + return nil, err + } + resp = new(user.SendEmailCodeResponse) + if returnCode != nil { + resp.VerificationCode = &returnCode.Code + } + return resp, nil +} + +func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) { + var ( + returnCode *domain.ReturnCodeCommand + cmd domain.Commander + ) + + switch req.GetVerification().(type) { + case *user.ResendEmailCodeRequest_SendCode: + cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) + case *user.ResendEmailCodeRequest_ReturnCode: + returnCode = domain.NewReturnCodeCommand(req.GetUserId()) + cmd = returnCode + default: + cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) + } + err = domain.Invoke(ctx, cmd) + if err != nil { + return nil, err + } + resp = new(user.SendEmailCodeResponse) + if returnCode != nil { + resp.VerificationCode = &returnCode.Code + } + return resp, nil +} diff --git a/backend/v3/api/user/v2/server.go b/backend/v3/api/user/v2/server.go new file mode 100644 index 0000000000..c179a4a417 --- /dev/null +++ b/backend/v3/api/user/v2/server.go @@ -0,0 +1,19 @@ +package userv2 + +import ( + "github.com/zitadel/zitadel/backend/v3/telemetry/logging" + "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +) + +var ( + logger logging.Logger + tracer tracing.Tracer +) + +func SetLogger(l logging.Logger) { + logger = l +} + +func SetTracer(t tracing.Tracer) { + tracer = t +} diff --git a/backend/v3/doc.go b/backend/v3/doc.go new file mode 100644 index 0000000000..7032207dcd --- /dev/null +++ b/backend/v3/doc.go @@ -0,0 +1,12 @@ +// the test used the manly relies on the following patterns: +// - domain: +// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package +// - command pattern which implements the changes +// - the invoker decorates the commands by checking for events and tracing +// - the database connections are manged in this package +// - the database connections are passed to the repositories +// +// - storage: +// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package +// - the repositories are used by the domain package to access the database +package v3 diff --git a/backend/v3/domain/command.go b/backend/v3/domain/command.go new file mode 100644 index 0000000000..40809c6e1e --- /dev/null +++ b/backend/v3/domain/command.go @@ -0,0 +1,105 @@ +package domain + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Commander interface { + Execute(ctx context.Context, opts *CommandOpts) (err error) +} + +type Invoker interface { + Invoke(ctx context.Context, command Commander, opts *CommandOpts) error +} + +type CommandOpts struct { + DB database.QueryExecutor + Invoker Invoker +} + +type ensureTxOpts struct { + *database.TransactionOptions +} + +type EnsureTransactionOpt func(*ensureTxOpts) + +// EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction. +// The returned close function will end the transaction. If the DB is already a transaction, the close function +// will do nothing because another [Commander] is already responsible for ending the transaction. +func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) { + beginner, ok := o.DB.(database.Beginner) + if !ok { + // db is already a transaction + return func(_ context.Context, err error) error { + return err + }, nil + } + + txOpts := &ensureTxOpts{ + TransactionOptions: new(database.TransactionOptions), + } + for _, opt := range opts { + opt(txOpts) + } + + tx, err := beginner.Begin(ctx, txOpts.TransactionOptions) + if err != nil { + return nil, err + } + o.DB = tx + + return func(ctx context.Context, err error) error { + return tx.End(ctx, err) + }, nil +} + +// EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool]. +// The returned close function will release the client. If the o.DB is already a client or transaction, the close function +// will do nothing because another [Commander] is already responsible for releasing the client. +func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) { + pool, ok := o.DB.(database.Pool) + if !ok { + // o.DB is already a client + return func(_ context.Context) error { + return nil + }, nil + } + client, err := pool.Acquire(ctx) + if err != nil { + return nil, err + } + o.DB = client + return func(ctx context.Context) error { + return client.Release(ctx) + }, nil +} + +func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error { + if o.Invoker == nil { + return command.Execute(ctx, o) + } + return o.Invoker.Invoke(ctx, command, o) +} + +func DefaultOpts(invoker Invoker) *CommandOpts { + if invoker == nil { + invoker = &noopInvoker{} + } + return &CommandOpts{ + DB: pool, + Invoker: invoker, + } +} + +type noopInvoker struct { + next Invoker +} + +func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error { + if i.next != nil { + return i.next.Invoke(ctx, command, opts) + } + return command.Execute(ctx, opts) +} diff --git a/backend/v3/domain/create_user.go b/backend/v3/domain/create_user.go new file mode 100644 index 0000000000..c4f480d1cb --- /dev/null +++ b/backend/v3/domain/create_user.go @@ -0,0 +1,76 @@ +package domain + +import ( + "context" + + v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4" + "github.com/zitadel/zitadel/backend/v3/storage/eventstore" +) + +type CreateUserCommand struct { + user *User + email *SetEmailCommand +} + +var ( + _ Commander = (*CreateUserCommand)(nil) + _ eventer = (*CreateUserCommand)(nil) +) + +func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand { + cmd := &CreateUserCommand{ + user: &User{ + User: v4.User{ + Username: username, + Traits: &v4.Human{}, + }, + }, + } + + for _, opt := range opts { + opt.applyOnCreateHuman(cmd) + } + return cmd +} + +// Events implements [eventer]. +func (c *CreateUserCommand) Events() []*eventstore.Event { + panic("unimplemented") +} + +// Execute implements [Commander]. +func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error { + if err := c.ensureUserID(); err != nil { + return err + } + c.email.UserID = c.user.ID + if err := opts.Invoke(ctx, c.email); err != nil { + return err + } + return nil +} + +type CreateHumanOpt interface { + applyOnCreateHuman(*CreateUserCommand) +} + +type createHumanIDOpt string + +// applyOnCreateHuman implements [CreateHumanOpt]. +func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) { + cmd.user.ID = string(c) +} + +var _ CreateHumanOpt = (*createHumanIDOpt)(nil) + +func CreateHumanWithID(id string) CreateHumanOpt { + return createHumanIDOpt(id) +} + +func (c *CreateUserCommand) ensureUserID() (err error) { + if c.user.ID != "" { + return nil + } + c.user.ID, err = generateID() + return err +} diff --git a/backend/v3/domain/crypto.go b/backend/v3/domain/crypto.go new file mode 100644 index 0000000000..a3b06f6645 --- /dev/null +++ b/backend/v3/domain/crypto.go @@ -0,0 +1,26 @@ +package domain + +import ( + "context" + + "github.com/zitadel/zitadel/internal/crypto" +) + +type generateCodeCommand struct { + code string + value *crypto.CryptoValue +} + +type CryptoRepository interface { + GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) +} + +func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { + config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx) + if err != nil { + return err + } + generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm) + cmd.value, cmd.code, err = crypto.NewCode(generator) + return err +} diff --git a/backend/v3/domain/domain.go b/backend/v3/domain/domain.go new file mode 100644 index 0000000000..8f28acf00d --- /dev/null +++ b/backend/v3/domain/domain.go @@ -0,0 +1,52 @@ +package domain + +import ( + "math/rand/v2" + "strconv" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" + "github.com/zitadel/zitadel/internal/crypto" +) + +var ( + pool database.Pool + userCodeAlgorithm crypto.EncryptionAlgorithm + tracer tracing.Tracer + + // userRepo func(database.QueryExecutor) UserRepository + instanceRepo func(database.QueryExecutor) InstanceRepository + cryptoRepo func(database.QueryExecutor) CryptoRepository + orgRepo func(database.QueryExecutor) OrgRepository + + instanceCache cache.Cache[string, string, *Instance] + + generateID func() (string, error) = func() (string, error) { + return strconv.FormatUint(rand.Uint64(), 10), nil + } +) + +func SetPool(p database.Pool) { + pool = p +} + +func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) { + userCodeAlgorithm = algorithm +} + +func SetTracer(t tracing.Tracer) { + tracer = t +} + +// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) { +// userRepo = repo +// } + +func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) { + instanceRepo = repo +} + +func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) { + cryptoRepo = repo +} diff --git a/backend/v3/domain/domain_test.go b/backend/v3/domain/domain_test.go new file mode 100644 index 0000000000..6e4ee248f6 --- /dev/null +++ b/backend/v3/domain/domain_test.go @@ -0,0 +1,45 @@ +package domain_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "go.opentelemetry.io/otel" + "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" + sdktrace "go.opentelemetry.io/otel/sdk/trace" + + . "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +) + +func TestExample(t *testing.T) { + ctx := context.Background() + + // SetPool(pool) + + exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint()) + require.NoError(t, err) + tracerProvider := sdktrace.NewTracerProvider( + sdktrace.WithSyncer(exporter), + ) + otel.SetTracerProvider(tracerProvider) + SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")}) + defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }() + + SetUserRepository(repository.User) + SetInstanceRepository(repository.Instance) + SetCryptoRepository(repository.Crypto) + + t.Run("verified email", func(t *testing.T) { + err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true))) + assert.NoError(t, err) + }) + + t.Run("unverified email", func(t *testing.T) { + err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false))) + assert.NoError(t, err) + }) +} diff --git a/backend/v3/domain/email_verification.go b/backend/v3/domain/email_verification.go new file mode 100644 index 0000000000..61abfe93d2 --- /dev/null +++ b/backend/v3/domain/email_verification.go @@ -0,0 +1,155 @@ +package domain + +import ( + "context" + + v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4" +) + +type EmailVerifiedCommand struct { + UserID string `json:"userId"` + Email *Email `json:"email"` +} + +func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand { + return &EmailVerifiedCommand{ + UserID: userID, + Email: &Email{ + IsVerified: isVerified, + }, + } +} + +var ( + _ Commander = (*EmailVerifiedCommand)(nil) + _ SetEmailOpt = (*EmailVerifiedCommand)(nil) +) + +// Execute implements [Commander] +func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error { + return userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmailVerified(ctx, cmd.Email.Address) +} + +// applyOnSetEmail implements [SetEmailOpt] +func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) { + cmd.UserID = setEmailCmd.UserID + cmd.Email.Address = setEmailCmd.Email + setEmailCmd.verification = cmd +} + +type SendCodeCommand struct { + UserID string `json:"userId"` + Email string `json:"email"` + URLTemplate *string `json:"urlTemplate"` + generator *generateCodeCommand +} + +var ( + _ Commander = (*SendCodeCommand)(nil) + _ SetEmailOpt = (*SendCodeCommand)(nil) +) + +func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand { + return &SendCodeCommand{ + UserID: userID, + generator: &generateCodeCommand{}, + URLTemplate: urlTemplate, + } +} + +// Execute implements [Commander] +func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { + if err := cmd.ensureEmail(ctx, opts); err != nil { + return err + } + if err := cmd.ensureURL(ctx, opts); err != nil { + return err + } + + if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil { + return err + } + // TODO: queue notification + + return nil +} + +func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error { + if cmd.Email != "" { + return nil + } + email, err := userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().GetEmail(ctx) + if err != nil || email.IsVerified { + return err + } + cmd.Email = email.Address + return nil +} + +func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error { + if cmd.URLTemplate != nil && *cmd.URLTemplate != "" { + return nil + } + _, _ = ctx, opts + // TODO: load default template + return nil +} + +// applyOnSetEmail implements [SetEmailOpt] +func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) { + cmd.UserID = setEmailCmd.UserID + cmd.Email = setEmailCmd.Email + setEmailCmd.verification = cmd +} + +type ReturnCodeCommand struct { + UserID string `json:"userId"` + Email string `json:"email"` + Code string `json:"code"` + generator *generateCodeCommand +} + +var ( + _ Commander = (*ReturnCodeCommand)(nil) + _ SetEmailOpt = (*ReturnCodeCommand)(nil) +) + +func NewReturnCodeCommand(userID string) *ReturnCodeCommand { + return &ReturnCodeCommand{ + UserID: userID, + generator: &generateCodeCommand{}, + } +} + +// Execute implements [Commander] +func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { + if err := cmd.ensureEmail(ctx, opts); err != nil { + return err + } + if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil { + return err + } + cmd.Code = cmd.generator.code + return nil +} + +func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error { + if cmd.Email != "" { + return nil + } + user := v4.UserRepository(opts.DB) + user.WithCondition(user.IDCondition(cmd.UserID)) + email, err := user.he.GetEmail(ctx) + if err != nil || email.IsVerified { + return err + } + cmd.Email = email.Address + return nil +} + +// applyOnSetEmail implements [SetEmailOpt] +func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) { + cmd.UserID = setEmailCmd.UserID + cmd.Email = setEmailCmd.Email + setEmailCmd.verification = cmd +} diff --git a/backend/v3/domain/errors.go b/backend/v3/domain/errors.go new file mode 100644 index 0000000000..a11c31c07d --- /dev/null +++ b/backend/v3/domain/errors.go @@ -0,0 +1,7 @@ +package domain + +import "errors" + +var ( + ErrNoAdminSpecified = errors.New("at least one admin must be specified") +) diff --git a/backend/v3/domain/instance.go b/backend/v3/domain/instance.go new file mode 100644 index 0000000000..54833a1173 --- /dev/null +++ b/backend/v3/domain/instance.go @@ -0,0 +1,36 @@ +package domain + +import ( + "context" + "time" +) + +type Instance struct { + ID string `json:"id"` + Name string `json:"name"` + CreatedAt time.Time `json:"-"` + UpdatedAt time.Time `json:"-"` + DeletedAt time.Time `json:"-"` +} + +// Keys implements the [cache.Entry]. +func (i *Instance) Keys(index string) (key []string) { + // TODO: Return the correct keys for the instance cache, e.g., i.ID, i.Domain + return []string{} +} + +type InstanceRepository interface { + ByID(ctx context.Context, id string) (*Instance, error) + Create(ctx context.Context, instance *Instance) error + On(id string) InstanceOperation +} + +type InstanceOperation interface { + AdminRepository + Update(ctx context.Context, instance *Instance) error + Delete(ctx context.Context) error +} + +type CreateInstance struct { + Name string `json:"name"` +} diff --git a/backend/v3/domain/invoke.go b/backend/v3/domain/invoke.go new file mode 100644 index 0000000000..54999deae1 --- /dev/null +++ b/backend/v3/domain/invoke.go @@ -0,0 +1,94 @@ +package domain + +import ( + "context" + "fmt" + + "github.com/zitadel/zitadel/backend/v3/storage/eventstore" +) + +var defaultInvoker = newEventStoreInvoker(newTraceInvoker(nil)) + +func Invoke(ctx context.Context, cmd Commander) error { + invoker := newEventStoreInvoker(newTraceInvoker(nil)) + opts := &CommandOpts{ + Invoker: invoker.collector, + } + return invoker.Invoke(ctx, cmd, opts) +} + +type eventStoreInvoker struct { + collector *eventCollector +} + +func newEventStoreInvoker(next Invoker) *eventStoreInvoker { + return &eventStoreInvoker{collector: &eventCollector{next: next}} +} + +func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { + err = i.collector.Invoke(ctx, command, opts) + if err != nil { + return err + } + if len(i.collector.events) > 0 { + err = eventstore.Publish(ctx, i.collector.events, opts.DB) + if err != nil { + return err + } + } + return nil +} + +type eventCollector struct { + next Invoker + events []*eventstore.Event +} + +type eventer interface { + Events() []*eventstore.Event +} + +func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { + if e, ok := command.(eventer); ok && len(e.Events()) > 0 { + // we need to ensure all commands are executed in the same transaction + close, err := opts.EnsureTx(ctx) + if err != nil { + return err + } + defer func() { err = close(ctx, err) }() + + i.events = append(i.events, e.Events()...) + } + if i.next != nil { + err = i.next.Invoke(ctx, command, opts) + } else { + err = command.Execute(ctx, opts) + } + if err != nil { + return err + } + return nil +} + +type traceInvoker struct { + next Invoker +} + +func newTraceInvoker(next Invoker) *traceInvoker { + return &traceInvoker{next: next} +} + +func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { + ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command)) + defer span.End() + + if i.next != nil { + err = i.next.Invoke(ctx, command, opts) + } else { + err = command.Execute(ctx, opts) + } + if err != nil { + span.RecordError(err) + } + return err +} diff --git a/backend/v3/domain/org.go b/backend/v3/domain/org.go new file mode 100644 index 0000000000..13e90dbef6 --- /dev/null +++ b/backend/v3/domain/org.go @@ -0,0 +1,39 @@ +package domain + +import ( + "context" + "time" +) + +type Org struct { + ID string `json:"id"` + Name string `json:"name"` + + CreatedAt time.Time `json:"createdAt"` + UpdatedAt time.Time `json:"updatedAt"` +} + +type OrgRepository interface { + ByID(ctx context.Context, orgID string) (*Org, error) + Create(ctx context.Context, org *Org) error + On(id string) OrgOperation +} + +type OrgOperation interface { + AdminRepository + DomainRepository + Update(ctx context.Context, org *Org) error + Delete(ctx context.Context) error +} + +type AdminRepository interface { + AddAdmin(ctx context.Context, userID string, roles []string) error + SetAdminRoles(ctx context.Context, userID string, roles []string) error + RemoveAdmin(ctx context.Context, userID string) error +} + +type DomainRepository interface { + AddDomain(ctx context.Context, domain string) error + SetDomainVerified(ctx context.Context, domain string) error + RemoveDomain(ctx context.Context, domain string) error +} diff --git a/backend/v3/domain/org_add.go b/backend/v3/domain/org_add.go new file mode 100644 index 0000000000..c7bbe56650 --- /dev/null +++ b/backend/v3/domain/org_add.go @@ -0,0 +1,74 @@ +package domain + +import ( + "context" +) + +type AddOrgCommand struct { + ID string `json:"id"` + Name string `json:"name"` + Admins []AddAdminCommand `json:"admins"` +} + +func NewAddOrgCommand(name string, admins ...AddAdminCommand) *AddOrgCommand { + return &AddOrgCommand{ + Name: name, + Admins: admins, + } +} + +// Execute implements Commander. +func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) { + if len(cmd.Admins) == 0 { + return ErrNoAdminSpecified + } + if err = cmd.ensureID(); err != nil { + return err + } + + close, err := opts.EnsureTx(ctx) + if err != nil { + return err + } + defer func() { err = close(ctx, err) }() + err = orgRepo(opts.DB).Create(ctx, &Org{ + ID: cmd.ID, + Name: cmd.Name, + }) + if err != nil { + return err + } + + return nil +} + +var ( + _ Commander = (*AddOrgCommand)(nil) +) + +func (cmd *AddOrgCommand) ensureID() (err error) { + if cmd.ID != "" { + return nil + } + cmd.ID, err = generateID() + return err +} + +type AddAdminCommand struct { + UserID string `json:"userId"` + Roles []string `json:"roles"` +} + +// Execute implements Commander. +func (a *AddAdminCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) { + close, err := opts.EnsureTx(ctx) + if err != nil { + return err + } + defer func() { err = close(ctx, err) }() + return nil +} + +var ( + _ Commander = (*AddAdminCommand)(nil) +) diff --git a/backend/v3/domain/repository.go b/backend/v3/domain/repository.go new file mode 100644 index 0000000000..f9ea8b92f4 --- /dev/null +++ b/backend/v3/domain/repository.go @@ -0,0 +1,82 @@ +package domain + +import ( + "time" + + "golang.org/x/exp/constraints" +) + +type Operation interface { + // TextOperation | + // NumberOperation | + // BoolOperation + + op() +} + +type clause[F ~uint8, Op Operation] struct { + field F + op Op +} + +func (c *clause[F, Op]) Field() F { + return c.field +} + +func (c *clause[F, Op]) Operation() Op { + return c.op +} + +type Text interface { + ~string | ~[]byte +} + +type TextOperation uint8 + +const ( + TextOperationEqual TextOperation = iota + TextOperationNotEqual + TextOperationStartsWith + TextOperationStartsWithIgnoreCase +) + +func (TextOperation) op() {} + +type Number interface { + constraints.Integer | constraints.Float | constraints.Complex | time.Time +} + +type NumberOperation uint8 + +const ( + NumberOperationEqual NumberOperation = iota + NumberOperationNotEqual + NumberOperationLessThan + NumberOperationLessThanOrEqual + NumberOperationGreaterThan + NumberOperationGreaterThanOrEqual +) + +func (NumberOperation) op() {} + +type Bool interface { + ~bool +} + +type BoolOperation uint8 + +const ( + BoolOperationIs BoolOperation = iota + BoolOperationNot +) + +func (BoolOperation) op() {} + +type ListOperation uint8 + +const ( + ListOperationContains ListOperation = iota + ListOperationNotContains +) + +func (ListOperation) op() {} diff --git a/backend/v3/domain/set_email.go b/backend/v3/domain/set_email.go new file mode 100644 index 0000000000..bae2b20313 --- /dev/null +++ b/backend/v3/domain/set_email.go @@ -0,0 +1,64 @@ +package domain + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/storage/eventstore" +) + +type SetEmailCommand struct { + UserID string `json:"userId"` + Email string `json:"email"` + verification Commander +} + +var ( + _ Commander = (*SetEmailCommand)(nil) + _ eventer = (*SetEmailCommand)(nil) + _ CreateHumanOpt = (*SetEmailCommand)(nil) +) + +type SetEmailOpt interface { + applyOnSetEmail(*SetEmailCommand) +} + +func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand { + cmd := &SetEmailCommand{ + UserID: userID, + Email: email, + } + verificationType.applyOnSetEmail(cmd) + return cmd +} + +func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error { + close, err := opts.EnsureTx(ctx) + if err != nil { + return err + } + defer func() { err = close(ctx, err) }() + // userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email) + err = userRepo(opts.DB).Human().ByID(cmd.UserID).Exec().SetEmail(ctx, cmd.Email) + if err != nil { + return err + } + + return opts.Invoke(ctx, cmd.verification) +} + +// Events implements [eventer]. +func (cmd *SetEmailCommand) Events() []*eventstore.Event { + return []*eventstore.Event{ + { + AggregateType: "user", + AggregateID: cmd.UserID, + Type: "user.email.set", + Payload: cmd, + }, + } +} + +// applyOnCreateHuman implements [CreateHumanOpt]. +func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand[Human]) { + createUserCmd.email = cmd +} diff --git a/backend/v3/domain/user.go b/backend/v3/domain/user.go new file mode 100644 index 0000000000..6adbc18e5a --- /dev/null +++ b/backend/v3/domain/user.go @@ -0,0 +1,193 @@ +package domain + +import ( + "context" + "time" + + v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4" +) + +type userColumns interface { + // TODO: move v4.columns to domain + InstanceIDColumn() column + OrgIDColumn() column + IDColumn() column + usernameColumn() column + CreatedAtColumn() column + UpdatedAtColumn() column + DeletedAtColumn() column +} + +type userConditions interface { + InstanceIDCondition(instanceID string) v4.Condition + OrgIDCondition(orgID string) v4.Condition + IDCondition(userID string) v4.Condition + UsernameCondition(op v4.TextOperator, username string) v4.Condition + CreatedAtCondition(op v4.NumberOperator, createdAt time.Time) v4.Condition + UpdatedAtCondition(op v4.NumberOperator, updatedAt time.Time) v4.Condition + DeletedCondition(isDeleted bool) v4.Condition + DeletedAtCondition(op v4.NumberOperator, deletedAt time.Time) v4.Condition +} + +type UserRepository interface { + userColumns + userConditions + // TODO: move condition to domain + WithCondition(condition v4.Condition) UserRepository + Get(ctx context.Context) (*User, error) + List(ctx context.Context) ([]*User, error) + Create(ctx context.Context, user *User) error + Delete(ctx context.Context) error + + Human() HumanRepository + Machine() MachineRepository +} + +type humanColumns interface { + FirstNameColumn() column + LastNameColumn() column + EmailAddressColumn() column + EmailVerifiedAtColumn() column + PhoneNumberColumn() column + PhoneVerifiedAtColumn() column +} + +type humanConditions interface { + FirstNameCondition(op v4.TextOperator, firstName string) v4.Condition + LastNameCondition(op v4.TextOperator, lastName string) v4.Condition + EmailAddressCondition(op v4.TextOperator, email string) v4.Condition + EmailAddressVerifiedCondition(isVerified bool) v4.Condition + EmailVerifiedAtCondition(op v4.TextOperator, emailVerifiedAt string) v4.Condition + PhoneNumberCondition(op v4.TextOperator, phoneNumber string) v4.Condition + PhoneNumberVerifiedCondition(isVerified bool) v4.Condition + PhoneVerifiedAtCondition(op v4.TextOperator, phoneVerifiedAt string) v4.Condition +} + +type HumanRepository interface { + humanColumns + humanConditions + + GetEmail(ctx context.Context) (*Email, error) + // TODO: replace any with add email update columns + SetEmail(ctx context.Context, columns ...any) error +} + +type machineColumns interface { + DescriptionColumn() column +} + +type machineConditions interface { + DescriptionCondition(op v4.TextOperator, description string) v4.Condition +} + +type MachineRepository interface { + machineColumns + machineConditions +} + +// type UserRepository interface { +// // Get(ctx context.Context, clauses ...UserClause) (*User, error) +// // Search(ctx context.Context, clauses ...UserClause) ([]*User, error) + +// UserQuery[UserOperation] +// Human() HumanQuery +// Machine() MachineQuery +// } + +// type UserQuery[Op UserOperation] interface { +// ByID(id string) UserQuery[Op] +// Username(username string) UserQuery[Op] +// Exec() Op +// } + +// type HumanQuery interface { +// UserQuery[HumanOperation] +// Email(op TextOperation, email string) HumanQuery +// HumanOperation +// } + +// type MachineQuery interface { +// UserQuery[MachineOperation] +// MachineOperation +// } + +// type UserClause interface { +// Field() UserField +// Operation() Operation +// Args() []any +// } + +// type UserField uint8 + +// const ( +// // Fields used for all users +// UserFieldInstanceID UserField = iota + 1 +// UserFieldOrgID +// UserFieldID +// UserFieldUsername + +// // Fields used for human users +// UserHumanFieldEmail +// UserHumanFieldEmailVerified + +// // Fields used for machine users +// UserMachineFieldDescription +// ) + +// type userByIDClause struct { +// id string +// } + +// func (c *userByIDClause) Field() UserField { +// return UserFieldID +// } + +// func (c *userByIDClause) Operation() Operation { +// return TextOperationEqual +// } + +// func (c *userByIDClause) Args() []any { +// return []any{c.id} +// } + +// type UserOperation interface { +// Delete(ctx context.Context) error +// SetUsername(ctx context.Context, username string) error +// } + +// type HumanOperation interface { +// UserOperation +// SetEmail(ctx context.Context, email string) error +// SetEmailVerified(ctx context.Context, email string) error +// GetEmail(ctx context.Context) (*Email, error) +// } + +// type MachineOperation interface { +// UserOperation +// SetDescription(ctx context.Context, description string) error +// } + +type User struct { + v4.User +} + +// type userTraits interface { +// isUserTraits() +// } + +// type Human struct { +// Email *Email `json:"email"` +// } + +// func (*Human) isUserTraits() {} + +// type Machine struct { +// Description string `json:"description"` +// } + +// func (*Machine) isUserTraits() {} + +// type Email struct { +// Address string `json:"address"` +// IsVerified bool `json:"isVerified"` +// } diff --git a/backend/v3/storage/cache/cache.go b/backend/v3/storage/cache/cache.go new file mode 100644 index 0000000000..dc05208caa --- /dev/null +++ b/backend/v3/storage/cache/cache.go @@ -0,0 +1,112 @@ +// Package cache provides abstraction of cache implementations that can be used by zitadel. +package cache + +import ( + "context" + "time" + + "github.com/zitadel/logging" +) + +// Purpose describes which object types are stored by a cache. +type Purpose int + +//go:generate enumer -type Purpose -transform snake -trimprefix Purpose +const ( + PurposeUnspecified Purpose = iota + PurposeAuthzInstance + PurposeMilestones + PurposeOrganization + PurposeIdPFormCallback +) + +// Cache stores objects with a value of type `V`. +// Objects may be referred to by one or more indices. +// Implementations may encode the value for storage. +// This means non-exported fields may be lost and objects +// with function values may fail to encode. +// See https://pkg.go.dev/encoding/json#Marshal for example. +// +// `I` is the type by which indices are identified, +// typically an enum for type-safe access. +// Indices are defined when calling the constructor of an implementation of this interface. +// It is illegal to refer to an idex not defined during construction. +// +// `K` is the type used as key in each index. +// Due to the limitations in type constraints, all indices use the same key type. +// +// Implementations are free to use stricter type constraints or fixed typing. +type Cache[I, K comparable, V Entry[I, K]] interface { + // Get an object through specified index. + // An [IndexUnknownError] may be returned if the index is unknown. + // [ErrCacheMiss] is returned if the key was not found in the index, + // or the object is not valid. + Get(ctx context.Context, index I, key K) (V, bool) + + // Set an object. + // Keys are created on each index based in the [Entry.Keys] method. + // If any key maps to an existing object, the object is invalidated, + // regardless if the object has other keys defined in the new entry. + // This to prevent ghost objects when an entry reduces the amount of keys + // for a given index. + Set(ctx context.Context, value V) + + // Invalidate an object through specified index. + // Implementations may choose to instantly delete the object, + // defer until prune or a separate cleanup routine. + // Invalidated object are no longer returned from Get. + // It is safe to call Invalidate multiple times or on non-existing entries. + Invalidate(ctx context.Context, index I, key ...K) error + + // Delete one or more keys from a specific index. + // An [IndexUnknownError] may be returned if the index is unknown. + // The referred object is not invalidated and may still be accessible though + // other indices and keys. + // It is safe to call Delete multiple times or on non-existing entries + Delete(ctx context.Context, index I, key ...K) error + + // Truncate deletes all cached objects. + Truncate(ctx context.Context) error +} + +// Entry contains a value of type `V` to be cached. +// +// `I` is the type by which indices are identified, +// typically an enum for type-safe access. +// +// `K` is the type used as key in an index. +// Due to the limitations in type constraints, all indices use the same key type. +type Entry[I, K comparable] interface { + // Keys returns which keys map to the object in a specified index. + // May return nil if the index in unknown or when there are no keys. + Keys(index I) (key []K) +} + +type Connector int + +//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text +const ( + // Empty line comment ensures empty string for unspecified value + ConnectorUnspecified Connector = iota // + ConnectorMemory + ConnectorPostgres + ConnectorRedis +) + +type Config struct { + Connector Connector + + // Age since an object was added to the cache, + // after which the object is considered invalid. + // 0 disables max age checks. + MaxAge time.Duration + + // Age since last use (Get) of an object, + // after which the object is considered invalid. + // 0 disables last use age checks. + LastUseAge time.Duration + + // Log allows logging of the specific cache. + // By default only errors are logged to stdout. + Log *logging.Config +} diff --git a/backend/v3/storage/cache/connector/connector.go b/backend/v3/storage/cache/connector/connector.go new file mode 100644 index 0000000000..487680155c --- /dev/null +++ b/backend/v3/storage/cache/connector/connector.go @@ -0,0 +1,49 @@ +// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages. +package connector + +import ( + "context" + "fmt" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" + "github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap" + "github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop" +) + +type CachesConfig struct { + Connectors struct { + Memory gomap.Config + } + Instance *cache.Config + Milestones *cache.Config + Organization *cache.Config + IdPFormCallbacks *cache.Config +} + +type Connectors struct { + Config CachesConfig + Memory *gomap.Connector +} + +func StartConnectors(conf *CachesConfig) (Connectors, error) { + if conf == nil { + return Connectors{}, nil + } + return Connectors{ + Config: *conf, + Memory: gomap.NewConnector(conf.Connectors.Memory), + }, nil +} + +func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) { + if conf == nil || conf.Connector == cache.ConnectorUnspecified { + return noop.NewCache[I, K, V](), nil + } + if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil { + c := gomap.NewCache[I, K, V](background, indices, *conf) + connectors.Memory.Config.StartAutoPrune(background, c, purpose) + return c, nil + } + + return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector) +} diff --git a/backend/v3/storage/cache/connector/gomap/connector.go b/backend/v3/storage/cache/connector/gomap/connector.go new file mode 100644 index 0000000000..c453e34fc5 --- /dev/null +++ b/backend/v3/storage/cache/connector/gomap/connector.go @@ -0,0 +1,23 @@ +package gomap + +import ( + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type Config struct { + Enabled bool + AutoPrune cache.AutoPruneConfig +} + +type Connector struct { + Config cache.AutoPruneConfig +} + +func NewConnector(config Config) *Connector { + if !config.Enabled { + return nil + } + return &Connector{ + Config: config.AutoPrune, + } +} diff --git a/backend/v3/storage/cache/connector/gomap/gomap.go b/backend/v3/storage/cache/connector/gomap/gomap.go new file mode 100644 index 0000000000..6b25d642c4 --- /dev/null +++ b/backend/v3/storage/cache/connector/gomap/gomap.go @@ -0,0 +1,200 @@ +package gomap + +import ( + "context" + "errors" + "log/slog" + "maps" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type mapCache[I, K comparable, V cache.Entry[I, K]] struct { + config *cache.Config + indexMap map[I]*index[K, V] + logger *slog.Logger +} + +// NewCache returns an in-memory Cache implementation based on the builtin go map type. +// Object values are stored as-is and there is no encoding or decoding involved. +func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] { + m := &mapCache[I, K, V]{ + config: &config, + indexMap: make(map[I]*index[K, V], len(indices)), + logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelError, + })), + } + if config.Log != nil { + m.logger = config.Log.Slog() + } + m.logger.InfoContext(background, "map cache logging enabled") + + for _, name := range indices { + m.indexMap[name] = &index[K, V]{ + config: m.config, + entries: make(map[K]*entry[V]), + } + } + return m +} + +func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) { + i, ok := c.indexMap[index] + if !ok { + c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key) + return value, false + } + entry, err := i.Get(key) + if err == nil { + c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key) + return entry.value, true + } + if errors.Is(err, cache.ErrCacheMiss) { + c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key) + return value, false + } + c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key) + return value, false +} + +func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) { + now := time.Now() + entry := &entry[V]{ + value: value, + created: now, + } + entry.lastUse.Store(now.UnixMicro()) + + for name, i := range c.indexMap { + keys := value.Keys(name) + i.Set(keys, entry) + c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys) + } +} + +func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error { + i, ok := c.indexMap[index] + if !ok { + return cache.NewIndexUnknownErr(index) + } + i.Invalidate(keys) + c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys) + return nil +} + +func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error { + i, ok := c.indexMap[index] + if !ok { + return cache.NewIndexUnknownErr(index) + } + i.Delete(keys) + c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys) + return nil +} + +func (c *mapCache[I, K, V]) Prune(ctx context.Context) error { + for name, index := range c.indexMap { + index.Prune() + c.logger.DebugContext(ctx, "map cache prune", "index", name) + } + return nil +} + +func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error { + for name, index := range c.indexMap { + index.Truncate() + c.logger.DebugContext(ctx, "map cache truncate", "index", name) + } + return nil +} + +type index[K comparable, V any] struct { + mutex sync.RWMutex + config *cache.Config + entries map[K]*entry[V] +} + +func (i *index[K, V]) Get(key K) (*entry[V], error) { + i.mutex.RLock() + entry, ok := i.entries[key] + i.mutex.RUnlock() + if ok && entry.isValid(i.config) { + return entry, nil + } + return nil, cache.ErrCacheMiss +} + +func (c *index[K, V]) Set(keys []K, entry *entry[V]) { + c.mutex.Lock() + for _, key := range keys { + c.entries[key] = entry + } + c.mutex.Unlock() +} + +func (i *index[K, V]) Invalidate(keys []K) { + i.mutex.RLock() + for _, key := range keys { + if entry, ok := i.entries[key]; ok { + entry.invalid.Store(true) + } + } + i.mutex.RUnlock() +} + +func (c *index[K, V]) Delete(keys []K) { + c.mutex.Lock() + for _, key := range keys { + delete(c.entries, key) + } + c.mutex.Unlock() +} + +func (c *index[K, V]) Prune() { + c.mutex.Lock() + maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool { + return !entry.isValid(c.config) + }) + c.mutex.Unlock() +} + +func (c *index[K, V]) Truncate() { + c.mutex.Lock() + c.entries = make(map[K]*entry[V]) + c.mutex.Unlock() +} + +type entry[V any] struct { + value V + created time.Time + invalid atomic.Bool + lastUse atomic.Int64 // UnixMicro time +} + +func (e *entry[V]) isValid(c *cache.Config) bool { + if e.invalid.Load() { + return false + } + now := time.Now() + if c.MaxAge > 0 { + if e.created.Add(c.MaxAge).Before(now) { + e.invalid.Store(true) + return false + } + } + if c.LastUseAge > 0 { + lastUse := e.lastUse.Load() + if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) { + e.invalid.Store(true) + return false + } + e.lastUse.CompareAndSwap(lastUse, now.UnixMicro()) + } + return true +} diff --git a/backend/v3/storage/cache/connector/gomap/gomap_test.go b/backend/v3/storage/cache/connector/gomap/gomap_test.go new file mode 100644 index 0000000000..62bbc471a1 --- /dev/null +++ b/backend/v3/storage/cache/connector/gomap/gomap_test.go @@ -0,0 +1,329 @@ +package gomap + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type testIndex int + +const ( + testIndexID testIndex = iota + testIndexName +) + +var testIndices = []testIndex{ + testIndexID, + testIndexName, +} + +type testObject struct { + id string + names []string +} + +func (o *testObject) Keys(index testIndex) []string { + switch index { + case testIndexID: + return []string{o.id} + case testIndexName: + return o.names + default: + return nil + } +} + +func Test_mapCache_Get(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + + type args struct { + index testIndex + key string + } + tests := []struct { + name string + args args + want *testObject + wantOk bool + }{ + { + name: "ok", + args: args{ + index: testIndexID, + key: "id", + }, + want: obj, + wantOk: true, + }, + { + name: "miss", + args: args{ + index: testIndexID, + key: "spanac", + }, + want: nil, + wantOk: false, + }, + { + name: "unknown index", + args: args{ + index: 99, + key: "id", + }, + want: nil, + wantOk: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := c.Get(context.Background(), tt.args.index, tt.args.key) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wantOk, ok) + }) + } +} + +func Test_mapCache_Invalidate(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + err := c.Invalidate(context.Background(), testIndexName, "bar") + require.NoError(t, err) + got, ok := c.Get(context.Background(), testIndexID, "id") + assert.Nil(t, got) + assert.False(t, ok) +} + +func Test_mapCache_Delete(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + err := c.Delete(context.Background(), testIndexName, "bar") + require.NoError(t, err) + + // Shouldn't find object by deleted name + got, ok := c.Get(context.Background(), testIndexName, "bar") + assert.Nil(t, got) + assert.False(t, ok) + + // Should find object by other name + got, ok = c.Get(context.Background(), testIndexName, "foo") + assert.Equal(t, obj, got) + assert.True(t, ok) + + // Should find object by id + got, ok = c.Get(context.Background(), testIndexID, "id") + assert.Equal(t, obj, got) + assert.True(t, ok) +} + +func Test_mapCache_Prune(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + + objects := []*testObject{ + { + id: "id1", + names: []string{"foo", "bar"}, + }, + { + id: "id2", + names: []string{"hello"}, + }, + } + for _, obj := range objects { + c.Set(context.Background(), obj) + } + // invalidate one entry + err := c.Invalidate(context.Background(), testIndexName, "bar") + require.NoError(t, err) + + err = c.(cache.Pruner).Prune(context.Background()) + require.NoError(t, err) + + // Other object should still be found + got, ok := c.Get(context.Background(), testIndexID, "id2") + assert.Equal(t, objects[1], got) + assert.True(t, ok) +} + +func Test_mapCache_Truncate(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + objects := []*testObject{ + { + id: "id1", + names: []string{"foo", "bar"}, + }, + { + id: "id2", + names: []string{"hello"}, + }, + } + for _, obj := range objects { + c.Set(context.Background(), obj) + } + + err := c.Truncate(context.Background()) + require.NoError(t, err) + + mc := c.(*mapCache[testIndex, string, *testObject]) + for _, index := range mc.indexMap { + index.mutex.RLock() + assert.Len(t, index.entries, 0) + index.mutex.RUnlock() + } +} + +func Test_entry_isValid(t *testing.T) { + type fields struct { + created time.Time + invalid bool + lastUse time.Time + } + tests := []struct { + name string + fields fields + config *cache.Config + want bool + }{ + { + name: "invalid", + fields: fields{ + created: time.Now(), + invalid: true, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "max age exceeded", + fields: fields{ + created: time.Now().Add(-(time.Minute + time.Second)), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "max age disabled", + fields: fields{ + created: time.Now().Add(-(time.Minute + time.Second)), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + LastUseAge: time.Second, + }, + want: true, + }, + { + name: "last use age exceeded", + fields: fields{ + created: time.Now().Add(-(time.Minute / 2)), + invalid: false, + lastUse: time.Now().Add(-(time.Second * 2)), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "last use age disabled", + fields: fields{ + created: time.Now().Add(-(time.Minute / 2)), + invalid: false, + lastUse: time.Now().Add(-(time.Second * 2)), + }, + config: &cache.Config{ + MaxAge: time.Minute, + }, + want: true, + }, + { + name: "valid", + fields: fields{ + created: time.Now(), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &entry[any]{ + created: tt.fields.created, + } + e.invalid.Store(tt.fields.invalid) + e.lastUse.Store(tt.fields.lastUse.UnixMicro()) + got := e.isValid(tt.config) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/v3/storage/cache/connector/noop/noop.go b/backend/v3/storage/cache/connector/noop/noop.go new file mode 100644 index 0000000000..12d261a77d --- /dev/null +++ b/backend/v3/storage/cache/connector/noop/noop.go @@ -0,0 +1,21 @@ +package noop + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type noop[I, K comparable, V cache.Entry[I, K]] struct{} + +// NewCache returns a cache that does nothing +func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] { + return noop[I, K, V]{} +} + +func (noop[I, K, V]) Set(context.Context, V) {} +func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return } +func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return } +func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return } +func (noop[I, K, V]) Prune(context.Context) (err error) { return } +func (noop[I, K, V]) Truncate(context.Context) (err error) { return } diff --git a/backend/v3/storage/cache/connector_enumer.go b/backend/v3/storage/cache/connector_enumer.go new file mode 100644 index 0000000000..7ea014db16 --- /dev/null +++ b/backend/v3/storage/cache/connector_enumer.go @@ -0,0 +1,98 @@ +// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _ConnectorName = "memorypostgresredis" + +var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19} + +const _ConnectorLowerName = "memorypostgresredis" + +func (i Connector) String() string { + if i < 0 || i >= Connector(len(_ConnectorIndex)-1) { + return fmt.Sprintf("Connector(%d)", i) + } + return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _ConnectorNoOp() { + var x [1]struct{} + _ = x[ConnectorUnspecified-(0)] + _ = x[ConnectorMemory-(1)] + _ = x[ConnectorPostgres-(2)] + _ = x[ConnectorRedis-(3)] +} + +var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis} + +var _ConnectorNameToValueMap = map[string]Connector{ + _ConnectorName[0:0]: ConnectorUnspecified, + _ConnectorLowerName[0:0]: ConnectorUnspecified, + _ConnectorName[0:6]: ConnectorMemory, + _ConnectorLowerName[0:6]: ConnectorMemory, + _ConnectorName[6:14]: ConnectorPostgres, + _ConnectorLowerName[6:14]: ConnectorPostgres, + _ConnectorName[14:19]: ConnectorRedis, + _ConnectorLowerName[14:19]: ConnectorRedis, +} + +var _ConnectorNames = []string{ + _ConnectorName[0:0], + _ConnectorName[0:6], + _ConnectorName[6:14], + _ConnectorName[14:19], +} + +// ConnectorString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ConnectorString(s string) (Connector, error) { + if val, ok := _ConnectorNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Connector values", s) +} + +// ConnectorValues returns all values of the enum +func ConnectorValues() []Connector { + return _ConnectorValues +} + +// ConnectorStrings returns a slice of all String values of the enum +func ConnectorStrings() []string { + strs := make([]string, len(_ConnectorNames)) + copy(strs, _ConnectorNames) + return strs +} + +// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Connector) IsAConnector() bool { + for _, v := range _ConnectorValues { + if i == v { + return true + } + } + return false +} + +// MarshalText implements the encoding.TextMarshaler interface for Connector +func (i Connector) MarshalText() ([]byte, error) { + return []byte(i.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector +func (i *Connector) UnmarshalText(text []byte) error { + var err error + *i, err = ConnectorString(string(text)) + return err +} diff --git a/backend/v3/storage/cache/error.go b/backend/v3/storage/cache/error.go new file mode 100644 index 0000000000..b66b9447bf --- /dev/null +++ b/backend/v3/storage/cache/error.go @@ -0,0 +1,29 @@ +package cache + +import ( + "errors" + "fmt" +) + +type IndexUnknownError[I comparable] struct { + index I +} + +func NewIndexUnknownErr[I comparable](index I) error { + return IndexUnknownError[I]{index} +} + +func (i IndexUnknownError[I]) Error() string { + return fmt.Sprintf("index %v unknown", i.index) +} + +func (a IndexUnknownError[I]) Is(err error) bool { + if b, ok := err.(IndexUnknownError[I]); ok { + return a.index == b.index + } + return false +} + +var ( + ErrCacheMiss = errors.New("cache miss") +) diff --git a/backend/v3/storage/cache/pruner.go b/backend/v3/storage/cache/pruner.go new file mode 100644 index 0000000000..959762d410 --- /dev/null +++ b/backend/v3/storage/cache/pruner.go @@ -0,0 +1,76 @@ +package cache + +import ( + "context" + "math/rand" + "time" + + "github.com/jonboulle/clockwork" + "github.com/zitadel/logging" +) + +// Pruner is an optional [Cache] interface. +type Pruner interface { + // Prune deletes all invalidated or expired objects. + Prune(ctx context.Context) error +} + +type PrunerCache[I, K comparable, V Entry[I, K]] interface { + Cache[I, K, V] + Pruner +} + +type AutoPruneConfig struct { + // Interval at which the cache is automatically pruned. + // 0 or lower disables automatic pruning. + Interval time.Duration + + // Timeout for an automatic prune. + // It is recommended to keep the value shorter than AutoPruneInterval + // 0 or lower disables automatic pruning. + Timeout time.Duration +} + +func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) { + return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock()) +} + +func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) { + if c.Interval <= 0 { + return func() {} + } + background, cancel := context.WithCancel(background) + // randomize the first interval + timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval)))) + go c.pruneTimer(background, pruner, purpose, timer) + return cancel +} + +func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) { + defer func() { + if !timer.Stop() { + <-timer.Chan() + } + }() + + for { + select { + case <-background.Done(): + return + case <-timer.Chan(): + err := c.doPrune(background, pruner) + logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune") + timer.Reset(c.Interval) + } + } +} + +func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error { + ctx, cancel := context.WithCancel(background) + defer cancel() + if c.Timeout > 0 { + ctx, cancel = context.WithTimeout(background, c.Timeout) + defer cancel() + } + return pruner.Prune(ctx) +} diff --git a/backend/v3/storage/cache/pruner_test.go b/backend/v3/storage/cache/pruner_test.go new file mode 100644 index 0000000000..faaedeb88c --- /dev/null +++ b/backend/v3/storage/cache/pruner_test.go @@ -0,0 +1,43 @@ +package cache + +import ( + "context" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" +) + +type testPruner struct { + called chan struct{} +} + +func (p *testPruner) Prune(context.Context) error { + p.called <- struct{}{} + return nil +} + +func TestAutoPruneConfig_startAutoPrune(t *testing.T) { + c := AutoPruneConfig{ + Interval: time.Second, + Timeout: time.Millisecond, + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + pruner := testPruner{ + called: make(chan struct{}), + } + clock := clockwork.NewFakeClock() + close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock) + defer close() + clock.Advance(time.Second) + + select { + case _, ok := <-pruner.called: + assert.True(t, ok) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } +} diff --git a/backend/v3/storage/cache/purpose_enumer.go b/backend/v3/storage/cache/purpose_enumer.go new file mode 100644 index 0000000000..a93a978efb --- /dev/null +++ b/backend/v3/storage/cache/purpose_enumer.go @@ -0,0 +1,90 @@ +// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" + +var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65} + +const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" + +func (i Purpose) String() string { + if i < 0 || i >= Purpose(len(_PurposeIndex)-1) { + return fmt.Sprintf("Purpose(%d)", i) + } + return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _PurposeNoOp() { + var x [1]struct{} + _ = x[PurposeUnspecified-(0)] + _ = x[PurposeAuthzInstance-(1)] + _ = x[PurposeMilestones-(2)] + _ = x[PurposeOrganization-(3)] + _ = x[PurposeIdPFormCallback-(4)] +} + +var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback} + +var _PurposeNameToValueMap = map[string]Purpose{ + _PurposeName[0:11]: PurposeUnspecified, + _PurposeLowerName[0:11]: PurposeUnspecified, + _PurposeName[11:25]: PurposeAuthzInstance, + _PurposeLowerName[11:25]: PurposeAuthzInstance, + _PurposeName[25:35]: PurposeMilestones, + _PurposeLowerName[25:35]: PurposeMilestones, + _PurposeName[35:47]: PurposeOrganization, + _PurposeLowerName[35:47]: PurposeOrganization, + _PurposeName[47:65]: PurposeIdPFormCallback, + _PurposeLowerName[47:65]: PurposeIdPFormCallback, +} + +var _PurposeNames = []string{ + _PurposeName[0:11], + _PurposeName[11:25], + _PurposeName[25:35], + _PurposeName[35:47], + _PurposeName[47:65], +} + +// PurposeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func PurposeString(s string) (Purpose, error) { + if val, ok := _PurposeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Purpose values", s) +} + +// PurposeValues returns all values of the enum +func PurposeValues() []Purpose { + return _PurposeValues +} + +// PurposeStrings returns a slice of all String values of the enum +func PurposeStrings() []string { + strs := make([]string, len(_PurposeNames)) + copy(strs, _PurposeNames) + return strs +} + +// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Purpose) IsAPurpose() bool { + for _, v := range _PurposeValues { + if i == v { + return true + } + } + return false +} diff --git a/backend/v3/storage/database/config.go b/backend/v3/storage/database/config.go new file mode 100644 index 0000000000..d9aa99b869 --- /dev/null +++ b/backend/v3/storage/database/config.go @@ -0,0 +1,9 @@ +package database + +import ( + "context" +) + +type Connector interface { + Connect(ctx context.Context) (Pool, error) +} diff --git a/backend/v3/storage/database/database.go b/backend/v3/storage/database/database.go new file mode 100644 index 0000000000..33d297adf0 --- /dev/null +++ b/backend/v3/storage/database/database.go @@ -0,0 +1,60 @@ +package database + +import ( + "context" +) + +var ( + db *database +) + +type database struct { + connector Connector + pool Pool +} + +type Pool interface { + Beginner + QueryExecutor + + Acquire(ctx context.Context) (Client, error) + Close(ctx context.Context) error +} + +type Client interface { + Beginner + QueryExecutor + + Release(ctx context.Context) error +} + +type Querier interface { + Query(ctx context.Context, stmt string, args ...any) (Rows, error) + QueryRow(ctx context.Context, stmt string, args ...any) Row +} + +type Executor interface { + Exec(ctx context.Context, stmt string, args ...any) error +} + +type QueryExecutor interface { + Querier + Executor +} + +type Scanner interface { + Scan(dest ...any) error +} + +type Row interface { + Scanner +} + +type Rows interface { + Row + Next() bool + Close() error + Err() error +} + +type Query[T any] func(querier Querier) (result T, err error) diff --git a/backend/v3/storage/database/dialect/config.go b/backend/v3/storage/database/dialect/config.go new file mode 100644 index 0000000000..a044f7bd4e --- /dev/null +++ b/backend/v3/storage/database/dialect/config.go @@ -0,0 +1,92 @@ +package dialect + +import ( + "context" + "errors" + "reflect" + + "github.com/mitchellh/mapstructure" + "github.com/spf13/viper" + + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/backend/storage/database/dialect/postgres" +) + +type Hook struct { + Match func(string) bool + Decode func(config any) (database.Connector, error) + Name string + Constructor func() database.Connector +} + +var hooks = []Hook{ + { + Match: postgres.NameMatcher, + Decode: postgres.DecodeConfig, + Name: postgres.Name, + Constructor: func() database.Connector { return new(postgres.Config) }, + }, + // { + // Match: gosql.NameMatcher, + // Decode: gosql.DecodeConfig, + // Name: gosql.Name, + // Constructor: func() database.Connector { return new(gosql.Config) }, + // }, +} + +type Config struct { + Dialects map[string]any `mapstructure:",remain" yaml:",inline"` + + connector database.Connector +} + +func (c Config) Connect(ctx context.Context) (database.Pool, error) { + if len(c.Dialects) != 1 { + return nil, errors.New("Exactly one dialect must be configured") + } + + return c.connector.Connect(ctx) +} + +// Hooks implements [configure.Unmarshaller]. +func (c Config) Hooks() []viper.DecoderConfigOption { + return []viper.DecoderConfigOption{ + viper.DecodeHook(decodeHook), + } +} + +func decodeHook(from, to reflect.Value) (_ any, err error) { + if to.Type() != reflect.TypeOf(Config{}) { + return from.Interface(), nil + } + + config := new(Config) + if err = mapstructure.Decode(from.Interface(), config); err != nil { + return nil, err + } + + if err = config.decodeDialect(); err != nil { + return nil, err + } + + return config, nil +} + +func (c *Config) decodeDialect() error { + for _, hook := range hooks { + for name, config := range c.Dialects { + if !hook.Match(name) { + continue + } + + connector, err := hook.Decode(config) + if err != nil { + return err + } + + c.connector = connector + return nil + } + } + return errors.New("no dialect found") +} diff --git a/backend/v3/storage/database/dialect/postgres/config.go b/backend/v3/storage/database/dialect/postgres/config.go new file mode 100644 index 0000000000..9862526106 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/config.go @@ -0,0 +1,80 @@ +package postgres + +import ( + "context" + "errors" + "slices" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/mitchellh/mapstructure" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var ( + _ database.Connector = (*Config)(nil) + Name = "postgres" +) + +type Config struct { + *pgxpool.Config + + // Host string + // Port int32 + // Database string + // MaxOpenConns uint32 + // MaxIdleConns uint32 + // MaxConnLifetime time.Duration + // MaxConnIdleTime time.Duration + // User User + // // Additional options to be appended as options= + // // The value will be taken as is. Multiple options are space separated. + // Options string + + configuredFields []string +} + +// Connect implements [database.Connector]. +func (c *Config) Connect(ctx context.Context) (database.Pool, error) { + pool, err := pgxpool.NewWithConfig(ctx, c.Config) + if err != nil { + return nil, err + } + if err = pool.Ping(ctx); err != nil { + return nil, err + } + return &pgxPool{pool}, nil +} + +func NameMatcher(name string) bool { + return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name)) +} + +func DecodeConfig(input any) (database.Connector, error) { + switch c := input.(type) { + case string: + config, err := pgxpool.ParseConfig(c) + if err != nil { + return nil, err + } + return &Config{Config: config}, nil + case map[string]any: + connector := new(Config) + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + WeaklyTypedInput: true, + Result: connector, + }) + if err != nil { + return nil, err + } + if err = decoder.Decode(c); err != nil { + return nil, err + } + return &Config{ + Config: &pgxpool.Config{}, + }, nil + } + return nil, errors.New("invalid configuration") +} diff --git a/backend/v3/storage/database/dialect/postgres/conn.go b/backend/v3/storage/database/dialect/postgres/conn.go new file mode 100644 index 0000000000..1ce061e2fb --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/conn.go @@ -0,0 +1,48 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type pgxConn struct{ *pgxpool.Conn } + +var _ database.Client = (*pgxConn)(nil) + +// Release implements [database.Client]. +func (c *pgxConn) Release(_ context.Context) error { + c.Conn.Release() + return nil +} + +// Begin implements [database.Client]. +func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.Conn.BeginTx(ctx, transactionOptionsToPgx(opts)) + if err != nil { + return nil, err + } + return &pgxTx{tx}, nil +} + +// Query implements sql.Client. +// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn. +func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := c.Conn.Query(ctx, sql, args...) + return &Rows{rows}, err +} + +// QueryRow implements sql.Client. +// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn. +func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return c.Conn.QueryRow(ctx, sql, args...) +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error { + _, err := c.Conn.Exec(ctx, sql, args...) + return err +} diff --git a/backend/v3/storage/database/dialect/postgres/pool.go b/backend/v3/storage/database/dialect/postgres/pool.go new file mode 100644 index 0000000000..bbfb421413 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/pool.go @@ -0,0 +1,57 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type pgxPool struct{ *pgxpool.Pool } + +var _ database.Pool = (*pgxPool)(nil) + +// Acquire implements [database.Pool]. +func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { + conn, err := c.Pool.Acquire(ctx) + if err != nil { + return nil, err + } + return &pgxConn{conn}, nil +} + +// Query implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. +func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := c.Pool.Query(ctx, sql, args...) + return &Rows{rows}, err +} + +// QueryRow implements [database.Pool]. +// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. +func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return c.Pool.QueryRow(ctx, sql, args...) +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error { + _, err := c.Pool.Exec(ctx, sql, args...) + return err +} + +// Begin implements [database.Pool]. +func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.Pool.BeginTx(ctx, transactionOptionsToPgx(opts)) + if err != nil { + return nil, err + } + return &pgxTx{tx}, nil +} + +// Close implements [database.Pool]. +func (c *pgxPool) Close(_ context.Context) error { + c.Pool.Close() + return nil +} diff --git a/backend/v3/storage/database/dialect/postgres/rows.go b/backend/v3/storage/database/dialect/postgres/rows.go new file mode 100644 index 0000000000..891a2a3f46 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/rows.go @@ -0,0 +1,18 @@ +package postgres + +import ( + "github.com/jackc/pgx/v5" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var _ database.Rows = (*Rows)(nil) + +type Rows struct{ pgx.Rows } + +// Close implements [database.Rows]. +// Subtle: this method shadows the method (Rows).Close of Rows.Rows. +func (r *Rows) Close() error { + r.Rows.Close() + return nil +} diff --git a/backend/v3/storage/database/dialect/postgres/tx.go b/backend/v3/storage/database/dialect/postgres/tx.go new file mode 100644 index 0000000000..bfac46572d --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/tx.go @@ -0,0 +1,95 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type pgxTx struct{ pgx.Tx } + +var _ database.Transaction = (*pgxTx)(nil) + +// Commit implements [database.Transaction]. +func (tx *pgxTx) Commit(ctx context.Context) error { + return tx.Tx.Commit(ctx) +} + +// Rollback implements [database.Transaction]. +func (tx *pgxTx) Rollback(ctx context.Context) error { + return tx.Tx.Rollback(ctx) +} + +// End implements [database.Transaction]. +func (tx *pgxTx) End(ctx context.Context, err error) error { + if err != nil { + tx.Rollback(ctx) + return err + } + return tx.Commit(ctx) +} + +// Query implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx. +func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := tx.Tx.Query(ctx, sql, args...) + return &Rows{rows}, err +} + +// QueryRow implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx. +func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return tx.Tx.QueryRow(ctx, sql, args...) +} + +// Exec implements [database.Transaction]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error { + _, err := tx.Tx.Exec(ctx, sql, args...) + return err +} + +// Begin implements [database.Transaction]. +// As postgres does not support nested transactions we use savepoints to emulate them. +func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) { + savepoint, err := tx.Tx.Begin(ctx) + if err != nil { + return nil, err + } + return &pgxTx{savepoint}, nil +} + +func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions { + if opts == nil { + return pgx.TxOptions{} + } + + return pgx.TxOptions{ + IsoLevel: isolationToPgx(opts.IsolationLevel), + AccessMode: accessModeToPgx(opts.AccessMode), + } +} + +func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel { + switch isolation { + case database.IsolationLevelSerializable: + return pgx.Serializable + case database.IsolationLevelReadCommitted: + return pgx.ReadCommitted + default: + return pgx.Serializable + } +} + +func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode { + switch accessMode { + case database.AccessModeReadWrite: + return pgx.ReadWrite + case database.AccessModeReadOnly: + return pgx.ReadOnly + default: + return pgx.ReadWrite + } +} diff --git a/backend/v3/storage/database/gen_mock.go b/backend/v3/storage/database/gen_mock.go new file mode 100644 index 0000000000..e8a319c7f0 --- /dev/null +++ b/backend/v3/storage/database/gen_mock.go @@ -0,0 +1,3 @@ +package database + +//go:generate mockgen -typed -package mock -destination ./mock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction diff --git a/backend/v3/storage/database/mock/database.mock.go b/backend/v3/storage/database/mock/database.mock.go new file mode 100644 index 0000000000..2460d5b75c --- /dev/null +++ b/backend/v3/storage/database/mock/database.mock.go @@ -0,0 +1,1067 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Client,Row,Rows,Transaction) +// +// Generated by this command: +// +// mockgen -typed -package mock -destination ./mock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction +// + +// Package mock is a generated GoMock package. +package mock + +import ( + context "context" + reflect "reflect" + + database "github.com/zitadel/zitadel/backend/v3/storage/database" + gomock "go.uber.org/mock/gomock" +) + +// MockPool is a mock of Pool interface. +type MockPool struct { + ctrl *gomock.Controller + recorder *MockPoolMockRecorder +} + +// MockPoolMockRecorder is the mock recorder for MockPool. +type MockPoolMockRecorder struct { + mock *MockPool +} + +// NewMockPool creates a new mock instance. +func NewMockPool(ctrl *gomock.Controller) *MockPool { + mock := &MockPool{ctrl: ctrl} + mock.recorder = &MockPoolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPool) EXPECT() *MockPoolMockRecorder { + return m.recorder +} + +// Acquire mocks base method. +func (m *MockPool) Acquire(arg0 context.Context) (database.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Acquire", arg0) + ret0, _ := ret[0].(database.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Acquire indicates an expected call of Acquire. +func (mr *MockPoolMockRecorder) Acquire(arg0 any) *MockPoolAcquireCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), arg0) + return &MockPoolAcquireCall{Call: call} +} + +// MockPoolAcquireCall wrap *gomock.Call +type MockPoolAcquireCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolAcquireCall) Return(arg0 database.Client, arg1 error) *MockPoolAcquireCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Begin mocks base method. +func (m *MockPool) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin", arg0, arg1) + ret0, _ := ret[0].(database.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin. +func (mr *MockPoolMockRecorder) Begin(arg0, arg1 any) *MockPoolBeginCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), arg0, arg1) + return &MockPoolBeginCall{Call: call} +} + +// MockPoolBeginCall wrap *gomock.Call +type MockPoolBeginCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolBeginCall) Return(arg0 database.Transaction, arg1 error) *MockPoolBeginCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockPoolBeginCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockPoolBeginCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Close mocks base method. +func (m *MockPool) Close(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockPoolMockRecorder) Close(arg0 any) *MockPoolCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), arg0) + return &MockPoolCloseCall{Call: call} +} + +// MockPoolCloseCall wrap *gomock.Call +type MockPoolCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolCloseCall) Return(arg0 error) *MockPoolCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolCloseCall) Do(f func(context.Context) error) *MockPoolCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPoolCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockPoolMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockPoolExecCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...) + return &MockPoolExecCall{Call: call} +} + +// MockPoolExecCall wrap *gomock.Call +type MockPoolExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolExecCall) Return(arg0 error) *MockPoolExecCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolExecCall) Do(f func(context.Context, string, ...any) error) *MockPoolExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockPoolExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Query mocks base method. +func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(database.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockPoolMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockPoolQueryCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...) + return &MockPoolQueryCall{Call: call} +} + +// MockPoolQueryCall wrap *gomock.Call +type MockPoolQueryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolQueryCall) Return(arg0 database.Rows, arg1 error) *MockPoolQueryCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockPoolQueryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockPoolQueryCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryRow mocks base method. +func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(database.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow. +func (mr *MockPoolMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockPoolQueryRowCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...) + return &MockPoolQueryRowCall{Call: call} +} + +// MockPoolQueryRowCall wrap *gomock.Call +type MockPoolQueryRowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolQueryRowCall) Return(arg0 database.Row) *MockPoolQueryRowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockPoolQueryRowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockPoolQueryRowCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// Begin mocks base method. +func (m *MockClient) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin", arg0, arg1) + ret0, _ := ret[0].(database.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin. +func (mr *MockClientMockRecorder) Begin(arg0, arg1 any) *MockClientBeginCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockClient)(nil).Begin), arg0, arg1) + return &MockClientBeginCall{Call: call} +} + +// MockClientBeginCall wrap *gomock.Call +type MockClientBeginCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientBeginCall) Return(arg0 database.Transaction, arg1 error) *MockClientBeginCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockClientMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockClientExecCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...) + return &MockClientExecCall{Call: call} +} + +// MockClientExecCall wrap *gomock.Call +type MockClientExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientExecCall) Return(arg0 error) *MockClientExecCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) error) *MockClientExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockClientExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Query mocks base method. +func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(database.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockClientMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockClientQueryCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockClient)(nil).Query), varargs...) + return &MockClientQueryCall{Call: call} +} + +// MockClientQueryCall wrap *gomock.Call +type MockClientQueryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientQueryCall) Return(arg0 database.Rows, arg1 error) *MockClientQueryCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryRow mocks base method. +func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(database.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow. +func (mr *MockClientMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockClientQueryRowCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockClient)(nil).QueryRow), varargs...) + return &MockClientQueryRowCall{Call: call} +} + +// MockClientQueryRowCall wrap *gomock.Call +type MockClientQueryRowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientQueryRowCall) Return(arg0 database.Row) *MockClientQueryRowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Release mocks base method. +func (m *MockClient) Release(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Release", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Release indicates an expected call of Release. +func (mr *MockClientMockRecorder) Release(arg0 any) *MockClientReleaseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockClient)(nil).Release), arg0) + return &MockClientReleaseCall{Call: call} +} + +// MockClientReleaseCall wrap *gomock.Call +type MockClientReleaseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientReleaseCall) Return(arg0 error) *MockClientReleaseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientReleaseCall) Do(f func(context.Context) error) *MockClientReleaseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *MockClientReleaseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockRow is a mock of Row interface. +type MockRow struct { + ctrl *gomock.Controller + recorder *MockRowMockRecorder +} + +// MockRowMockRecorder is the mock recorder for MockRow. +type MockRowMockRecorder struct { + mock *MockRow +} + +// NewMockRow creates a new mock instance. +func NewMockRow(ctrl *gomock.Controller) *MockRow { + mock := &MockRow{ctrl: ctrl} + mock.recorder = &MockRowMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRow) EXPECT() *MockRowMockRecorder { + return m.recorder +} + +// Scan mocks base method. +func (m *MockRow) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowMockRecorder) Scan(arg0 ...any) *MockRowScanCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0...) + return &MockRowScanCall{Call: call} +} + +// MockRowScanCall wrap *gomock.Call +type MockRowScanCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowScanCall) Return(arg0 error) *MockRowScanCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowScanCall) Do(f func(...any) error) *MockRowScanCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowScanCall) DoAndReturn(f func(...any) error) *MockRowScanCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *MockRowsCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) + return &MockRowsCloseCall{Call: call} +} + +// MockRowsCloseCall wrap *gomock.Call +type MockRowsCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsCloseCall) Return(arg0 error) *MockRowsCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsCloseCall) Do(f func() error) *MockRowsCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsCloseCall) DoAndReturn(f func() error) *MockRowsCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *MockRowsErrCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) + return &MockRowsErrCall{Call: call} +} + +// MockRowsErrCall wrap *gomock.Call +type MockRowsErrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsErrCall) Return(arg0 error) *MockRowsErrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsErrCall) Do(f func() error) *MockRowsErrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsErrCall) DoAndReturn(f func() error) *MockRowsErrCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *MockRowsNextCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) + return &MockRowsNextCall{Call: call} +} + +// MockRowsNextCall wrap *gomock.Call +type MockRowsNextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsNextCall) Return(arg0 bool) *MockRowsNextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsNextCall) Do(f func() bool) *MockRowsNextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsNextCall) DoAndReturn(f func() bool) *MockRowsNextCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *MockRowsScanCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) + return &MockRowsScanCall{Call: call} +} + +// MockRowsScanCall wrap *gomock.Call +type MockRowsScanCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsScanCall) Return(arg0 error) *MockRowsScanCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsScanCall) Do(f func(...any) error) *MockRowsScanCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsScanCall) DoAndReturn(f func(...any) error) *MockRowsScanCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockTransaction is a mock of Transaction interface. +type MockTransaction struct { + ctrl *gomock.Controller + recorder *MockTransactionMockRecorder +} + +// MockTransactionMockRecorder is the mock recorder for MockTransaction. +type MockTransactionMockRecorder struct { + mock *MockTransaction +} + +// NewMockTransaction creates a new mock instance. +func NewMockTransaction(ctrl *gomock.Controller) *MockTransaction { + mock := &MockTransaction{ctrl: ctrl} + mock.recorder = &MockTransactionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder { + return m.recorder +} + +// Begin mocks base method. +func (m *MockTransaction) Begin(arg0 context.Context) (database.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin", arg0) + ret0, _ := ret[0].(database.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin. +func (mr *MockTransactionMockRecorder) Begin(arg0 any) *MockTransactionBeginCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), arg0) + return &MockTransactionBeginCall{Call: call} +} + +// MockTransactionBeginCall wrap *gomock.Call +type MockTransactionBeginCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionBeginCall) Return(arg0 database.Transaction, arg1 error) *MockTransactionBeginCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionBeginCall) Do(f func(context.Context) (database.Transaction, error)) *MockTransactionBeginCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionBeginCall) DoAndReturn(f func(context.Context) (database.Transaction, error)) *MockTransactionBeginCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Commit mocks base method. +func (m *MockTransaction) Commit(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTransactionMockRecorder) Commit(arg0 any) *MockTransactionCommitCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), arg0) + return &MockTransactionCommitCall{Call: call} +} + +// MockTransactionCommitCall wrap *gomock.Call +type MockTransactionCommitCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionCommitCall) Return(arg0 error) *MockTransactionCommitCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionCommitCall) Do(f func(context.Context) error) *MockTransactionCommitCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionCommitCall) DoAndReturn(f func(context.Context) error) *MockTransactionCommitCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// End mocks base method. +func (m *MockTransaction) End(arg0 context.Context, arg1 error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "End", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// End indicates an expected call of End. +func (mr *MockTransactionMockRecorder) End(arg0, arg1 any) *MockTransactionEndCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), arg0, arg1) + return &MockTransactionEndCall{Call: call} +} + +// MockTransactionEndCall wrap *gomock.Call +type MockTransactionEndCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionEndCall) Return(arg0 error) *MockTransactionEndCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionEndCall) Do(f func(context.Context, error) error) *MockTransactionEndCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) error) *MockTransactionEndCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) error { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTransactionMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockTransactionExecCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...) + return &MockTransactionExecCall{Call: call} +} + +// MockTransactionExecCall wrap *gomock.Call +type MockTransactionExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionExecCall) Return(arg0 error) *MockTransactionExecCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionExecCall) Do(f func(context.Context, string, ...any) error) *MockTransactionExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockTransactionExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Query mocks base method. +func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(database.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockTransactionMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockTransactionQueryCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...) + return &MockTransactionQueryCall{Call: call} +} + +// MockTransactionQueryCall wrap *gomock.Call +type MockTransactionQueryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionQueryCall) Return(arg0 database.Rows, arg1 error) *MockTransactionQueryCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockTransactionQueryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockTransactionQueryCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryRow mocks base method. +func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(database.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow. +func (mr *MockTransactionMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockTransactionQueryRowCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...) + return &MockTransactionQueryRowCall{Call: call} +} + +// MockTransactionQueryRowCall wrap *gomock.Call +type MockTransactionQueryRowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionQueryRowCall) Return(arg0 database.Row) *MockTransactionQueryRowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockTransactionQueryRowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockTransactionQueryRowCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Rollback mocks base method. +func (m *MockTransaction) Rollback(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockTransactionMockRecorder) Rollback(arg0 any) *MockTransactionRollbackCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), arg0) + return &MockTransactionRollbackCall{Call: call} +} + +// MockTransactionRollbackCall wrap *gomock.Call +type MockTransactionRollbackCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionRollbackCall) Return(arg0 error) *MockTransactionRollbackCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionRollbackCall) Do(f func(context.Context) error) *MockTransactionRollbackCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionRollbackCall) DoAndReturn(f func(context.Context) error) *MockTransactionRollbackCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/backend/v3/storage/database/repository/clause.go b/backend/v3/storage/database/repository/clause.go new file mode 100644 index 0000000000..df0603d1bc --- /dev/null +++ b/backend/v3/storage/database/repository/clause.go @@ -0,0 +1,160 @@ +package repository + +import ( + "fmt" + + "github.com/zitadel/zitadel/backend/v3/domain" +) + +type field interface { + fmt.Stringer +} + +type fieldDescriptor struct { + schema string + table string + name string +} + +func (f fieldDescriptor) String() string { + return f.schema + "." + f.table + "." + f.name +} + +type ignoreCaseFieldDescriptor struct { + fieldDescriptor + fieldNameSuffix string +} + +func (f ignoreCaseFieldDescriptor) String() string { + return f.fieldDescriptor.String() + f.fieldNameSuffix +} + +type textFieldDescriptor struct { + field + isIgnoreCase bool +} + +type clause[Op domain.Operation] struct { + field field + op Op +} + +const ( + schema = "zitadel" + userTable = "users" +) + +var userFields = map[domain.UserField]field{ + domain.UserFieldInstanceID: fieldDescriptor{ + schema: schema, + table: userTable, + name: "instance_id", + }, + domain.UserFieldOrgID: fieldDescriptor{ + schema: schema, + table: userTable, + name: "org_id", + }, + domain.UserFieldID: fieldDescriptor{ + schema: schema, + table: userTable, + name: "id", + }, + domain.UserFieldUsername: textFieldDescriptor{ + field: ignoreCaseFieldDescriptor{ + fieldDescriptor: fieldDescriptor{ + schema: schema, + table: userTable, + name: "username", + }, + fieldNameSuffix: "_lower", + }, + }, + domain.UserHumanFieldEmail: textFieldDescriptor{ + field: ignoreCaseFieldDescriptor{ + fieldDescriptor: fieldDescriptor{ + schema: schema, + table: userTable, + name: "email", + }, + fieldNameSuffix: "_lower", + }, + }, + domain.UserHumanFieldEmailVerified: fieldDescriptor{ + schema: schema, + table: userTable, + name: "email_is_verified", + }, +} + +type textClause[V domain.Text] struct { + clause[domain.TextOperation] + value V +} + +var textOp map[domain.TextOperation]string = map[domain.TextOperation]string{ + domain.TextOperationEqual: " = ", + domain.TextOperationNotEqual: " <> ", + domain.TextOperationStartsWith: " LIKE ", + domain.TextOperationStartsWithIgnoreCase: " LIKE ", +} + +func (tc textClause[V]) Write(stmt *statement) { + placeholder := stmt.appendArg(tc.value) + var ( + left, right string + ) + switch tc.clause.op { + case domain.TextOperationEqual: + left = tc.clause.field.String() + right = placeholder + case domain.TextOperationNotEqual: + left = tc.clause.field.String() + right = placeholder + case domain.TextOperationStartsWith: + left = tc.clause.field.String() + right = placeholder + "%" + case domain.TextOperationStartsWithIgnoreCase: + left = tc.clause.field.String() + if _, ok := tc.clause.field.(ignoreCaseFieldDescriptor); !ok { + left = "LOWER(" + left + ")" + } + right = "LOWER(" + placeholder + "%)" + } + + stmt.builder.WriteString(left) + stmt.builder.WriteString(textOp[tc.clause.op]) + stmt.builder.WriteString(right) +} + +type boolClause[V domain.Bool] struct { + clause[domain.BoolOperation] + value V +} + +func (bc boolClause[V]) Write(stmt *statement) { + if !bc.value { + stmt.builder.WriteString("NOT ") + } + stmt.builder.WriteString(bc.clause.field.String()) +} + +type numberClause[V domain.Number] struct { + clause[domain.NumberOperation] + value V +} + +var numberOp map[domain.NumberOperation]string = map[domain.NumberOperation]string{ + domain.NumberOperationEqual: " = ", + domain.NumberOperationNotEqual: " <> ", + domain.NumberOperationLessThan: " < ", + domain.NumberOperationLessThanOrEqual: " <= ", + domain.NumberOperationGreaterThan: " > ", + domain.NumberOperationGreaterThanOrEqual: " >= ", +} + +func (nc numberClause[V]) Write(stmt *statement) { + stmt.builder.WriteString(nc.clause.field.String()) + stmt.builder.WriteString(numberOp[nc.clause.op]) + stmt.builder.WriteString(stmt.appendArg(nc.value)) +} diff --git a/backend/v3/storage/database/repository/crypto.go b/backend/v3/storage/database/repository/crypto.go new file mode 100644 index 0000000000..cd547af147 --- /dev/null +++ b/backend/v3/storage/database/repository/crypto.go @@ -0,0 +1,45 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/internal/crypto" +) + +type cryptoRepo struct { + database.QueryExecutor +} + +func Crypto(db database.QueryExecutor) domain.CryptoRepository { + return &cryptoRepo{ + QueryExecutor: db, + } +} + +const getEncryptionConfigQuery = "SELECT" + + " length" + + ", expiry" + + ", should_include_lower_letters" + + ", should_include_upper_letters" + + ", should_include_digits" + + ", should_include_symbols" + + " FROM encryption_config" + +func (repo *cryptoRepo) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) { + var config crypto.GeneratorConfig + row := repo.QueryRow(ctx, getEncryptionConfigQuery) + err := row.Scan( + &config.Length, + &config.Expiry, + &config.IncludeLowerLetters, + &config.IncludeUpperLetters, + &config.IncludeDigits, + &config.IncludeSymbols, + ) + if err != nil { + return nil, err + } + return &config, nil +} diff --git a/backend/v3/storage/database/repository/doc.go b/backend/v3/storage/database/repository/doc.go new file mode 100644 index 0000000000..ba567e747c --- /dev/null +++ b/backend/v3/storage/database/repository/doc.go @@ -0,0 +1,7 @@ +// Repository package provides the database repository for the application. +// It contains the implementation of the [repository pattern](https://martinfowler.com/eaaCatalog/repository.html) for the database. + +// funcs which need to interact with the database should create interfaces which are implemented by the +// [query] and [exec] structs respectively their factory methods [Query] and [Execute]. The [query] struct is used for read operations, while the [exec] struct is used for write operations. + +package repository diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go new file mode 100644 index 0000000000..d8432c546c --- /dev/null +++ b/backend/v3/storage/database/repository/instance.go @@ -0,0 +1,54 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type instance struct { + database.QueryExecutor +} + +func Instance(client database.QueryExecutor) domain.InstanceRepository { + return &instance{QueryExecutor: client} +} + +func (i *instance) ByID(ctx context.Context, id string) (*domain.Instance, error) { + var instance domain.Instance + err := i.QueryExecutor.QueryRow(ctx, `SELECT id, name, created_at, updated_at, deleted_at FROM instances WHERE id = $1`, id).Scan( + &instance.ID, + &instance.Name, + &instance.CreatedAt, + &instance.UpdatedAt, + &instance.DeletedAt, + ) + if err != nil { + return nil, err + } + return &instance, nil +} + +const createInstanceStmt = `INSERT INTO instances (id, name) VALUES ($1, $2) RETURNING created_at, updated_at` + +// Create implements [domain.InstanceRepository]. +func (i *instance) Create(ctx context.Context, instance *domain.Instance) error { + return i.QueryExecutor.QueryRow(ctx, createInstanceStmt, + instance.ID, + instance.Name, + ).Scan( + &instance.CreatedAt, + &instance.UpdatedAt, + ) +} + +// On implements [domain.InstanceRepository]. +func (i *instance) On(id string) domain.InstanceOperation { + return &instanceOperation{ + QueryExecutor: i.QueryExecutor, + id: id, + } +} + +var _ domain.InstanceRepository = (*instance)(nil) diff --git a/backend/v3/storage/database/repository/instance_operation.go b/backend/v3/storage/database/repository/instance_operation.go new file mode 100644 index 0000000000..928c164a3c --- /dev/null +++ b/backend/v3/storage/database/repository/instance_operation.go @@ -0,0 +1,52 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type instanceOperation struct { + database.QueryExecutor + id string +} + +const addInstanceAdminStmt = `INSERT INTO instance_admins (instance_id, user_id, roles) VALUES ($1, $2, $3)` + +// AddAdmin implements [domain.InstanceOperation]. +func (i *instanceOperation) AddAdmin(ctx context.Context, userID string, roles []string) error { + return i.QueryExecutor.Exec(ctx, addInstanceAdminStmt, i.id, userID, roles) +} + +// Delete implements [domain.InstanceOperation]. +func (i *instanceOperation) Delete(ctx context.Context) error { + return i.QueryExecutor.Exec(ctx, `DELETE FROM instances WHERE id = $1`, i.id) +} + +const removeInstanceAdminStmt = `DELETE FROM instance_admins WHERE instance_id = $1 AND user_id = $2` + +// RemoveAdmin implements [domain.InstanceOperation]. +func (i *instanceOperation) RemoveAdmin(ctx context.Context, userID string) error { + return i.QueryExecutor.Exec(ctx, removeInstanceAdminStmt, i.id, userID) +} + +const setInstanceAdminRolesStmt = `UPDATE instance_admins SET roles = $1 WHERE instance_id = $2 AND user_id = $3` + +// SetAdminRoles implements [domain.InstanceOperation]. +func (i *instanceOperation) SetAdminRoles(ctx context.Context, userID string, roles []string) error { + return i.QueryExecutor.Exec(ctx, setInstanceAdminRolesStmt, roles, i.id, userID) +} + +const updateInstanceStmt = `UPDATE instances SET name = $1, updated_at = $2 WHERE id = $3 RETURNING updated_at` + +// Update implements [domain.InstanceOperation]. +func (i *instanceOperation) Update(ctx context.Context, instance *domain.Instance) error { + return i.QueryExecutor.QueryRow(ctx, updateInstanceStmt, + instance.Name, + instance.UpdatedAt, + i.id, + ).Scan(&instance.UpdatedAt) +} + +var _ domain.InstanceOperation = (*instanceOperation)(nil) diff --git a/backend/v3/storage/database/repository/query.go b/backend/v3/storage/database/repository/query.go new file mode 100644 index 0000000000..fc026bae43 --- /dev/null +++ b/backend/v3/storage/database/repository/query.go @@ -0,0 +1,17 @@ +package repository + +import ( + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type query struct{ database.Querier } + +func Query(querier database.Querier) *query { + return &query{Querier: querier} +} + +type executor struct{ database.Executor } + +func Execute(exec database.Executor) *executor { + return &executor{Executor: exec} +} diff --git a/backend/v3/storage/database/repository/statement.go b/backend/v3/storage/database/repository/statement.go new file mode 100644 index 0000000000..50138c02b2 --- /dev/null +++ b/backend/v3/storage/database/repository/statement.go @@ -0,0 +1,21 @@ +package repository + +import "strings" + +type statement struct { + builder strings.Builder + args []any +} + +func (s *statement) appendArg(arg any) (placeholder string) { + s.args = append(s.args, arg) + return "$" + string(len(s.args)) +} + +func (s *statement) appendArgs(args ...any) (placeholders []string) { + placeholders = make([]string, len(args)) + for i, arg := range args { + placeholders[i] = s.appendArg(arg) + } + return placeholders +} diff --git a/backend/v3/storage/database/repository/stmt/column.go b/backend/v3/storage/database/repository/stmt/column.go new file mode 100644 index 0000000000..c11b2b256e --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/column.go @@ -0,0 +1,43 @@ +package stmt + +import "fmt" + +type Column[T any] interface { + fmt.Stringer + statementApplier[T] + scanner(t *T) any +} + +type columnDescriptor[T any] struct { + name string + scan func(*T) any +} + +func (cd columnDescriptor[T]) scanner(t *T) any { + return cd.scan(t) +} + +// Apply implements [Column]. +func (f columnDescriptor[T]) Apply(stmt *statement[T]) { + stmt.builder.WriteString(stmt.columnPrefix()) + stmt.builder.WriteString(f.String()) +} + +// String implements [Column]. +func (f columnDescriptor[T]) String() string { + return f.name +} + +var _ Column[any] = (*columnDescriptor[any])(nil) + +type ignoreCaseColumnDescriptor[T any] struct { + columnDescriptor[T] + fieldNameSuffix string +} + +func (f ignoreCaseColumnDescriptor[T]) ApplyIgnoreCase(stmt *statement[T]) { + stmt.builder.WriteString(f.String()) + stmt.builder.WriteString(f.fieldNameSuffix) +} + +var _ Column[any] = (*ignoreCaseColumnDescriptor[any])(nil) diff --git a/backend/v3/storage/database/repository/stmt/condition.go b/backend/v3/storage/database/repository/stmt/condition.go new file mode 100644 index 0000000000..bce0ca1b44 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/condition.go @@ -0,0 +1,97 @@ +package stmt + +import "fmt" + +type statementApplier[T any] interface { + // Apply writes the statement to the builder. + Apply(stmt *statement[T]) +} + +type Condition[T any] interface { + statementApplier[T] +} + +type op interface { + TextOperation | NumberOperation | ListOperation + fmt.Stringer +} + +type operation[T any, O op] struct { + o O +} + +func (o operation[T, O]) String() string { + return o.o.String() +} + +func (o operation[T, O]) Apply(stmt *statement[T]) { + stmt.builder.WriteString(o.o.String()) +} + +type condition[V, T any, OP op] struct { + field Column[T] + op OP + value V +} + +func (c *condition[V, T, OP]) Apply(stmt *statement[T]) { + // placeholder := stmt.appendArg(c.value) + stmt.builder.WriteString(stmt.columnPrefix()) + stmt.builder.WriteString(c.field.String()) + // stmt.builder.WriteString(c.op) + // stmt.builder.WriteString(placeholder) +} + +type and[T any] struct { + conditions []Condition[T] +} + +func And[T any](conditions ...Condition[T]) *and[T] { + return &and[T]{ + conditions: conditions, + } +} + +// Apply implements [Condition]. +func (a *and[T]) Apply(stmt *statement[T]) { + if len(a.conditions) > 1 { + stmt.builder.WriteString("(") + defer stmt.builder.WriteString(")") + } + + for i, condition := range a.conditions { + if i > 0 { + stmt.builder.WriteString(" AND ") + } + condition.Apply(stmt) + } +} + +var _ Condition[any] = (*and[any])(nil) + +type or[T any] struct { + conditions []Condition[T] +} + +func Or[T any](conditions ...Condition[T]) *or[T] { + return &or[T]{ + conditions: conditions, + } +} + +// Apply implements [Condition]. +func (o *or[T]) Apply(stmt *statement[T]) { + if len(o.conditions) > 1 { + stmt.builder.WriteString("(") + defer stmt.builder.WriteString(")") + } + + for i, condition := range o.conditions { + if i > 0 { + stmt.builder.WriteString(" OR ") + } + condition.Apply(stmt) + } +} + +var _ Condition[any] = (*or[any])(nil) diff --git a/backend/v3/storage/database/repository/stmt/list.go b/backend/v3/storage/database/repository/stmt/list.go new file mode 100644 index 0000000000..90114ace73 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/list.go @@ -0,0 +1,71 @@ +package stmt + +type ListEntry interface { + Number | Text | any +} + +type ListCondition[E ListEntry, T any] struct { + condition[[]E, T, ListOperation] +} + +func (lc *ListCondition[E, T]) Apply(stmt *statement[T]) { + placeholder := stmt.appendArg(lc.value) + + switch lc.op { + case ListOperationEqual, ListOperationNotEqual: + lc.field.Apply(stmt) + operation[T, ListOperation]{lc.op}.Apply(stmt) + stmt.builder.WriteString(placeholder) + case ListOperationContainsAny, ListOperationContainsAll: + lc.field.Apply(stmt) + operation[T, ListOperation]{lc.op}.Apply(stmt) + stmt.builder.WriteString(placeholder) + case ListOperationNotContainsAny, ListOperationNotContainsAll: + stmt.builder.WriteString("NOT (") + lc.field.Apply(stmt) + operation[T, ListOperation]{lc.op}.Apply(stmt) + stmt.builder.WriteString(placeholder) + stmt.builder.WriteString(")") + default: + panic("unknown list operation") + } +} + +type ListOperation uint8 + +const ( + // ListOperationEqual checks if the arrays are equal including the order of the elements + ListOperationEqual ListOperation = iota + 1 + // ListOperationNotEqual checks if the arrays are not equal including the order of the elements + ListOperationNotEqual + + // ListOperationContains checks if the array column contains all the values of the specified array + ListOperationContainsAll + // ListOperationContainsAny checks if the arrays have at least one value in common + ListOperationContainsAny + // ListOperationContainsAll checks if the array column contains all the values of the specified array + + // ListOperationNotContainsAll checks if the specified array is not contained by the column + ListOperationNotContainsAll + // ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array + ListOperationNotContainsAny +) + +var listOperations = map[ListOperation]string{ + // ListOperationEqual checks if the lists are equal + ListOperationEqual: " = ", + // ListOperationNotEqual checks if the lists are not equal + ListOperationNotEqual: " <> ", + // ListOperationContainsAny checks if the arrays have at least one value in common + ListOperationContainsAny: " && ", + // ListOperationContainsAll checks if the array column contains all the values of the specified array + ListOperationContainsAll: " @> ", + // ListOperationNotContainsAny checks if the arrays column contains none of the values of the specified array + ListOperationNotContainsAny: " && ", // Base operator for NOT (A && B) + // ListOperationNotContainsAll checks if the array column is not contained by the specified array + ListOperationNotContainsAll: " <@ ", // Base operator for NOT (A <@ B) +} + +func (lo ListOperation) String() string { + return listOperations[lo] +} diff --git a/backend/v3/storage/database/repository/stmt/number.go b/backend/v3/storage/database/repository/stmt/number.go new file mode 100644 index 0000000000..9dfb6e44bf --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/number.go @@ -0,0 +1,61 @@ +package stmt + +import ( + "time" + + "golang.org/x/exp/constraints" +) + +type Number interface { + constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration +} + +type between[N Number] struct { + min, max N +} + +type NumberBetween[V Number, T any] struct { + condition[between[V], T, NumberOperation] +} + +func (nb *NumberBetween[V, T]) Apply(stmt *statement[T]) { + nb.field.Apply(stmt) + stmt.builder.WriteString(" BETWEEN ") + stmt.builder.WriteString(stmt.appendArg(nb.value.min)) + stmt.builder.WriteString(" AND ") + stmt.builder.WriteString(stmt.appendArg(nb.value.max)) +} + +type NumberCondition[V Number, T any] struct { + condition[V, T, NumberOperation] +} + +func (nc *NumberCondition[V, T]) Apply(stmt *statement[T]) { + nc.field.Apply(stmt) + operation[T, NumberOperation]{nc.op}.Apply(stmt) + stmt.builder.WriteString(stmt.appendArg(nc.value)) +} + +type NumberOperation uint8 + +const ( + NumberOperationEqual NumberOperation = iota + 1 + NumberOperationNotEqual + NumberOperationLessThan + NumberOperationLessThanOrEqual + NumberOperationGreaterThan + NumberOperationGreaterThanOrEqual +) + +var numberOperations = map[NumberOperation]string{ + NumberOperationEqual: " = ", + NumberOperationNotEqual: " <> ", + NumberOperationLessThan: " < ", + NumberOperationLessThanOrEqual: " <= ", + NumberOperationGreaterThan: " > ", + NumberOperationGreaterThanOrEqual: " >= ", +} + +func (no NumberOperation) String() string { + return numberOperations[no] +} diff --git a/backend/v3/storage/database/repository/stmt/statement.go b/backend/v3/storage/database/repository/stmt/statement.go new file mode 100644 index 0000000000..edc0b79967 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/statement.go @@ -0,0 +1,104 @@ +package stmt + +import ( + "fmt" + "strings" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type statement[T any] struct { + builder strings.Builder + client database.QueryExecutor + + columns []Column[T] + + schema string + table string + alias string + + condition Condition[T] + + limit uint32 + offset uint32 + // order by fieldname and sort direction false for asc true for desc + // orderBy SortingColumns[C] + args []any + existingArgs map[any]string +} + +func (s *statement[T]) scanners(t *T) []any { + scanners := make([]any, len(s.columns)) + for i, column := range s.columns { + scanners[i] = column.scanner(t) + } + return scanners +} + +func (s *statement[T]) query() string { + s.builder.WriteString(`SELECT `) + for i, column := range s.columns { + if i > 0 { + s.builder.WriteString(", ") + } + column.Apply(s) + } + s.builder.WriteString(` FROM `) + s.builder.WriteString(s.schema) + s.builder.WriteRune('.') + s.builder.WriteString(s.table) + if s.alias != "" { + s.builder.WriteString(" AS ") + s.builder.WriteString(s.alias) + } + + s.builder.WriteString(` WHERE `) + + s.condition.Apply(s) + + if s.limit > 0 { + s.builder.WriteString(` LIMIT `) + s.builder.WriteString(s.appendArg(s.limit)) + } + if s.offset > 0 { + s.builder.WriteString(` OFFSET `) + s.builder.WriteString(s.appendArg(s.offset)) + } + + return s.builder.String() +} + +// func (s *statement[T]) Where(condition Condition[T]) *statement[T] { +// s.condition = condition +// return s +// } + +// func (s *statement[T]) Limit(limit uint32) *statement[T] { +// s.limit = limit +// return s +// } + +// func (s *statement[T]) Offset(offset uint32) *statement[T] { +// s.offset = offset +// return s +// } + +func (s *statement[T]) columnPrefix() string { + if s.alias != "" { + return s.alias + "." + } + return s.schema + "." + s.table + "." +} + +func (s *statement[T]) appendArg(arg any) string { + if s.existingArgs == nil { + s.existingArgs = make(map[any]string) + } + if existing, ok := s.existingArgs[arg]; ok { + return existing + } + s.args = append(s.args, arg) + placeholder := fmt.Sprintf("$%d", len(s.args)) + s.existingArgs[arg] = placeholder + return placeholder +} diff --git a/backend/v3/storage/database/repository/stmt/stmt_test.go b/backend/v3/storage/database/repository/stmt/stmt_test.go new file mode 100644 index 0000000000..2956a3dd10 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/stmt_test.go @@ -0,0 +1,18 @@ +package stmt_test + +import ( + "context" + "testing" + + "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt" +) + +func Test_Bla(t *testing.T) { + stmt.User(nil).Where( + stmt.Or( + stmt.UserIDCondition("123"), + stmt.UserIDCondition("123"), + stmt.UserUsernameCondition(stmt.TextOperationEqualIgnoreCase, "test"), + ), + ).Limit(1).Offset(1).Get(context.Background()) +} diff --git a/backend/v3/storage/database/repository/stmt/text.go b/backend/v3/storage/database/repository/stmt/text.go new file mode 100644 index 0000000000..da1b00abc0 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/text.go @@ -0,0 +1,72 @@ +package stmt + +type Text interface { + ~string | ~[]byte +} + +type TextCondition[V Text, T any] struct { + condition[V, T, TextOperation] +} + +func (tc *TextCondition[V, T]) Apply(stmt *statement[T]) { + placeholder := stmt.appendArg(tc.value) + + switch tc.op { + case TextOperationEqual, TextOperationNotEqual: + tc.field.Apply(stmt) + operation[T, TextOperation]{tc.op}.Apply(stmt) + stmt.builder.WriteString(placeholder) + case TextOperationEqualIgnoreCase: + if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok { + desc.ApplyIgnoreCase(stmt) + } else { + stmt.builder.WriteString("LOWER(") + tc.field.Apply(stmt) + stmt.builder.WriteString(")") + } + operation[T, TextOperation]{tc.op}.Apply(stmt) + stmt.builder.WriteString("LOWER(") + stmt.builder.WriteString(placeholder) + stmt.builder.WriteString(")") + case TextOperationStartsWith: + tc.field.Apply(stmt) + operation[T, TextOperation]{tc.op}.Apply(stmt) + stmt.builder.WriteString(placeholder) + stmt.builder.WriteString("|| '%'") + case TextOperationStartsWithIgnoreCase: + if desc, ok := tc.field.(ignoreCaseColumnDescriptor[T]); ok { + desc.ApplyIgnoreCase(stmt) + } else { + stmt.builder.WriteString("LOWER(") + tc.field.Apply(stmt) + stmt.builder.WriteString(")") + } + operation[T, TextOperation]{tc.op}.Apply(stmt) + stmt.builder.WriteString("LOWER(") + stmt.builder.WriteString(placeholder) + stmt.builder.WriteString(")") + stmt.builder.WriteString("|| '%'") + } +} + +type TextOperation uint8 + +const ( + TextOperationEqual TextOperation = iota + 1 + TextOperationEqualIgnoreCase + TextOperationNotEqual + TextOperationStartsWith + TextOperationStartsWithIgnoreCase +) + +var textOperations = map[TextOperation]string{ + TextOperationEqual: " = ", + TextOperationEqualIgnoreCase: " = ", + TextOperationNotEqual: " <> ", + TextOperationStartsWith: " LIKE ", + TextOperationStartsWithIgnoreCase: " LIKE ", +} + +func (to TextOperation) String() string { + return textOperations[to] +} diff --git a/backend/v3/storage/database/repository/stmt/user.go b/backend/v3/storage/database/repository/stmt/user.go new file mode 100644 index 0000000000..e0f8d388e0 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/user.go @@ -0,0 +1,193 @@ +package stmt + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type userStatement struct { + statement[domain.User] +} + +func User(client database.QueryExecutor) *userStatement { + return &userStatement{ + statement: statement[domain.User]{ + schema: "zitadel", + table: "users", + alias: "u", + client: client, + columns: []Column[domain.User]{ + userColumns[UserInstanceID], + userColumns[UserOrgID], + userColumns[UserColumnID], + userColumns[UserColumnUsername], + userColumns[UserCreatedAt], + userColumns[UserUpdatedAt], + userColumns[UserDeletedAt], + }, + }, + } +} + +func (s *userStatement) Where(condition Condition[domain.User]) *userStatement { + s.condition = condition + return s +} + +func (s *userStatement) Limit(limit uint32) *userStatement { + s.limit = limit + return s +} + +func (s *userStatement) Offset(offset uint32) *userStatement { + s.offset = offset + return s +} + +func (s *userStatement) Get(ctx context.Context) (*domain.User, error) { + var user domain.User + err := s.client.QueryRow(ctx, s.query(), s.statement.args...).Scan(s.scanners(&user)...) + + if err != nil { + return nil, err + } + + return &user, nil +} + +func (s *userStatement) List(ctx context.Context) ([]*domain.User, error) { + var users []*domain.User + rows, err := s.client.Query(ctx, s.query(), s.statement.args...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var user domain.User + err = rows.Scan(s.scanners(&user)...) + if err != nil { + return nil, err + } + users = append(users, &user) + } + + return users, nil +} + +func (s *userStatement) SetUsername(ctx context.Context, username string) error { + return nil +} + +type UserColumn uint8 + +var ( + userColumns map[UserColumn]Column[domain.User] = map[UserColumn]Column[domain.User]{ + UserInstanceID: columnDescriptor[domain.User]{ + name: "instance_id", + scan: func(u *domain.User) any { + return &u.InstanceID + }, + }, + UserOrgID: columnDescriptor[domain.User]{ + name: "org_id", + scan: func(u *domain.User) any { + return &u.OrgID + }, + }, + UserColumnID: columnDescriptor[domain.User]{ + name: "id", + scan: func(u *domain.User) any { + return &u.ID + }, + }, + UserColumnUsername: ignoreCaseColumnDescriptor[domain.User]{ + columnDescriptor: columnDescriptor[domain.User]{ + name: "username", + scan: func(u *domain.User) any { + return &u.Username + }, + }, + fieldNameSuffix: "_lower", + }, + UserCreatedAt: columnDescriptor[domain.User]{ + name: "created_at", + scan: func(u *domain.User) any { + return &u.CreatedAt + }, + }, + UserUpdatedAt: columnDescriptor[domain.User]{ + name: "updated_at", + scan: func(u *domain.User) any { + return &u.UpdatedAt + }, + }, + UserDeletedAt: columnDescriptor[domain.User]{ + name: "deleted_at", + scan: func(u *domain.User) any { + return &u.DeletedAt + }, + }, + } + humanColumns = map[UserColumn]Column[domain.User]{ + UserHumanColumnEmail: ignoreCaseColumnDescriptor[domain.User]{ + columnDescriptor: columnDescriptor[domain.User]{ + name: "email", + scan: func(u *domain.User) any { + human, ok := u.Traits.(*domain.Human) + if !ok { + return nil + } + if human.Email == nil { + human.Email = new(domain.Email) + } + return &human.Email.Address + }, + }, + fieldNameSuffix: "_lower", + }, + UserHumanColumnEmailVerified: columnDescriptor[domain.User]{ + name: "email_is_verified", + scan: func(u *domain.User) any { + human, ok := u.Traits.(*domain.Human) + if !ok { + return nil + } + if human.Email == nil { + human.Email = new(domain.Email) + } + return &human.Email.IsVerified + }, + }, + } + machineColumns = map[UserColumn]Column[domain.User]{ + UserMachineDescription: columnDescriptor[domain.User]{ + name: "description", + scan: func(u *domain.User) any { + machine, ok := u.Traits.(*domain.Machine) + if !ok { + return nil + } + if machine == nil { + machine = new(domain.Machine) + } + return &machine.Description + }, + }, + } +) + +const ( + UserInstanceID UserColumn = iota + 1 + UserOrgID + UserColumnID + UserColumnUsername + UserHumanColumnEmail + UserHumanColumnEmailVerified + UserMachineDescription + UserCreatedAt + UserUpdatedAt + UserDeletedAt +) diff --git a/backend/v3/storage/database/repository/stmt/user_condition.go b/backend/v3/storage/database/repository/stmt/user_condition.go new file mode 100644 index 0000000000..ba138efe6b --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/user_condition.go @@ -0,0 +1,23 @@ +package stmt + +import "github.com/zitadel/zitadel/backend/v3/domain" + +func UserIDCondition(id string) *TextCondition[string, domain.User] { + return &TextCondition[string, domain.User]{ + condition: condition[string, domain.User, TextOperation]{ + field: userColumns[UserColumnID], + op: TextOperationEqual, + value: id, + }, + } +} + +func UserUsernameCondition(op TextOperation, username string) *TextCondition[string, domain.User] { + return &TextCondition[string, domain.User]{ + condition: condition[string, domain.User, TextOperation]{ + field: userColumns[UserColumnUsername], + op: op, + value: username, + }, + } +} diff --git a/backend/v3/storage/database/repository/stmt/v2/table.go b/backend/v3/storage/database/repository/stmt/v2/table.go new file mode 100644 index 0000000000..0efc396fe6 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v2/table.go @@ -0,0 +1,135 @@ +package stmt + +// type table struct { +// schema string +// name string + +// possibleJoins []*join + +// columns []*col +// } + +// type col struct { +// *table + +// name string +// } + +// type join struct { +// *table + +// on []*joinColumns +// } + +// type joinColumns struct { +// left, right *col +// } + +// var ( +// userTable = &table{ +// schema: "zitadel", +// name: "users", +// } +// userColumns = []*col{ +// userInstanceIDColumn, +// userOrgIDColumn, +// userIDColumn, +// userUsernameColumn, +// } +// userInstanceIDColumn = &col{ +// table: userTable, +// name: "instance_id", +// } +// userOrgIDColumn = &col{ +// table: userTable, +// name: "org_id", +// } +// userIDColumn = &col{ +// table: userTable, +// name: "id", +// } +// userUsernameColumn = &col{ +// table: userTable, +// name: "username", +// } +// userJoins = []*join{ +// { +// table: instanceTable, +// on: []*joinColumns{ +// { +// left: instanceIDColumn, +// right: userInstanceIDColumn, +// }, +// }, +// }, +// { +// table: orgTable, +// on: []*joinColumns{ +// { +// left: orgIDColumn, +// right: userOrgIDColumn, +// }, +// }, +// }, +// } +// ) + +// var ( +// instanceTable = &table{ +// schema: "zitadel", +// name: "instances", +// } +// instanceColumns = []*col{ +// instanceIDColumn, +// instanceNameColumn, +// } +// instanceIDColumn = &col{ +// table: instanceTable, +// name: "id", +// } +// instanceNameColumn = &col{ +// table: instanceTable, +// name: "name", +// } +// ) + +// var ( +// orgTable = &table{ +// schema: "zitadel", +// name: "orgs", +// } +// orgColumns = []*col{ +// orgInstanceIDColumn, +// orgIDColumn, +// orgNameColumn, +// } +// orgInstanceIDColumn = &col{ +// table: orgTable, +// name: "instance_id", +// } +// orgIDColumn = &col{ +// table: orgTable, +// name: "id", +// } +// orgNameColumn = &col{ +// table: orgTable, +// name: "name", +// } +// ) + +// func init() { +// instanceTable.columns = instanceColumns +// userTable.columns = userColumns + +// userTable.possibleJoins = []join{ +// { +// table: userTable, +// on: []joinColumns{ +// { +// left: userIDColumn, +// right: userIDColumn, +// }, +// }, +// }, +// } +// } diff --git a/backend/v3/storage/database/repository/stmt/v3/column.go b/backend/v3/storage/database/repository/stmt/v3/column.go new file mode 100644 index 0000000000..60ba0e6750 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/column.go @@ -0,0 +1,55 @@ +package v3 + +type Column interface { + Name() string + Write(builder statementBuilder) +} + +type ignoreCaseColumn interface { + Column + WriteIgnoreCase(builder statementBuilder) +} + +var ( + columnNameID = "id" + columnNameName = "name" + columnNameCreatedAt = "created_at" + columnNameUpdatedAt = "updated_at" + columnNameDeletedAt = "deleted_at" + + columnNameInstanceID = "instance_id" + + columnNameOrgID = "org_id" +) + +type column struct { + table Table + name string +} + +// Write implements Column. +func (c *column) Write(builder statementBuilder) { + c.table.writeOn(builder) + builder.writeRune('.') + builder.writeString(c.name) +} + +// Name implements [Column]. +func (c *column) Name() string { + return c.name +} + +var _ Column = (*column)(nil) + +type columnIgnoreCase struct { + column + suffix string +} + +// WriteIgnoreCase implements ignoreCaseColumn. +func (c *columnIgnoreCase) WriteIgnoreCase(builder statementBuilder) { + c.Write(builder) + builder.writeString(c.suffix) +} + +var _ ignoreCaseColumn = (*columnIgnoreCase)(nil) diff --git a/backend/v3/storage/database/repository/stmt/v3/condition.go b/backend/v3/storage/database/repository/stmt/v3/condition.go new file mode 100644 index 0000000000..1766242b89 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/condition.go @@ -0,0 +1,182 @@ +package v3 + +type statementBuilder interface { + write([]byte) + writeString(string) + writeRune(rune) + + appendArg(any) (placeholder string) + table() Table +} + +type Condition interface { + writeOn(builder statementBuilder) +} + +type and struct { + conditions []Condition +} + +func And(conditions ...Condition) *and { + return &and{conditions: conditions} +} + +// writeOn implements [Condition]. +func (a *and) writeOn(builder statementBuilder) { + if len(a.conditions) > 1 { + builder.writeString("(") + defer builder.writeString(")") + } + + for i, condition := range a.conditions { + if i > 0 { + builder.writeString(" AND ") + } + condition.writeOn(builder) + } +} + +var _ Condition = (*and)(nil) + +type or struct { + conditions []Condition +} + +func Or(conditions ...Condition) *or { + return &or{conditions: conditions} +} + +// writeOn implements [Condition]. +func (o *or) writeOn(builder statementBuilder) { + if len(o.conditions) > 1 { + builder.writeString("(") + defer builder.writeString(")") + } + + for i, condition := range o.conditions { + if i > 0 { + builder.writeString(" OR ") + } + condition.writeOn(builder) + } +} + +var _ Condition = (*or)(nil) + +type isNull struct { + column Column +} + +func IsNull(column Column) *isNull { + return &isNull{column: column} +} + +// writeOn implements [Condition]. +func (cond *isNull) writeOn(builder statementBuilder) { + cond.column.Write(builder) + builder.writeString(" IS NULL") +} + +var _ Condition = (*isNull)(nil) + +type isNotNull struct { + column Column +} + +func IsNotNull(column Column) *isNotNull { + return &isNotNull{column: column} +} + +// writeOn implements [Condition]. +func (cond *isNotNull) writeOn(builder statementBuilder) { + cond.column.Write(builder) + builder.writeString(" IS NOT NULL") +} + +var _ Condition = (*isNotNull)(nil) + +type condition[Op Operator, V Value] struct { + column Column + operator Op + value V +} + +// writeOn implements [Condition]. +func (cond condition[Op, V]) writeOn(builder statementBuilder) { + cond.column.Write(builder) + builder.writeString(cond.operator.String()) + builder.writeString(builder.appendArg(cond.value)) +} + +var _ Condition = (*condition[TextOperator, string])(nil) + +type textCondition[V Text] struct { + condition[TextOperator, V] +} + +func NewTextCondition[V Text](column Column, operator TextOperator, value V) *textCondition[V] { + return &textCondition[V]{ + condition: condition[TextOperator, V]{ + column: column, + operator: operator, + value: value, + }, + } +} + +// writeOn implements [Condition]. +func (cond *textCondition[V]) writeOn(builder statementBuilder) { + switch cond.operator { + case TextOperatorEqual, TextOperatorNotEqual: + cond.column.Write(builder) + builder.writeString(cond.operator.String()) + builder.writeString(builder.appendArg(cond.value)) + case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase: + if col, ok := cond.column.(ignoreCaseColumn); ok { + col.WriteIgnoreCase(builder) + } else { + builder.writeString("LOWER(") + cond.column.Write(builder) + builder.writeString(")") + } + builder.writeString(cond.operator.String()) + builder.writeString("LOWER(") + builder.writeString(builder.appendArg(cond.value)) + builder.writeString(")") + case TextOperatorStartsWith: + cond.column.Write(builder) + builder.writeString(cond.operator.String()) + builder.writeString(builder.appendArg(cond.value)) + builder.writeString(" || '%'") + case TextOperatorStartsWithIgnoreCase: + if col, ok := cond.column.(ignoreCaseColumn); ok { + col.WriteIgnoreCase(builder) + } else { + builder.writeString("LOWER(") + cond.column.Write(builder) + builder.writeString(")") + } + builder.writeString(cond.operator.String()) + builder.writeString("LOWER(") + builder.writeString(builder.appendArg(cond.value)) + builder.writeString(") || '%'") + } +} + +var _ Condition = (*textCondition[string])(nil) + +type numberCondition[V Number] struct { + condition[NumberOperator, V] +} + +func NewNumberCondition[V Number](column Column, operator NumberOperator, value V) *numberCondition[V] { + return &numberCondition[V]{ + condition: condition[NumberOperator, V]{ + column: column, + operator: operator, + value: value, + }, + } +} + +var _ Condition = (*numberCondition[int])(nil) diff --git a/backend/v3/storage/database/repository/stmt/v3/instance.go b/backend/v3/storage/database/repository/stmt/v3/instance.go new file mode 100644 index 0000000000..7967d4f788 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/instance.go @@ -0,0 +1,104 @@ +package v3 + +import ( + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Instance struct { + id string + name string + + createdAt time.Time + updatedAt time.Time + deletedAt time.Time +} + +// Columns implements [object]. +func (Instance) Columns(table Table) []Column { + return []Column{ + &column{ + table: table, + name: columnNameID, + }, + &column{ + table: table, + name: columnNameName, + }, + &column{ + table: table, + name: columnNameCreatedAt, + }, + &column{ + table: table, + name: columnNameUpdatedAt, + }, + &column{ + table: table, + name: columnNameDeletedAt, + }, + } +} + +// Scan implements [object]. +func (i Instance) Scan(row database.Scanner) error { + return row.Scan( + &i.id, + &i.name, + &i.createdAt, + &i.updatedAt, + &i.deletedAt, + ) +} + +type instanceTable struct { + *table +} + +func InstanceTable() *instanceTable { + table := &instanceTable{ + table: newTable[Instance]("zitadel", "instances"), + } + + table.possibleJoins = func(t Table) map[string]Column { + switch on := t.(type) { + case *instanceTable: + return map[string]Column{ + columnNameID: on.IDColumn(), + } + case *orgTable: + return map[string]Column{ + columnNameID: on.InstanceIDColumn(), + } + case *userTable: + return map[string]Column{ + columnNameID: on.InstanceIDColumn(), + } + default: + return nil + } + } + + return table +} + +func (i *instanceTable) IDColumn() Column { + return i.columns[columnNameID] +} + +func (i *instanceTable) NameColumn() Column { + return i.columns[columnNameName] +} + +func (i *instanceTable) CreatedAtColumn() Column { + return i.columns[columnNameCreatedAt] +} + +func (i *instanceTable) UpdatedAtColumn() Column { + return i.columns[columnNameUpdatedAt] +} + +func (i *instanceTable) DeletedAtColumn() Column { + return i.columns[columnNameDeletedAt] +} diff --git a/backend/v3/storage/database/repository/stmt/v3/join.go b/backend/v3/storage/database/repository/stmt/v3/join.go new file mode 100644 index 0000000000..e35948cab4 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/join.go @@ -0,0 +1,11 @@ +package v3 + +type join struct { + table Table + conditions []joinCondition +} + +type joinCondition struct { + left Column + right Column +} diff --git a/backend/v3/storage/database/repository/stmt/v3/operator.go b/backend/v3/storage/database/repository/stmt/v3/operator.go new file mode 100644 index 0000000000..e9c1ff9c9f --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/operator.go @@ -0,0 +1,82 @@ +package v3 + +import ( + "fmt" + "time" + + "golang.org/x/exp/constraints" +) + +type Value interface { + Bool | Number | Text +} + +type Text interface { + ~string | ~[]byte +} + +type Number interface { + constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration +} + +type Bool interface { + ~bool +} + +type Operator interface { + fmt.Stringer +} + +type TextOperator uint8 + +// String implements [Operator]. +func (t TextOperator) String() string { + return textOperators[t] +} + +const ( + TextOperatorEqual TextOperator = iota + 1 + TextOperatorEqualIgnoreCase + TextOperatorNotEqual + TextOperatorNotEqualIgnoreCase + TextOperatorStartsWith + TextOperatorStartsWithIgnoreCase +) + +var textOperators = map[TextOperator]string{ + TextOperatorEqual: " = ", + TextOperatorEqualIgnoreCase: " LIKE ", + TextOperatorNotEqual: " <> ", + TextOperatorNotEqualIgnoreCase: " NOT LIKE ", + TextOperatorStartsWith: " LIKE ", + TextOperatorStartsWithIgnoreCase: " LIKE ", +} + +var _ Operator = TextOperator(0) + +type NumberOperator uint8 + +// String implements Operator. +func (n NumberOperator) String() string { + return numberOperators[n] +} + +const ( + NumberOperatorEqual NumberOperator = iota + 1 + NumberOperatorNotEqual + NumberOperatorLessThan + NumberOperatorLessThanOrEqual + NumberOperatorGreaterThan + NumberOperatorGreaterThanOrEqual +) + +var numberOperators = map[NumberOperator]string{ + NumberOperatorEqual: " = ", + NumberOperatorNotEqual: " <> ", + NumberOperatorLessThan: " < ", + NumberOperatorLessThanOrEqual: " <= ", + NumberOperatorGreaterThan: " > ", + NumberOperatorGreaterThanOrEqual: " >= ", +} + +var _ Operator = NumberOperator(0) diff --git a/backend/v3/storage/database/repository/stmt/v3/org.go b/backend/v3/storage/database/repository/stmt/v3/org.go new file mode 100644 index 0000000000..27926ed7c7 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/org.go @@ -0,0 +1,117 @@ +package v3 + +import ( + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Org struct { + instanceID string + id string + + name string + + createdAt time.Time + updatedAt time.Time + deletedAt time.Time +} + +// Columns implements [object]. +func (Org) Columns(table Table) []Column { + return []Column{ + &column{ + table: table, + name: columnNameInstanceID, + }, + &column{ + table: table, + name: columnNameID, + }, + &column{ + table: table, + name: columnNameName, + }, + &column{ + table: table, + name: columnNameCreatedAt, + }, + &column{ + table: table, + name: columnNameUpdatedAt, + }, + &column{ + table: table, + name: columnNameDeletedAt, + }, + } +} + +// Scan implements [object]. +func (o Org) Scan(row database.Scanner) error { + return row.Scan( + &o.instanceID, + &o.id, + &o.name, + &o.createdAt, + &o.updatedAt, + &o.deletedAt, + ) +} + +type orgTable struct { + *table +} + +func OrgTable() *orgTable { + table := &orgTable{ + table: newTable[Org]("zitadel", "orgs"), + } + + table.possibleJoins = func(table Table) map[string]Column { + switch on := table.(type) { + case *instanceTable: + return map[string]Column{ + columnNameInstanceID: on.IDColumn(), + } + case *orgTable: + return map[string]Column{ + columnNameInstanceID: on.InstanceIDColumn(), + columnNameID: on.IDColumn(), + } + case *userTable: + return map[string]Column{ + columnNameInstanceID: on.InstanceIDColumn(), + columnNameID: on.IDColumn(), + } + default: + return nil + } + } + + return table +} + +func (o *orgTable) InstanceIDColumn() Column { + return o.columns[columnNameInstanceID] +} + +func (o *orgTable) IDColumn() Column { + return o.columns[columnNameID] +} + +func (o *orgTable) NameColumn() Column { + return o.columns[columnNameName] +} + +func (o *orgTable) CreatedAtColumn() Column { + return o.columns[columnNameCreatedAt] +} + +func (o *orgTable) UpdatedAtColumn() Column { + return o.columns[columnNameUpdatedAt] +} + +func (o *orgTable) DeletedAtColumn() Column { + return o.columns[columnNameDeletedAt] +} diff --git a/backend/v3/storage/database/repository/stmt/v3/query.go b/backend/v3/storage/database/repository/stmt/v3/query.go new file mode 100644 index 0000000000..4d1ada6a68 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/query.go @@ -0,0 +1,188 @@ +package v3 + +import ( + "context" + "fmt" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Query[O object] interface { + Where(condition Condition) + Join(tables ...Table) + Limit(limit uint32) + Offset(offset uint32) + OrderBy(columns ...Column) + + Result(ctx context.Context, client database.Querier) (*O, error) + Results(ctx context.Context, client database.Querier) ([]O, error) + + fmt.Stringer + statementBuilder +} + +type query[O object] struct { + *statement[O] + joins []join + limit uint32 + offset uint32 + orderBy []Column +} + +func NewQuery[O object](table Table) Query[O] { + return &query[O]{ + statement: newStatement[O](table), + } +} + +// Result implements [Query]. +func (q *query[O]) Result(ctx context.Context, client database.Querier) (*O, error) { + var object O + row := client.QueryRow(ctx, q.String(), q.args...) + if err := object.Scan(row); err != nil { + return nil, err + } + return &object, nil +} + +// Results implements [Query]. +func (q *query[O]) Results(ctx context.Context, client database.Querier) ([]O, error) { + var objects []O + rows, err := client.Query(ctx, q.String(), q.args...) + if err != nil { + return nil, err + } + defer rows.Close() + + for rows.Next() { + var object O + if err := object.Scan(rows); err != nil { + return nil, err + } + objects = append(objects, object) + } + + return objects, rows.Err() +} + +// Join implements [Query]. +func (q *query[O]) Join(tables ...Table) { + for _, tbl := range tables { + cols := q.tbl.(*table).possibleJoins(tbl) + if len(cols) == 0 { + panic(fmt.Sprintf("table %q does not have any possible joins with table %q", q.tbl.Name(), tbl.Name())) + } + + q.joins = append(q.joins, join{ + table: tbl, + conditions: make([]joinCondition, 0, len(cols)), + }) + + for colName, col := range cols { + q.joins[len(q.joins)-1].conditions = append(q.joins[len(q.joins)-1].conditions, joinCondition{ + left: q.tbl.(*table).columns[colName], + right: col, + }) + } + } +} + +func (q *query[O]) Limit(limit uint32) { + q.limit = limit +} + +func (q *query[O]) Offset(offset uint32) { + q.offset = offset +} + +func (q *query[O]) OrderBy(columns ...Column) { + for _, allowedColumn := range q.columns { + for _, column := range columns { + if allowedColumn.Name() == column.Name() { + q.orderBy = append(q.orderBy, column) + } + } + } +} + +// String implements [fmt.Stringer] and [Query]. +func (q *query[O]) String() string { + q.writeSelectColumns() + q.writeFrom() + q.writeJoins() + q.writeCondition() + q.writeOrderBy() + q.writeLimit() + q.writeOffset() + q.writeGroupBy() + return q.builder.String() +} + +func (q *query[O]) writeSelectColumns() { + q.builder.WriteString("SELECT ") + for i, column := range q.columns { + if i > 0 { + q.builder.WriteString(", ") + } + q.builder.WriteString(q.tbl.Alias()) + q.builder.WriteRune('.') + q.builder.WriteString(column.Name()) + } +} + +func (q *query[O]) writeJoins() { + for _, join := range q.joins { + q.builder.WriteString(" JOIN ") + q.builder.WriteString(join.table.Schema()) + q.builder.WriteRune('.') + q.builder.WriteString(join.table.Name()) + if join.table.Alias() != "" { + q.builder.WriteString(" AS ") + q.builder.WriteString(join.table.Alias()) + } + + q.builder.WriteString(" ON ") + for i, condition := range join.conditions { + if i > 0 { + q.builder.WriteString(" AND ") + } + q.builder.WriteString(condition.left.Name()) + q.builder.WriteString(" = ") + q.builder.WriteString(condition.right.Name()) + } + } +} + +func (q *query[O]) writeOrderBy() { + if len(q.orderBy) == 0 { + return + } + + q.builder.WriteString(" ORDER BY ") + for i, order := range q.orderBy { + if i > 0 { + q.builder.WriteString(", ") + } + order.Write(q) + } +} + +func (q *query[O]) writeLimit() { + if q.limit == 0 { + return + } + q.builder.WriteString(" LIMIT ") + q.builder.WriteString(q.appendArg(q.limit)) +} + +func (q *query[O]) writeOffset() { + if q.offset == 0 { + return + } + q.builder.WriteString(" OFFSET ") + q.builder.WriteString(q.appendArg(q.offset)) +} + +func (q *query[O]) writeGroupBy() { + q.builder.WriteString(" GROUP BY ") +} diff --git a/backend/v3/storage/database/repository/stmt/v3/statement.go b/backend/v3/storage/database/repository/stmt/v3/statement.go new file mode 100644 index 0000000000..57884f357b --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/statement.go @@ -0,0 +1,85 @@ +package v3 + +import ( + "fmt" + "strings" +) + +type statement[T object] struct { + tbl Table + columns []Column + condition Condition + + builder strings.Builder + args []any + existingArgs map[any]string +} + +func newStatement[O object](t Table) *statement[O] { + var o O + return &statement[O]{ + tbl: t, + columns: o.Columns(t), + } +} + +// Where implements [Query]. +func (stmt *statement[T]) Where(condition Condition) { + stmt.condition = condition +} + +func (stmt *statement[T]) writeFrom() { + stmt.builder.WriteString(" FROM ") + stmt.builder.WriteString(stmt.tbl.Schema()) + stmt.builder.WriteRune('.') + stmt.builder.WriteString(stmt.tbl.Name()) + if stmt.tbl.Alias() != "" { + stmt.builder.WriteString(" AS ") + stmt.builder.WriteString(stmt.tbl.Alias()) + } +} + +func (stmt *statement[T]) writeCondition() { + if stmt.condition == nil { + return + } + stmt.builder.WriteString(" WHERE ") + stmt.condition.writeOn(stmt) +} + +// appendArg implements [statementBuilder]. +func (stmt *statement[T]) appendArg(arg any) (placeholder string) { + if stmt.existingArgs == nil { + stmt.existingArgs = make(map[any]string) + } + if placeholder, ok := stmt.existingArgs[arg]; ok { + return placeholder + } + + stmt.args = append(stmt.args, arg) + placeholder = fmt.Sprintf("$%d", len(stmt.args)) + stmt.existingArgs[arg] = placeholder + return placeholder +} + +// table implements [statementBuilder]. +func (stmt *statement[T]) table() Table { + return stmt.tbl +} + +// write implements [statementBuilder]. +func (stmt *statement[T]) write(data []byte) { + stmt.builder.Write(data) +} + +// writeRune implements [statementBuilder]. +func (stmt *statement[T]) writeRune(r rune) { + stmt.builder.WriteRune(r) +} + +// writeString implements [statementBuilder]. +func (stmt *statement[T]) writeString(s string) { + stmt.builder.WriteString(s) +} + +var _ statementBuilder = (*statement[Instance])(nil) diff --git a/backend/v3/storage/database/repository/stmt/v3/table.go b/backend/v3/storage/database/repository/stmt/v3/table.go new file mode 100644 index 0000000000..95a0f6f58b --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/table.go @@ -0,0 +1,84 @@ +package v3 + +import "github.com/zitadel/zitadel/backend/v3/storage/database" + +type object interface { + User | Org | Instance + Columns(t Table) []Column + Scan(s database.Scanner) error +} + +type Table interface { + Schema() string + Name() string + Alias() string + Columns() []Column + + writeOn(builder statementBuilder) +} + +type table struct { + schema string + name string + alias string + + possibleJoins func(table Table) map[string]Column + + columns map[string]Column + colList []Column +} + +func newTable[O object](schema, name string) *table { + t := &table{ + schema: schema, + name: name, + } + + var o O + t.colList = o.Columns(t) + t.columns = make(map[string]Column, len(t.colList)) + for _, col := range t.colList { + t.columns[col.Name()] = col + } + + return t +} + +// Columns implements [Table]. +func (t *table) Columns() []Column { + if len(t.colList) > 0 { + return t.colList + } + + t.colList = make([]Column, 0, len(t.columns)) + for _, column := range t.columns { + t.colList = append(t.colList, column) + } + + return t.colList +} + +// Name implements [Table]. +func (t *table) Name() string { + return t.name +} + +// Schema implements [Table]. +func (t *table) Schema() string { + return t.schema +} + +// Alias implements [Table]. +func (t *table) Alias() string { + if t.alias != "" { + return t.alias + } + return t.schema + "." + t.name +} + +// writeOn implements [Table]. +func (t *table) writeOn(builder statementBuilder) { + builder.writeString(t.Alias()) +} + +var _ Table = (*table)(nil) diff --git a/backend/v3/storage/database/repository/stmt/v3/user.go b/backend/v3/storage/database/repository/stmt/v3/user.go new file mode 100644 index 0000000000..f872382902 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/user.go @@ -0,0 +1,170 @@ +package v3 + +import ( + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type User struct { + instanceID string + orgID string + id string + username string + + createdAt time.Time + updatedAt time.Time + deletedAt time.Time +} + +// Columns implements [object]. +func (u User) Columns(table Table) []Column { + return []Column{ + &column{ + table: table, + name: columnNameInstanceID, + }, + &column{ + table: table, + name: columnNameOrgID, + }, + &column{ + table: table, + name: columnNameID, + }, + &columnIgnoreCase{ + column: column{ + table: table, + name: userTableUsernameColumn, + }, + suffix: "_lower", + }, + &column{ + table: table, + name: columnNameCreatedAt, + }, + &column{ + table: table, + name: columnNameUpdatedAt, + }, + &column{ + table: table, + name: columnNameDeletedAt, + }, + } +} + +// Scan implements [object]. +func (u User) Scan(row database.Scanner) error { + return row.Scan( + &u.instanceID, + &u.orgID, + &u.id, + &u.username, + &u.createdAt, + &u.updatedAt, + &u.deletedAt, + ) +} + +type userTable struct { + *table +} + +const ( + userTableUsernameColumn = "username" +) + +func UserTable() *userTable { + table := &userTable{ + table: newTable[User]("zitadel", "users"), + } + + table.possibleJoins = func(table Table) map[string]Column { + switch on := table.(type) { + case *userTable: + return map[string]Column{ + columnNameInstanceID: on.InstanceIDColumn(), + columnNameOrgID: on.OrgIDColumn(), + columnNameID: on.IDColumn(), + } + case *orgTable: + return map[string]Column{ + columnNameInstanceID: on.InstanceIDColumn(), + columnNameOrgID: on.IDColumn(), + } + case *instanceTable: + return map[string]Column{ + columnNameInstanceID: on.IDColumn(), + } + default: + return nil + } + } + + return table +} + +func (t *userTable) InstanceIDColumn() Column { + return t.columns[columnNameInstanceID] +} + +func (t *userTable) OrgIDColumn() Column { + return t.columns[columnNameOrgID] +} + +func (t *userTable) IDColumn() Column { + return t.columns[columnNameID] +} + +func (t *userTable) UsernameColumn() Column { + return t.columns[userTableUsernameColumn] +} + +func (t *userTable) CreatedAtColumn() Column { + return t.columns[columnNameCreatedAt] +} + +func (t *userTable) UpdatedAtColumn() Column { + return t.columns[columnNameUpdatedAt] +} + +func (t *userTable) DeletedAtColumn() Column { + return t.columns[columnNameDeletedAt] +} + +func NewUserQuery() Query[User] { + q := NewQuery[User](UserTable()) + return q +} + +type userByIDCondition[T Text] struct { + id T +} + +func UserByID[T Text](id T) Condition { + return &userByIDCondition[T]{id: id} +} + +// writeOn implements Condition. +func (u *userByIDCondition[T]) writeOn(builder statementBuilder) { + NewTextCondition(builder.table().(*userTable).IDColumn(), TextOperatorEqual, u.id).writeOn(builder) +} + +var _ Condition = (*userByIDCondition[string])(nil) + +type userByUsernameCondition[T Text] struct { + username T + operator TextOperator +} + +func UserByUsername[T Text](username T, operator TextOperator) Condition { + return &userByUsernameCondition[T]{username: username, operator: operator} +} + +// writeOn implements Condition. +func (u *userByUsernameCondition[T]) writeOn(builder statementBuilder) { + NewTextCondition(builder.table().(*userTable).UsernameColumn(), u.operator, u.username).writeOn(builder) +} + +var _ Condition = (*userByUsernameCondition[string])(nil) diff --git a/backend/v3/storage/database/repository/stmt/v3/user_test.go b/backend/v3/storage/database/repository/stmt/v3/user_test.go new file mode 100644 index 0000000000..4bcbca7ee9 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v3/user_test.go @@ -0,0 +1,25 @@ +package v3_test + +import ( + "context" + "testing" + + v3 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v3" +) + +type user struct{} + +func TestUser(t *testing.T) { + query := v3.NewUserQuery() + query.Where( + v3.Or( + v3.UserByID("123"), + v3.UserByUsername("test", v3.TextOperatorStartsWithIgnoreCase), + ), + ) + query.Limit(10) + query.Offset(5) + // query.OrderBy( + + query.Result(context.TODO(), nil) +} diff --git a/backend/v3/storage/database/repository/stmt/v4/column.go b/backend/v3/storage/database/repository/stmt/v4/column.go new file mode 100644 index 0000000000..5dec86469b --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/column.go @@ -0,0 +1,78 @@ +package v4 + +type Change interface { + Column +} + +type change[V Value] struct { + column Column + value V +} + +func newChange[V Value](col Column, value V) Change { + return &change[V]{ + column: col, + value: value, + } +} + +func newUpdatePtrColumn[V Value](col Column, value *V) Change { + if value == nil { + return newChange(col, nullDBInstruction) + } + return newChange(col, *value) +} + +// writeTo implements [Change]. +func (c change[V]) writeTo(builder *statementBuilder) { + c.column.writeTo(builder) + builder.WriteString(" = ") + builder.writeArg(c.value) +} + +type Changes []Change + +func newChanges(cols ...Change) Change { + return Changes(cols) +} + +// writeTo implements [Change]. +func (m Changes) writeTo(builder *statementBuilder) { + for i, col := range m { + if i > 0 { + builder.WriteString(", ") + } + col.writeTo(builder) + } +} + +var _ Change = Changes(nil) + +var _ Change = (*change[string])(nil) + +type Column interface { + writeTo(builder *statementBuilder) +} + +type column struct { + name string +} + +func (c column) writeTo(builder *statementBuilder) { + builder.WriteString(c.name) +} + +type ignoreCaseColumn interface { + Column + writeIgnoreCaseTo(builder *statementBuilder) +} + +type ignoreCaseCol struct { + column + suffix string +} + +func (c ignoreCaseCol) writeIgnoreCaseTo(builder *statementBuilder) { + c.column.writeTo(builder) + builder.WriteString(c.suffix) +} diff --git a/backend/v3/storage/database/repository/stmt/v4/condition.go b/backend/v3/storage/database/repository/stmt/v4/condition.go new file mode 100644 index 0000000000..e9cfad4317 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/condition.go @@ -0,0 +1,112 @@ +package v4 + +type Condition interface { + writeTo(builder *statementBuilder) +} + +type and struct { + conditions []Condition +} + +// writeTo implements [Condition]. +func (a *and) writeTo(builder *statementBuilder) { + if len(a.conditions) > 1 { + builder.WriteString("(") + defer builder.WriteString(")") + } + for i, condition := range a.conditions { + if i > 0 { + builder.WriteString(" AND ") + } + condition.writeTo(builder) + } +} + +func And(conditions ...Condition) *and { + return &and{conditions: conditions} +} + +var _ Condition = (*and)(nil) + +type or struct { + conditions []Condition +} + +// writeTo implements [Condition]. +func (o *or) writeTo(builder *statementBuilder) { + if len(o.conditions) > 1 { + builder.WriteString("(") + defer builder.WriteString(")") + } + for i, condition := range o.conditions { + if i > 0 { + builder.WriteString(" OR ") + } + condition.writeTo(builder) + } +} + +func Or(conditions ...Condition) *or { + return &or{conditions: conditions} +} + +var _ Condition = (*or)(nil) + +type isNull struct { + column Column +} + +// writeTo implements [Condition]. +func (i *isNull) writeTo(builder *statementBuilder) { + i.column.writeTo(builder) + builder.WriteString(" IS NULL") +} + +func IsNull(column Column) *isNull { + return &isNull{column: column} +} + +var _ Condition = (*isNull)(nil) + +type isNotNull struct { + column Column +} + +// writeTo implements [Condition]. +func (i *isNotNull) writeTo(builder *statementBuilder) { + i.column.writeTo(builder) + builder.WriteString(" IS NOT NULL") +} + +func IsNotNull(column Column) *isNotNull { + return &isNotNull{column: column} +} + +var _ Condition = (*isNotNull)(nil) + +type valueCondition func(builder *statementBuilder) + +func newTextCondition[V Text](col Column, op TextOperator, value V) Condition { + return valueCondition(func(builder *statementBuilder) { + writeTextOperation(builder, col, op, value) + }) +} + +func newNumberCondition[V Number](col Column, op NumberOperator, value V) Condition { + return valueCondition(func(builder *statementBuilder) { + writeNumberOperation(builder, col, op, value) + }) +} + +func newBooleanCondition[V Boolean](col Column, value V) Condition { + return valueCondition(func(builder *statementBuilder) { + writeBooleanOperation(builder, col, value) + }) +} + +// writeTo implements [Condition]. +func (c valueCondition) writeTo(builder *statementBuilder) { + c(builder) +} + +var _ Condition = (*valueCondition)(nil) diff --git a/backend/v3/storage/database/repository/stmt/v4/doc.go b/backend/v3/storage/database/repository/stmt/v4/doc.go new file mode 100644 index 0000000000..41871c5e1d --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/doc.go @@ -0,0 +1,2 @@ +// this test focuses on queries rather than on tables +package v4 diff --git a/backend/v3/storage/database/repository/stmt/v4/inheritance.sql b/backend/v3/storage/database/repository/stmt/v4/inheritance.sql new file mode 100644 index 0000000000..9d15188e28 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/inheritance.sql @@ -0,0 +1,149 @@ +CREATE TABLE objects ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + deleted_at TIMESTAMP +); + +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TABLE instances( + name VARCHAR(50) NOT NULL + , PRIMARY KEY (id) +) INHERITS (objects); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON instances +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE instance_objects( + instance_id INT NOT NULL + , PRIMARY KEY (instance_id, id) + -- as foreign keys are not inherited we need to define them on the child tables + --, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (objects); + +CREATE TABLE orgs( + name VARCHAR(50) NOT NULL + , PRIMARY KEY (instance_id, id) + , CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (instance_objects); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON orgs +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE org_objects( + org_id INT NOT NULL + , PRIMARY KEY (instance_id, org_id, id) + -- as foreign keys are not inherited we need to define them on the child tables + -- CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id), + -- CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (instance_objects); + +CREATE TABLE users ( + username VARCHAR(50) NOT NULL + , PRIMARY KEY (instance_id, org_id, id) + -- as foreign keys are not inherited we need to define them on the child tables + -- , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id) + -- , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (org_objects); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON users +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE human_users( + first_name VARCHAR(50) + , last_name VARCHAR(50) + , PRIMARY KEY (instance_id, org_id, id) + -- CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id), + , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id) + , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (users); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON human_users +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE machine_users( + description VARCHAR(50) + , PRIMARY KEY (instance_id, org_id, id) + -- , CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id) + , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id) + , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (users); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON machine_users +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + + +select u.*, hu.first_name, hu.last_name, mu.description from users u +left join human_users hu on u.instance_id = hu.instance_id and u.org_id = hu.org_id and u.id = hu.id +left join machine_users mu on u.instance_id = mu.instance_id and u.org_id = mu.org_id and u.id = mu.id +-- where +-- u.instance_id = 1 +-- and u.org_id = 3 +-- and u.id = 7 +; + +create view users_view as ( +SELECT + id + , created_at + , updated_at + , deleted_at + , instance_id + , org_id + , username + , first_name + , last_name + , description +FROM ( +(SELECT + id + , created_at + , updated_at + , deleted_at + , instance_id + , org_id + , username + , first_name + , last_name + , NULL AS description +FROM + human_users) + +UNION + +(SELECT + id + , created_at + , updated_at + , deleted_at + , instance_id + , org_id + , username + , NULL AS first_name + , NULL AS last_name + , description +FROM + machine_users) +)); \ No newline at end of file diff --git a/backend/v3/storage/database/repository/stmt/v4/operators.go b/backend/v3/storage/database/repository/stmt/v4/operators.go new file mode 100644 index 0000000000..44f7568184 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/operators.go @@ -0,0 +1,139 @@ +package v4 + +import ( + "time" + + "golang.org/x/exp/constraints" +) + +type Value interface { + Boolean | Number | Text | databaseInstruction +} + +type Operator interface { + BooleanOperator | NumberOperator | TextOperator +} + +type Text interface { + ~string | ~[]byte +} + +type TextOperator uint8 + +const ( + // TextOperatorEqual compares two strings for equality. + TextOperatorEqual TextOperator = iota + 1 + // TextOperatorEqualIgnoreCase compares two strings for equality, ignoring case. + TextOperatorEqualIgnoreCase + // TextOperatorNotEqual compares two strings for inequality. + TextOperatorNotEqual + // TextOperatorNotEqualIgnoreCase compares two strings for inequality, ignoring case. + TextOperatorNotEqualIgnoreCase + // TextOperatorStartsWith checks if the first string starts with the second. + TextOperatorStartsWith + // TextOperatorStartsWithIgnoreCase checks if the first string starts with the second, ignoring case. + TextOperatorStartsWithIgnoreCase +) + +var textOperators = map[TextOperator]string{ + TextOperatorEqual: " = ", + TextOperatorEqualIgnoreCase: " LIKE ", + TextOperatorNotEqual: " <> ", + TextOperatorNotEqualIgnoreCase: " NOT LIKE ", + TextOperatorStartsWith: " LIKE ", + TextOperatorStartsWithIgnoreCase: " LIKE ", +} + +func writeTextOperation[T Text](builder *statementBuilder, col Column, op TextOperator, value T) { + switch op { + case TextOperatorEqual, TextOperatorNotEqual: + col.writeTo(builder) + builder.WriteString(textOperators[op]) + builder.WriteString(builder.appendArg(value)) + case TextOperatorEqualIgnoreCase, TextOperatorNotEqualIgnoreCase: + if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok { + ignoreCaseCol.writeIgnoreCaseTo(builder) + } else { + builder.WriteString("LOWER(") + col.writeTo(builder) + builder.WriteString(")") + } + builder.WriteString(textOperators[op]) + builder.WriteString("LOWER(") + builder.WriteString(builder.appendArg(value)) + builder.WriteString(")") + case TextOperatorStartsWith: + col.writeTo(builder) + builder.WriteString(textOperators[op]) + builder.WriteString(builder.appendArg(value)) + builder.WriteString(" || '%'") + case TextOperatorStartsWithIgnoreCase: + if ignoreCaseCol, ok := col.(ignoreCaseColumn); ok { + ignoreCaseCol.writeIgnoreCaseTo(builder) + } else { + builder.WriteString("LOWER(") + col.writeTo(builder) + builder.WriteString(")") + } + builder.WriteString(textOperators[op]) + builder.WriteString("LOWER(") + builder.WriteString(builder.appendArg(value)) + builder.WriteString(")") + builder.WriteString(" || '%'") + default: + panic("unsupported text operation") + } +} + +type Number interface { + constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration +} + +type NumberOperator uint8 + +const ( + // NumberOperatorEqual compares two numbers for equality. + NumberOperatorEqual NumberOperator = iota + 1 + // NumberOperatorNotEqual compares two numbers for inequality. + NumberOperatorNotEqual + // NumberOperatorLessThan compares two numbers to check if the first is less than the second. + NumberOperatorLessThan + // NumberOperatorLessThanOrEqual compares two numbers to check if the first is less than or equal to the second. + NumberOperatorAtLeast + // NumberOperatorGreaterThan compares two numbers to check if the first is greater than the second. + NumberOperatorGreaterThan + // NumberOperatorGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second. + NumberOperatorAtMost +) + +var numberOperators = map[NumberOperator]string{ + NumberOperatorEqual: " = ", + NumberOperatorNotEqual: " <> ", + NumberOperatorLessThan: " < ", + NumberOperatorAtLeast: " <= ", + NumberOperatorGreaterThan: " > ", + NumberOperatorAtMost: " >= ", +} + +func writeNumberOperation[T Number](builder *statementBuilder, col Column, op NumberOperator, value T) { + col.writeTo(builder) + builder.WriteString(numberOperators[op]) + builder.WriteString(builder.appendArg(value)) +} + +type Boolean interface { + ~bool +} + +type BooleanOperator uint8 + +const ( + BooleanOperatorIsTrue BooleanOperator = iota + 1 + BooleanOperatorIsFalse +) + +func writeBooleanOperation[T Boolean](builder *statementBuilder, col Column, value T) { + col.writeTo(builder) + builder.WriteString(" IS ") + builder.WriteString(builder.appendArg(value)) +} diff --git a/backend/v3/storage/database/repository/stmt/v4/org.go b/backend/v3/storage/database/repository/stmt/v4/org.go new file mode 100644 index 0000000000..42681d2442 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/org.go @@ -0,0 +1,18 @@ +package v4 + +type Org struct { + InstanceID string + ID string + Name string + Dates +} + +type GetOrg struct{} + +type ListOrgs struct{} + +type CreateOrg struct{} + +type UpdateOrg struct{} + +type DeleteOrg struct{} diff --git a/backend/v3/storage/database/repository/stmt/v4/statement.go b/backend/v3/storage/database/repository/stmt/v4/statement.go new file mode 100644 index 0000000000..9b652ee805 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/statement.go @@ -0,0 +1,46 @@ +package v4 + +import ( + "strconv" + "strings" +) + +type databaseInstruction string + +const ( + nowDBInstruction databaseInstruction = "NOW()" + nullDBInstruction databaseInstruction = "NULL" +) + +type statementBuilder struct { + strings.Builder + args []any + existingArgs map[any]string +} + +func (b *statementBuilder) writeArg(arg any) { + b.WriteString(b.appendArg(arg)) +} + +func (b *statementBuilder) appendArg(arg any) (placeholder string) { + if b.existingArgs == nil { + b.existingArgs = make(map[any]string) + } + if placeholder, ok := b.existingArgs[arg]; ok { + return placeholder + } + if instruction, ok := arg.(databaseInstruction); ok { + return string(instruction) + } + + b.args = append(b.args, arg) + placeholder = "$" + strconv.Itoa(len(b.args)) + b.existingArgs[arg] = placeholder + return placeholder +} + +func (b *statementBuilder) appendArgs(args ...any) { + for _, arg := range args { + b.appendArg(arg) + } +} diff --git a/backend/v3/storage/database/repository/stmt/v4/user.go b/backend/v3/storage/database/repository/stmt/v4/user.go new file mode 100644 index 0000000000..b728fc82c8 --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/user.go @@ -0,0 +1,239 @@ +package v4 + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Dates struct { + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt time.Time +} + +type User struct { + InstanceID string + OrgID string + ID string + Username string + Traits userTrait + Dates +} + +type UserType string + +type userTrait interface { + userTrait() + Type() UserType +} + +const userQuery = `SELECT u.instance_id, u.org_id, u.id, u.username, u.type, u.created_at, u.updated_at, u.deleted_at,` + + ` h.first_name, h.last_name, h.email_address, h.email_verified_at, h.phone_number, h.phone_verified_at, m.description` + + ` FROM users u` + + ` LEFT JOIN user_humans h ON u.instance_id = h.instance_id AND u.org_id = h.org_id AND u.id = h.id` + + ` LEFT JOIN user_machines m ON u.instance_id = m.instance_id AND u.org_id = m.org_id AND u.id = m.id` + +type user struct { + builder statementBuilder + client database.QueryExecutor + + condition Condition +} + +func UserRepository(client database.QueryExecutor) *user { + return &user{ + client: client, + } +} + +func (u *user) WithCondition(condition Condition) *user { + u.condition = condition + return u +} + +func (u *user) Get(ctx context.Context) (*User, error) { + u.builder.WriteString(userQuery) + u.writeCondition() + return scanUser(u.client.QueryRow(ctx, u.builder.String(), u.builder.args...)) +} + +func (u *user) List(ctx context.Context) (users []*User, err error) { + u.builder.WriteString(userQuery) + u.writeCondition() + + rows, err := u.client.Query(ctx, u.builder.String(), u.builder.args...) + if err != nil { + return nil, err + } + defer func() { + closeErr := rows.Close() + if err != nil { + return + } + err = closeErr + }() + for rows.Next() { + user, err := scanUser(rows) + if err != nil { + return nil, err + } + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + return users, nil +} + +const ( + createUserCte = `WITH user AS (` + + `INSERT INTO users (instance_id, org_id, id, username, type) VALUES ($1, $2, $3, $4, $5)` + + ` RETURNING *)` + createHumanStmt = createUserCte + ` INSERT INTO user_humans h (instance_id, org_id, user_id, first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at)` + + ` SELECT u.instance_id, u.org_id, u.id, $6, $7, $8, $9, $10, $11` + + ` FROM user u` + + ` RETURNING u.created_at, u.updated_at, u.deleted_at` + createMachineStmt = createUserCte + ` INSERT INTO user_machines (instance_id, org_id, user_id, description)` + + ` SELECT u.instance_id, u.org_id, u.id, $6` + + ` FROM user u` + + ` RETURNING u.created_at, u.updated_at` +) + +func (u *user) Create(ctx context.Context, user *User) error { + u.builder.appendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type()) + switch trait := user.Traits.(type) { + case *Human: + u.builder.WriteString(createHumanStmt) + u.builder.appendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt) + case *Machine: + u.builder.WriteString(createMachineStmt) + u.builder.appendArgs(trait.Description) + } + return u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan(user.CreatedAt, user.UpdatedAt) +} + +func (u *user) InstanceIDColumn() Column { + return column{name: "u.instance_id"} +} + +func (u *user) InstanceIDCondition(instanceID string) Condition { + return newTextCondition(u.InstanceIDColumn(), TextOperatorEqual, instanceID) +} + +func (u *user) OrgIDColumn() Column { + return column{name: "u.org_id"} +} + +func (u *user) OrgIDCondition(orgID string) Condition { + return newTextCondition(u.OrgIDColumn(), TextOperatorEqual, orgID) +} + +func (u *user) IDColumn() Column { + return column{name: "u.id"} +} + +func (u *user) IDCondition(userID string) Condition { + return newTextCondition(u.IDColumn(), TextOperatorEqual, userID) +} + +func (u *user) UsernameColumn() Column { + return ignoreCaseCol{ + column: column{name: "u.username"}, + suffix: "_lower", + } +} + +func (u user) SetUsername(username string) Change { + return newChange(u.UsernameColumn(), username) +} + +func (u *user) UsernameCondition(op TextOperator, username string) Condition { + return newTextCondition(u.UsernameColumn(), op, username) +} + +func (u *user) CreatedAtColumn() Column { + return column{name: "u.created_at"} +} + +func (u *user) CreatedAtCondition(op NumberOperator, createdAt time.Time) Condition { + return newNumberCondition(u.CreatedAtColumn(), op, createdAt) +} + +func (u *user) UpdatedAtColumn() Column { + return column{name: "u.updated_at"} +} + +func (u *user) UpdatedAtCondition(op NumberOperator, updatedAt time.Time) Condition { + return newNumberCondition(u.UpdatedAtColumn(), op, updatedAt) +} + +func (u *user) DeletedAtColumn() Column { + return column{name: "u.deleted_at"} +} + +func (u *user) DeletedCondition(isDeleted bool) Condition { + if isDeleted { + return IsNotNull(u.DeletedAtColumn()) + } + return IsNull(u.DeletedAtColumn()) +} + +func (u *user) DeletedAtCondition(op NumberOperator, deletedAt time.Time) Condition { + return newNumberCondition(u.DeletedAtColumn(), op, deletedAt) +} + +func (u *user) writeCondition() { + if u.condition == nil { + return + } + u.builder.WriteString(" WHERE ") + u.condition.writeTo(&u.builder) +} + +func scanUser(scanner database.Scanner) (*User, error) { + var ( + user User + human Human + email Email + phone Phone + machine Machine + typ UserType + ) + err := scanner.Scan( + &user.InstanceID, + &user.OrgID, + &user.ID, + &user.Username, + &typ, + &user.Dates.CreatedAt, + &user.Dates.UpdatedAt, + &user.Dates.DeletedAt, + &human.FirstName, + &human.LastName, + &email.Address, + &email.VerifiedAt, + &phone.Number, + &phone.VerifiedAt, + &machine.Description, + ) + if err != nil { + return nil, err + } + + switch typ { + case UserTypeHuman: + if email.Address != "" { + human.Email = &email + } + if phone.Number != "" { + human.Phone = &phone + } + user.Traits = &human + case UserTypeMachine: + user.Traits = &machine + } + + return &user, nil +} diff --git a/backend/v3/storage/database/repository/stmt/v4/user_human.go b/backend/v3/storage/database/repository/stmt/v4/user_human.go new file mode 100644 index 0000000000..3aadb3ab1b --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/user_human.go @@ -0,0 +1,187 @@ +package v4 + +import ( + "context" + "time" +) + +type Human struct { + FirstName string + LastName string + Email *Email + Phone *Phone +} + +const UserTypeHuman UserType = "human" + +func (Human) userTrait() {} + +func (h Human) Type() UserType { + return UserTypeHuman +} + +var _ userTrait = (*Human)(nil) + +type Email struct { + Address string + Verification +} + +type Phone struct { + Number string + Verification +} + +type Verification struct { + VerifiedAt time.Time +} + +type userHuman struct { + *user +} + +func (u *user) Human() *userHuman { + return &userHuman{user: u} +} + +const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h` + +func (u *userHuman) GetEmail(ctx context.Context) (*Email, error) { + var email Email + + u.builder.WriteString(userEmailQuery) + u.writeCondition() + + err := u.client.QueryRow(ctx, u.builder.String(), u.builder.args...).Scan( + &email.Address, + &email.Verification.VerifiedAt, + ) + + if err != nil { + return nil, err + } + return &email, nil +} + +func (h userHuman) Update(ctx context.Context, changes ...Change) error { + h.builder.WriteString(`UPDATE human_users h SET `) + Changes(changes).writeTo(&h.builder) + h.writeCondition() + + stmt := h.builder.String() + + return h.client.Exec(ctx, stmt, h.builder.args...) +} + +func (h userHuman) SetFirstName(firstName string) Change { + return newChange(h.FirstNameColumn(), firstName) +} + +func (h userHuman) FirstNameColumn() Column { + return column{"h.first_name"} +} + +func (h userHuman) FirstNameCondition(op TextOperator, firstName string) Condition { + return newTextCondition(h.FirstNameColumn(), op, firstName) +} + +func (h userHuman) SetLastName(lastName string) Change { + return newChange(h.LastNameColumn(), lastName) +} + +func (h userHuman) LastNameColumn() Column { + return column{"h.last_name"} +} + +func (h userHuman) LastNameCondition(op TextOperator, lastName string) Condition { + return newTextCondition(h.LastNameColumn(), op, lastName) +} + +func (h userHuman) EmailAddressColumn() Column { + return ignoreCaseCol{ + column: column{"h.email_address"}, + suffix: "_lower", + } +} + +func (h userHuman) EmailAddressCondition(op TextOperator, email string) Condition { + return newTextCondition(h.EmailAddressColumn(), op, email) +} + +func (h userHuman) EmailVerifiedAtColumn() Column { + return column{"h.email_verified_at"} +} + +func (h *userHuman) EmailAddressVerifiedCondition(isVerified bool) Condition { + if isVerified { + return IsNotNull(h.EmailVerifiedAtColumn()) + } + return IsNull(h.EmailVerifiedAtColumn()) +} + +func (h userHuman) EmailVerifiedAtCondition(op TextOperator, emailVerifiedAt string) Condition { + return newTextCondition(h.EmailVerifiedAtColumn(), op, emailVerifiedAt) +} + +func (h userHuman) SetEmailAddress(address string) Change { + return newChange(h.EmailAddressColumn(), address) +} + +// SetEmailVerified sets the verified column of the email +// if at is zero the statement uses the database timestamp +func (h userHuman) SetEmailVerified(at time.Time) Change { + if at.IsZero() { + return newChange(h.EmailVerifiedAtColumn(), nowDBInstruction) + } + return newChange(h.EmailVerifiedAtColumn(), at) +} + +func (h userHuman) SetEmail(address string, verified *time.Time) Change { + return newChanges( + h.SetEmailAddress(address), + newUpdatePtrColumn(h.EmailVerifiedAtColumn(), verified), + ) +} + +func (h userHuman) PhoneNumberColumn() Column { + return column{"h.phone_number"} +} + +func (h userHuman) SetPhoneNumber(number string) Change { + return newChange(h.PhoneNumberColumn(), number) +} + +func (h userHuman) PhoneNumberCondition(op TextOperator, phoneNumber string) Condition { + return newTextCondition(h.PhoneNumberColumn(), op, phoneNumber) +} + +func (h userHuman) PhoneVerifiedAtColumn() Column { + return column{"h.phone_verified_at"} +} + +func (h userHuman) PhoneNumberVerifiedCondition(isVerified bool) Condition { + if isVerified { + return IsNotNull(h.PhoneVerifiedAtColumn()) + } + return IsNull(h.PhoneVerifiedAtColumn()) +} + +// SetPhoneVerified sets the verified column of the phone +// if at is zero the statement uses the database timestamp +func (h userHuman) SetPhoneVerified(at time.Time) Change { + if at.IsZero() { + return newChange(h.PhoneVerifiedAtColumn(), nowDBInstruction) + } + return newChange(h.PhoneVerifiedAtColumn(), at) +} + +func (h userHuman) PhoneVerifiedAtCondition(op TextOperator, phoneVerifiedAt string) Condition { + return newTextCondition(h.PhoneVerifiedAtColumn(), op, phoneVerifiedAt) +} + +func (h userHuman) SetPhone(number string, verifiedAt *time.Time) Change { + return newChanges( + h.SetPhoneNumber(number), + newUpdatePtrColumn(h.PhoneVerifiedAtColumn(), verifiedAt), + ) +} diff --git a/backend/v3/storage/database/repository/stmt/v4/user_machine.go b/backend/v3/storage/database/repository/stmt/v4/user_machine.go new file mode 100644 index 0000000000..57bb7f14fb --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/user_machine.go @@ -0,0 +1,41 @@ +package v4 + +import "context" + +type Machine struct { + Description string +} + +func (Machine) userTrait() {} + +func (m Machine) Type() UserType { + return UserTypeMachine +} + +const UserTypeMachine UserType = "machine" + +var _ userTrait = (*Machine)(nil) + +type userMachine struct { + *user +} + +func (u *user) Machine() *userMachine { + return &userMachine{user: u} +} + +func (m userMachine) Update(ctx context.Context, cols ...Change) (*Machine, error) { + return nil, nil +} + +func (userMachine) DescriptionColumn() Column { + return column{"m.description"} +} + +func (m userMachine) SetDescription(description string) Change { + return newChange(m.DescriptionColumn(), description) +} + +func (m userMachine) DescriptionCondition(op TextOperator, description string) Condition { + return newTextCondition(m.DescriptionColumn(), op, description) +} diff --git a/backend/v3/storage/database/repository/stmt/v4/user_test.go b/backend/v3/storage/database/repository/stmt/v4/user_test.go new file mode 100644 index 0000000000..e4c65efd1b --- /dev/null +++ b/backend/v3/storage/database/repository/stmt/v4/user_test.go @@ -0,0 +1,65 @@ +package v4_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + v4 "github.com/zitadel/zitadel/backend/v3/storage/database/repository/stmt/v4" +) + +func TestQueryUser(t *testing.T) { + t.Run("User filters", func(t *testing.T) { + user := v4.UserRepository(nil) + user.WithCondition( + v4.And( + v4.Or( + user.IDCondition("test"), + user.IDCondition("2"), + ), + user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"), + ), + ).Get(context.Background()) + }) + + t.Run("machine and human filters", func(t *testing.T) { + user := v4.UserRepository(nil) + machine := user.Machine() + human := user.Human() + user.WithCondition( + v4.And( + user.UsernameCondition(v4.TextOperatorStartsWithIgnoreCase, "test"), + v4.Or( + machine.DescriptionCondition(v4.TextOperatorStartsWithIgnoreCase, "test"), + human.EmailAddressVerifiedCondition(true), + v4.IsNotNull(machine.DescriptionColumn()), + ), + ), + ) + human.GetEmail(context.Background()) + }) +} + +type dbInstruction string + +func TestArg(t *testing.T) { + var bla any = "asdf" + instr, ok := bla.(dbInstruction) + assert.False(t, ok) + assert.Empty(t, instr) + bla = dbInstruction("asdf") + instr, ok = bla.(dbInstruction) + assert.True(t, ok) + assert.Equal(t, instr, dbInstruction("asdf")) +} + +func TestWriteUser(t *testing.T) { + t.Run("update user", func(t *testing.T) { + user := v4.UserRepository(nil) + user.WithCondition(user.IDCondition("test")).Human().Update( + context.Background(), + user.SetUsername("test"), + ) + + }) +} diff --git a/backend/v3/storage/database/repository/user.go b/backend/v3/storage/database/repository/user.go new file mode 100644 index 0000000000..dcc0b64f0c --- /dev/null +++ b/backend/v3/storage/database/repository/user.go @@ -0,0 +1,39 @@ +package repository + +import ( + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type user struct { + database.QueryExecutor +} + +func User(client database.QueryExecutor) domain.UserRepository { + // return &user{QueryExecutor: client} + return nil +} + +// On implements [domain.UserRepository]. +func (exec *user) On(clauses ...domain.UserClause) domain.UserOperation { + return &userOperation{ + QueryExecutor: exec.QueryExecutor, + clauses: clauses, + } +} + +// OnHuman implements [domain.UserRepository]. +func (exec *user) OnHuman(clauses ...domain.UserClause) domain.HumanOperation { + return &humanOperation{ + userOperation: *exec.On(clauses...).(*userOperation), + } +} + +// OnMachine implements [domain.UserRepository]. +func (exec *user) OnMachine(clauses ...domain.UserClause) domain.MachineOperation { + return &machineOperation{ + userOperation: *exec.On(clauses...).(*userOperation), + } +} + +// var _ domain.UserRepository = (*user)(nil) diff --git a/backend/v3/storage/database/repository/user_human_operation.go b/backend/v3/storage/database/repository/user_human_operation.go new file mode 100644 index 0000000000..cc4de1d5db --- /dev/null +++ b/backend/v3/storage/database/repository/user_human_operation.go @@ -0,0 +1,36 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" +) + +type humanOperation struct { + userOperation +} + +// GetEmail implements domain.HumanOperation. +func (h *humanOperation) GetEmail(ctx context.Context) (*domain.Email, error) { + var email domain.Email + err := h.QueryExecutor.QueryRow(ctx, `SELECT email, is_email_verified FROM human_users WHERE id = $1`, h.clauses).Scan( + &email.Address, + &email.IsVerified, + ) + if err != nil { + return nil, err + } + return &email, nil +} + +// SetEmail implements domain.HumanOperation. +func (h *humanOperation) SetEmail(ctx context.Context, email string) error { + return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET email = $1 WHERE id = $2`, email, h.clauses) +} + +// SetEmailVerified implements domain.HumanOperation. +func (h *humanOperation) SetEmailVerified(ctx context.Context, email string) error { + return h.QueryExecutor.Exec(ctx, `UPDATE human_users SET is_email_verified = $1 WHERE id = $2 AND email = $3`, true, h.clauses, email) +} + +var _ domain.HumanOperation = (*humanOperation)(nil) diff --git a/backend/v3/storage/database/repository/user_machine_operation.go b/backend/v3/storage/database/repository/user_machine_operation.go new file mode 100644 index 0000000000..b01451f566 --- /dev/null +++ b/backend/v3/storage/database/repository/user_machine_operation.go @@ -0,0 +1,18 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" +) + +type machineOperation struct { + userOperation +} + +// SetDescription implements domain.MachineOperation. +func (m *machineOperation) SetDescription(ctx context.Context, description string) error { + return m.QueryExecutor.Exec(ctx, `UPDATE machines SET description = $1 WHERE id = $2`, description, m.clauses) +} + +var _ domain.MachineOperation = (*machineOperation)(nil) diff --git a/backend/v3/storage/database/repository/user_operation.go b/backend/v3/storage/database/repository/user_operation.go new file mode 100644 index 0000000000..f2e90dc55b --- /dev/null +++ b/backend/v3/storage/database/repository/user_operation.go @@ -0,0 +1,68 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type userOperation struct { + database.QueryExecutor + clauses []domain.UserClause +} + +// Delete implements [domain.UserOperation]. +func (u *userOperation) Delete(ctx context.Context) error { + return u.QueryExecutor.Exec(ctx, `DELETE FROM users WHERE id = $1`, u.clauses) +} + +// SetUsername implements [domain.UserOperation]. +func (u *userOperation) SetUsername(ctx context.Context, username string) error { + var stmt statement + + stmt.builder.WriteString(`UPDATE users SET username = $1 WHERE `) + stmt.appendArg(username) + clausesToSQL(&stmt, u.clauses) + return u.QueryExecutor.Exec(ctx, stmt.builder.String(), stmt.args...) +} + +var _ domain.UserOperation = (*userOperation)(nil) + +func UserIDQuery(id string) domain.UserClause { + return textClause[string]{ + clause: clause[domain.TextOperation]{ + field: userFields[domain.UserFieldID], + op: domain.TextOperationEqual, + }, + value: id, + } +} + +func HumanEmailQuery(op domain.TextOperation, email string) domain.UserClause { + return textClause[string]{ + clause: clause[domain.TextOperation]{ + field: userFields[domain.UserHumanFieldEmail], + op: op, + }, + value: email, + } +} + +func HumanEmailVerifiedQuery(op domain.BoolOperation) domain.UserClause { + return boolClause[domain.BoolOperation]{ + clause: clause[domain.BoolOperation]{ + field: userFields[domain.UserHumanFieldEmailVerified], + op: op, + }, + } +} + +func clausesToSQL(stmt *statement, clauses []domain.UserClause) { + for _, clause := range clauses { + + stmt.builder.WriteString(userFields[clause.Field()].String()) + stmt.builder.WriteString(clause.Operation().String()) + stmt.appendArg(clause.Args()...) + } +} diff --git a/backend/v3/storage/database/tx.go b/backend/v3/storage/database/tx.go new file mode 100644 index 0000000000..02c945dc77 --- /dev/null +++ b/backend/v3/storage/database/tx.go @@ -0,0 +1,36 @@ +package database + +import "context" + +type Transaction interface { + Commit(ctx context.Context) error + Rollback(ctx context.Context) error + End(ctx context.Context, err error) error + + Begin(ctx context.Context) (Transaction, error) + + QueryExecutor +} + +type Beginner interface { + Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error) +} + +type TransactionOptions struct { + IsolationLevel IsolationLevel + AccessMode AccessMode +} + +type IsolationLevel uint8 + +const ( + IsolationLevelSerializable IsolationLevel = iota + IsolationLevelReadCommitted +) + +type AccessMode uint8 + +const ( + AccessModeReadWrite AccessMode = iota + AccessModeReadOnly +) diff --git a/backend/v3/storage/eventstore/event.go b/backend/v3/storage/eventstore/event.go new file mode 100644 index 0000000000..1306a9329e --- /dev/null +++ b/backend/v3/storage/eventstore/event.go @@ -0,0 +1,23 @@ +package eventstore + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Event struct { + AggregateType string `json:"aggregateType"` + AggregateID string `json:"aggregateId"` + Type string `json:"type"` + Payload any `json:"payload,omitempty"` +} + +func Publish(ctx context.Context, events []*Event, db database.Executor) error { + for _, event := range events { + if err := db.Exec(ctx, `INSERT INTO events (aggregate_type, aggregate_id) VALUES ($1, $2)`, event.AggregateType, event.AggregateID); err != nil { + return err + } + } + return nil +} diff --git a/backend/v3/telemetry/logging/logger.go b/backend/v3/telemetry/logging/logger.go new file mode 100644 index 0000000000..580120cce6 --- /dev/null +++ b/backend/v3/telemetry/logging/logger.go @@ -0,0 +1,7 @@ +package logging + +import "log/slog" + +type Logger struct { + *slog.Logger +} diff --git a/backend/v3/telemetry/tracing/tracer.go b/backend/v3/telemetry/tracing/tracer.go new file mode 100644 index 0000000000..4536092947 --- /dev/null +++ b/backend/v3/telemetry/tracing/tracer.go @@ -0,0 +1,23 @@ +package tracing + +import ( + "context" + + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +type Tracer struct { + trace.Tracer +} + +var noopTracer = Tracer{ + Tracer: noop.NewTracerProvider().Tracer(""), +} + +func (t *Tracer) Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { + if t.Tracer == nil { + return noopTracer.Start(ctx, spanName, opts...) + } + return t.Tracer.Start(ctx, spanName, opts...) +}