mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-12 03:17:33 +00:00
better
This commit is contained in:
30
backend/repository/cached/instance.go
Normal file
30
backend/repository/cached/instance.go
Normal file
@@ -0,0 +1,30 @@
|
|||||||
|
package cached
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/backend/repository"
|
||||||
|
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Instance struct {
|
||||||
|
cache.Cache[repository.InstanceIndex, string, *repository.Instance]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewInstance(c cache.Cache[repository.InstanceIndex, string, *repository.Instance]) *Instance {
|
||||||
|
return &Instance{c}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Instance) ByID(ctx context.Context, id string) *repository.Instance {
|
||||||
|
instance, _ := i.Cache.Get(ctx, repository.InstanceByID, id)
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Instance) ByDomain(ctx context.Context, domain string) *repository.Instance {
|
||||||
|
instance, _ := i.Cache.Get(ctx, repository.InstanceByDomain, domain)
|
||||||
|
return instance
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *Instance) Set(ctx context.Context, instance *repository.Instance) {
|
||||||
|
i.Cache.Set(ctx, instance)
|
||||||
|
}
|
25
backend/repository/cached/user.go
Normal file
25
backend/repository/cached/user.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
package cached
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/backend/repository"
|
||||||
|
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||||
|
)
|
||||||
|
|
||||||
|
type User struct {
|
||||||
|
cache.Cache[repository.UserIndex, string, *repository.User]
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewUser(c cache.Cache[repository.UserIndex, string, *repository.User]) *User {
|
||||||
|
return &User{c}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *User) ByID(ctx context.Context, id string) *repository.User {
|
||||||
|
user, _ := i.Cache.Get(ctx, repository.UserByIDIndex, id)
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *User) Set(ctx context.Context, user *repository.User) {
|
||||||
|
i.Cache.Set(ctx, user)
|
||||||
|
}
|
@@ -2,100 +2,125 @@ package handler
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Handler is a function that handles the request.
|
// Handler is a function that handles the in.
|
||||||
type Handler[Req, Res any] func(ctx context.Context, request Req) (res Res, err error)
|
type Handler[Out, In any] func(ctx context.Context, in Out) (out In, err error)
|
||||||
|
|
||||||
|
func (h Handler[Out, In]) Chain(next Handler[In, In]) Handler[Out, In] {
|
||||||
|
return Chain(h, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Handler[Out, In]) Decorate(decorate Decorator[Out, In]) Handler[Out, In] {
|
||||||
|
return Decorate(h, decorate)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Handler[Out, In]) SkipNext(next Handler[Out, In]) Handler[Out, In] {
|
||||||
|
return SkipNext(h, next)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h Handler[Out, In]) SkipNilHandler(handler *Out) Handler[Out, In] {
|
||||||
|
return SkipNilHandler(handler, h)
|
||||||
|
}
|
||||||
|
|
||||||
// Decorator is a function that decorates the handle function.
|
// Decorator is a function that decorates the handle function.
|
||||||
type Decorator[Req, Res any] func(ctx context.Context, request Req, handle Handler[Req, Res]) (res Res, err error)
|
type Decorator[In, Out any] func(ctx context.Context, in In, handle Handler[In, Out]) (out Out, err error)
|
||||||
|
|
||||||
// Chain chains the handle function with the next handler.
|
// Chain chains the handle function with the next handler.
|
||||||
// The next handler is called after the handle function.
|
// The next handler is called after the handle function.
|
||||||
func Chain[Req, Res any](handle Handler[Req, Res], next Handler[Res, Res]) Handler[Req, Res] {
|
func Chain[In, Out any](handle Handler[In, Out], next Handler[Out, Out]) Handler[In, Out] {
|
||||||
return func(ctx context.Context, request Req) (res Res, err error) {
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
res, err = handle(ctx, request)
|
out, err = handle(ctx, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return out, err
|
||||||
}
|
}
|
||||||
return next(ctx, res)
|
return next(ctx, out)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func Chains[Req, Res any](handle Handler[Req, Res], nexts ...Handler[Res, Res]) Handler[Req, Res] {
|
func Chains[In, Out any](handle Handler[In, Out], nexts ...Handler[Out, Out]) Handler[In, Out] {
|
||||||
return func(ctx context.Context, request Req) (res Res, err error) {
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
for _, next := range nexts {
|
for _, next := range nexts {
|
||||||
handle = Chain(handle, next)
|
handle = Chain(handle, next)
|
||||||
}
|
}
|
||||||
return handle(ctx, request)
|
return handle(ctx, in)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decorate decorates the handle function with the decorate function.
|
// Decorate decorates the handle function with the decorate function.
|
||||||
// The decorate function is called before the handle function.
|
// The decorate function is called before the handle function.
|
||||||
func Decorate[Req, Res any](handle Handler[Req, Res], decorate Decorator[Req, Res]) Handler[Req, Res] {
|
func Decorate[In, Out any](handle Handler[In, Out], decorate Decorator[In, Out]) Handler[In, Out] {
|
||||||
return func(ctx context.Context, request Req) (res Res, err error) {
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
return decorate(ctx, request, handle)
|
return decorate(ctx, in, handle)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Decorates decorates the handle function with the decorate functions.
|
// Decorates decorates the handle function with the decorate functions.
|
||||||
// The decorates function is called before the handle function.
|
// The decorates function is called before the handle function.
|
||||||
func Decorates[Req, Res any](handle Handler[Req, Res], decorates ...Decorator[Req, Res]) Handler[Req, Res] {
|
func Decorates[In, Out any](handle Handler[In, Out], decorates ...Decorator[In, Out]) Handler[In, Out] {
|
||||||
return func(ctx context.Context, request Req) (res Res, err error) {
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
for i := len(decorates) - 1; i >= 0; i-- {
|
for i := len(decorates) - 1; i >= 0; i-- {
|
||||||
handle = Decorate(handle, decorates[i])
|
handle = Decorate(handle, decorates[i])
|
||||||
}
|
}
|
||||||
return handle(ctx, request)
|
return handle(ctx, in)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SkipNext skips the next handler if the handle function returns a non-nil response.
|
// SkipNext skips the next handler if the handle function returns a non-nil response.
|
||||||
func SkipNext[Req, Res any](handle Handler[Req, Res], next Handler[Req, Res]) Handler[Req, Res] {
|
func SkipNext[In, Out any](handle Handler[In, Out], next Handler[In, Out]) Handler[In, Out] {
|
||||||
return func(ctx context.Context, request Req) (res Res, err error) {
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
var empty Res
|
var empty Out
|
||||||
res, err = handle(ctx, request)
|
out, err = handle(ctx, in)
|
||||||
// TODO: does this work?
|
// TODO: does this work?
|
||||||
if any(res) == any(empty) || err != nil {
|
if any(out) != any(empty) || err != nil {
|
||||||
return res, err
|
return out, err
|
||||||
}
|
}
|
||||||
return next(ctx, request)
|
return next(ctx, in)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// SkipNilHandler skips the handle function if the handler is nil.
|
// SkipNilHandler skips the handle function if the handler is nil.
|
||||||
|
// If handle is nil, an empty output is returned.
|
||||||
// The function is safe to call with nil handler.
|
// The function is safe to call with nil handler.
|
||||||
func SkipNilHandler[R any](handler any, handle Handler[R, R]) Handler[R, R] {
|
func SkipNilHandler[O, In, Out any](handler *O, handle Handler[In, Out]) Handler[In, Out] {
|
||||||
return func(ctx context.Context, request R) (res R, err error) {
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
if handler == nil {
|
if handler == nil {
|
||||||
return request, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
return handle(ctx, request)
|
return handle(ctx, in)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func ErrFuncToHandle[R any](fn func(context.Context, R) error) Handler[R, R] {
|
// SkipReturnPreviousHandler skips the handle function if the handler is nil and returns the input.
|
||||||
return func(ctx context.Context, request R) (res R, err error) {
|
// The function is safe to call with nil handler.
|
||||||
err = fn(ctx, request)
|
func SkipReturnPreviousHandler[O, In any](handler *O, handle Handler[In, In]) Handler[In, In] {
|
||||||
|
return func(ctx context.Context, in In) (out In, err error) {
|
||||||
|
if handler == nil {
|
||||||
|
return in, nil
|
||||||
|
}
|
||||||
|
return handle(ctx, in)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ResFuncToHandle[In any, Out any](fn func(context.Context, In) Out) Handler[In, Out] {
|
||||||
|
return func(ctx context.Context, in In) (out Out, err error) {
|
||||||
|
return fn(ctx, in), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ErrFuncToHandle[In any](fn func(context.Context, In) error) Handler[In, In] {
|
||||||
|
return func(ctx context.Context, in In) (out In, err error) {
|
||||||
|
err = fn(ctx, in)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return res, err
|
return out, err
|
||||||
}
|
}
|
||||||
return request, nil
|
return in, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func NoReturnToHandle[R any](fn func(context.Context, R)) Handler[R, R] {
|
func NoReturnToHandle[In any](fn func(context.Context, In)) Handler[In, In] {
|
||||||
return func(ctx context.Context, request R) (res R, err error) {
|
return func(ctx context.Context, in In) (out In, err error) {
|
||||||
fn(ctx, request)
|
fn(ctx, in)
|
||||||
return request, nil
|
return in, nil
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func CacheGetToHandle[I, K comparable, E cache.Entry[I, K]](fn func(context.Context, I, K) (E, bool), index I) Handler[K, E] {
|
|
||||||
return func(ctx context.Context, request K) (res E, err error) {
|
|
||||||
res, _ = fn(ctx, index, request)
|
|
||||||
return res, nil
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@@ -4,19 +4,19 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/backend/repository"
|
"github.com/zitadel/zitadel/backend/repository"
|
||||||
|
"github.com/zitadel/zitadel/backend/repository/cached"
|
||||||
"github.com/zitadel/zitadel/backend/repository/event"
|
"github.com/zitadel/zitadel/backend/repository/event"
|
||||||
"github.com/zitadel/zitadel/backend/repository/orchestrate/handler"
|
"github.com/zitadel/zitadel/backend/repository/orchestrate/handler"
|
||||||
"github.com/zitadel/zitadel/backend/repository/sql"
|
"github.com/zitadel/zitadel/backend/repository/sql"
|
||||||
"github.com/zitadel/zitadel/backend/repository/telemetry/logged"
|
"github.com/zitadel/zitadel/backend/repository/telemetry/logged"
|
||||||
"github.com/zitadel/zitadel/backend/repository/telemetry/traced"
|
"github.com/zitadel/zitadel/backend/repository/telemetry/traced"
|
||||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||||
"github.com/zitadel/zitadel/backend/storage/cache/connector/noop"
|
|
||||||
"github.com/zitadel/zitadel/backend/storage/database"
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type InstanceOptions struct {
|
type InstanceOptions struct {
|
||||||
cache cache.Cache[repository.InstanceIndex, string, *repository.Instance]
|
cache *cached.Instance
|
||||||
}
|
}
|
||||||
|
|
||||||
type instance struct {
|
type instance struct {
|
||||||
@@ -27,7 +27,6 @@ type instance struct {
|
|||||||
func Instance(opts ...Option[InstanceOptions]) *instance {
|
func Instance(opts ...Option[InstanceOptions]) *instance {
|
||||||
i := new(instance)
|
i := new(instance)
|
||||||
i.InstanceOptions = &i.options.custom
|
i.InstanceOptions = &i.options.custom
|
||||||
i.cache = noop.NewCache[repository.InstanceIndex, string, *repository.Instance]()
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt.apply(&i.options)
|
opt.apply(&i.options)
|
||||||
@@ -35,9 +34,9 @@ func Instance(opts ...Option[InstanceOptions]) *instance {
|
|||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithInstanceCache(cache cache.Cache[repository.InstanceIndex, string, *repository.Instance]) Option[InstanceOptions] {
|
func WithInstanceCache(c cache.Cache[repository.InstanceIndex, string, *repository.Instance]) Option[InstanceOptions] {
|
||||||
return func(opts *options[InstanceOptions]) {
|
return func(opts *options[InstanceOptions]) {
|
||||||
opts.custom.cache = cache
|
opts.custom.cache = cached.NewInstance(c)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -54,7 +53,7 @@ func (i *instance) Create(ctx context.Context, tx database.Transaction, instance
|
|||||||
traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.event.SetUp")),
|
traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.event.SetUp")),
|
||||||
logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.event.SetUp"),
|
logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.event.SetUp"),
|
||||||
),
|
),
|
||||||
handler.SkipNilHandler(i.cache,
|
handler.SkipReturnPreviousHandler(i.cache,
|
||||||
handler.Decorates(
|
handler.Decorates(
|
||||||
handler.NoReturnToHandle(i.cache.Set),
|
handler.NoReturnToHandle(i.cache.Set),
|
||||||
traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.cache.SetUp")),
|
traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.cache.SetUp")),
|
||||||
@@ -67,7 +66,9 @@ func (i *instance) Create(ctx context.Context, tx database.Transaction, instance
|
|||||||
|
|
||||||
func (i *instance) ByID(ctx context.Context, querier database.Querier, id string) (*repository.Instance, error) {
|
func (i *instance) ByID(ctx context.Context, querier database.Querier, id string) (*repository.Instance, error) {
|
||||||
return handler.SkipNext(
|
return handler.SkipNext(
|
||||||
handler.CacheGetToHandle(i.cache.Get, repository.InstanceByID),
|
handler.SkipNilHandler(i.cache,
|
||||||
|
handler.ResFuncToHandle(i.cache.ByID),
|
||||||
|
),
|
||||||
handler.Chain(
|
handler.Chain(
|
||||||
handler.Decorate(
|
handler.Decorate(
|
||||||
sql.Query(querier).InstanceByID,
|
sql.Query(querier).InstanceByID,
|
||||||
@@ -80,7 +81,9 @@ func (i *instance) ByID(ctx context.Context, querier database.Querier, id string
|
|||||||
|
|
||||||
func (i *instance) ByDomain(ctx context.Context, querier database.Querier, domain string) (*repository.Instance, error) {
|
func (i *instance) ByDomain(ctx context.Context, querier database.Querier, domain string) (*repository.Instance, error) {
|
||||||
return handler.SkipNext(
|
return handler.SkipNext(
|
||||||
handler.CacheGetToHandle(i.cache.Get, repository.InstanceByDomain),
|
handler.SkipNilHandler(i.cache,
|
||||||
|
handler.ResFuncToHandle(i.cache.ByDomain),
|
||||||
|
),
|
||||||
handler.Chain(
|
handler.Chain(
|
||||||
handler.Decorate(
|
handler.Decorate(
|
||||||
sql.Query(querier).InstanceByDomain,
|
sql.Query(querier).InstanceByDomain,
|
||||||
|
@@ -9,6 +9,7 @@ import (
|
|||||||
|
|
||||||
"github.com/zitadel/zitadel/backend/repository"
|
"github.com/zitadel/zitadel/backend/repository"
|
||||||
"github.com/zitadel/zitadel/backend/repository/orchestrate"
|
"github.com/zitadel/zitadel/backend/repository/orchestrate"
|
||||||
|
"github.com/zitadel/zitadel/backend/repository/sql"
|
||||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||||
"github.com/zitadel/zitadel/backend/storage/cache/connector/gomap"
|
"github.com/zitadel/zitadel/backend/storage/cache/connector/gomap"
|
||||||
"github.com/zitadel/zitadel/backend/storage/database"
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
@@ -17,7 +18,7 @@ import (
|
|||||||
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_instance_SetUp(t *testing.T) {
|
func Test_instance_Create(t *testing.T) {
|
||||||
type args struct {
|
type args struct {
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
tx database.Transaction
|
tx database.Transaction
|
||||||
@@ -41,7 +42,7 @@ func Test_instance_SetUp(t *testing.T) {
|
|||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
tx: mock.NewTransaction(),
|
tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")),
|
||||||
instance: &repository.Instance{
|
instance: &repository.Instance{
|
||||||
ID: "ID",
|
ID: "ID",
|
||||||
Name: "Name",
|
Name: "Name",
|
||||||
@@ -61,7 +62,7 @@ func Test_instance_SetUp(t *testing.T) {
|
|||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
tx: mock.NewTransaction(),
|
tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")),
|
||||||
instance: &repository.Instance{
|
instance: &repository.Instance{
|
||||||
ID: "ID",
|
ID: "ID",
|
||||||
Name: "Name",
|
Name: "Name",
|
||||||
@@ -80,7 +81,7 @@ func Test_instance_SetUp(t *testing.T) {
|
|||||||
},
|
},
|
||||||
args: args{
|
args: args{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
tx: mock.NewTransaction(),
|
tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")),
|
||||||
instance: &repository.Instance{
|
instance: &repository.Instance{
|
||||||
ID: "ID",
|
ID: "ID",
|
||||||
Name: "Name",
|
Name: "Name",
|
||||||
@@ -96,7 +97,7 @@ func Test_instance_SetUp(t *testing.T) {
|
|||||||
name: "without cache, tracer, logger",
|
name: "without cache, tracer, logger",
|
||||||
args: args{
|
args: args{
|
||||||
ctx: context.Background(),
|
ctx: context.Background(),
|
||||||
tx: mock.NewTransaction(),
|
tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")),
|
||||||
instance: &repository.Instance{
|
instance: &repository.Instance{
|
||||||
ID: "ID",
|
ID: "ID",
|
||||||
Name: "Name",
|
Name: "Name",
|
||||||
@@ -123,3 +124,118 @@ func Test_instance_SetUp(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_instance_ByID(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
tx database.Transaction
|
||||||
|
id string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
opts []orchestrate.Option[orchestrate.InstanceOptions]
|
||||||
|
args args
|
||||||
|
want *repository.Instance
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "simple, not cached",
|
||||||
|
opts: []orchestrate.Option[orchestrate.InstanceOptions]{
|
||||||
|
orchestrate.WithTracer[orchestrate.InstanceOptions](tracing.NewTracer("test")),
|
||||||
|
orchestrate.WithLogger[orchestrate.InstanceOptions](logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))),
|
||||||
|
orchestrate.WithInstanceCache(
|
||||||
|
gomap.NewCache[repository.InstanceIndex, string, *repository.Instance](context.Background(), repository.InstanceIndices, cache.Config{}),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
tx: mock.NewTransaction(t,
|
||||||
|
mock.ExpectQueryRow(mock.NewRow(t, "id", "Name"), sql.InstanceByIDStmt, "id"),
|
||||||
|
),
|
||||||
|
id: "id",
|
||||||
|
},
|
||||||
|
want: &repository.Instance{
|
||||||
|
ID: "id",
|
||||||
|
Name: "Name",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "simple, cached",
|
||||||
|
opts: []orchestrate.Option[orchestrate.InstanceOptions]{
|
||||||
|
orchestrate.WithTracer[orchestrate.InstanceOptions](tracing.NewTracer("test")),
|
||||||
|
orchestrate.WithLogger[orchestrate.InstanceOptions](logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))),
|
||||||
|
orchestrate.WithInstanceCache(
|
||||||
|
func() cache.Cache[repository.InstanceIndex, string, *repository.Instance] {
|
||||||
|
c := gomap.NewCache[repository.InstanceIndex, string, *repository.Instance](context.Background(), repository.InstanceIndices, cache.Config{})
|
||||||
|
c.Set(context.Background(), &repository.Instance{
|
||||||
|
ID: "id",
|
||||||
|
Name: "Name",
|
||||||
|
})
|
||||||
|
return c
|
||||||
|
}(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
tx: mock.NewTransaction(t,
|
||||||
|
mock.ExpectQueryRow(mock.NewRow(t, "id", "Name"), sql.InstanceByIDStmt, "id"),
|
||||||
|
),
|
||||||
|
id: "id",
|
||||||
|
},
|
||||||
|
want: &repository.Instance{
|
||||||
|
ID: "id",
|
||||||
|
Name: "Name",
|
||||||
|
},
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
// {
|
||||||
|
// name: "without cache, tracer",
|
||||||
|
// opts: []orchestrate.Option[orchestrate.InstanceOptions]{
|
||||||
|
// orchestrate.WithLogger[orchestrate.InstanceOptions](logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))),
|
||||||
|
// },
|
||||||
|
// args: args{
|
||||||
|
// ctx: context.Background(),
|
||||||
|
// tx: mock.NewTransaction(),
|
||||||
|
// id: &repository.Instance{
|
||||||
|
// ID: "ID",
|
||||||
|
// Name: "Name",
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// want: &repository.Instance{
|
||||||
|
// ID: "ID",
|
||||||
|
// Name: "Name",
|
||||||
|
// },
|
||||||
|
// wantErr: false,
|
||||||
|
// },
|
||||||
|
// {
|
||||||
|
// name: "without cache, tracer, logger",
|
||||||
|
// args: args{
|
||||||
|
// ctx: context.Background(),
|
||||||
|
// tx: mock.NewTransaction(),
|
||||||
|
// id: &repository.Instance{
|
||||||
|
// ID: "ID",
|
||||||
|
// Name: "Name",
|
||||||
|
// },
|
||||||
|
// },
|
||||||
|
// want: &repository.Instance{
|
||||||
|
// ID: "ID",
|
||||||
|
// Name: "Name",
|
||||||
|
// },
|
||||||
|
// wantErr: false,
|
||||||
|
// },
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
i := orchestrate.Instance(tt.opts...)
|
||||||
|
got, err := i.ByID(tt.args.ctx, tt.args.tx, tt.args.id)
|
||||||
|
if (err != nil) != tt.wantErr {
|
||||||
|
t.Errorf("instance.Create() error = %v, wantErr %v", err, tt.wantErr)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !reflect.DeepEqual(got, tt.want) {
|
||||||
|
t.Errorf("instance.Create() = %v, want %v", got, tt.want)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -4,18 +4,18 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/backend/repository"
|
"github.com/zitadel/zitadel/backend/repository"
|
||||||
|
"github.com/zitadel/zitadel/backend/repository/cached"
|
||||||
"github.com/zitadel/zitadel/backend/repository/event"
|
"github.com/zitadel/zitadel/backend/repository/event"
|
||||||
"github.com/zitadel/zitadel/backend/repository/orchestrate/handler"
|
"github.com/zitadel/zitadel/backend/repository/orchestrate/handler"
|
||||||
"github.com/zitadel/zitadel/backend/repository/sql"
|
"github.com/zitadel/zitadel/backend/repository/sql"
|
||||||
"github.com/zitadel/zitadel/backend/repository/telemetry/traced"
|
"github.com/zitadel/zitadel/backend/repository/telemetry/traced"
|
||||||
"github.com/zitadel/zitadel/backend/storage/cache"
|
"github.com/zitadel/zitadel/backend/storage/cache"
|
||||||
"github.com/zitadel/zitadel/backend/storage/cache/connector/noop"
|
|
||||||
"github.com/zitadel/zitadel/backend/storage/database"
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
"github.com/zitadel/zitadel/backend/telemetry/tracing"
|
||||||
)
|
)
|
||||||
|
|
||||||
type UserOptions struct {
|
type UserOptions struct {
|
||||||
cache cache.Cache[repository.UserIndex, string, *repository.User]
|
cache *cached.User
|
||||||
}
|
}
|
||||||
|
|
||||||
type user struct {
|
type user struct {
|
||||||
@@ -26,7 +26,6 @@ type user struct {
|
|||||||
func User(opts ...Option[UserOptions]) *user {
|
func User(opts ...Option[UserOptions]) *user {
|
||||||
i := new(user)
|
i := new(user)
|
||||||
i.UserOptions = &i.options.custom
|
i.UserOptions = &i.options.custom
|
||||||
i.cache = noop.NewCache[repository.UserIndex, string, *repository.User]()
|
|
||||||
|
|
||||||
for _, opt := range opts {
|
for _, opt := range opts {
|
||||||
opt(&i.options)
|
opt(&i.options)
|
||||||
@@ -36,7 +35,7 @@ func User(opts ...Option[UserOptions]) *user {
|
|||||||
|
|
||||||
func WithUserCache(cache cache.Cache[repository.UserIndex, string, *repository.User]) Option[UserOptions] {
|
func WithUserCache(cache cache.Cache[repository.UserIndex, string, *repository.User]) Option[UserOptions] {
|
||||||
return func(i *options[UserOptions]) {
|
return func(i *options[UserOptions]) {
|
||||||
i.custom.cache = cache
|
i.custom.cache = cached.NewUser(cache)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -57,7 +56,9 @@ func (i *user) Create(ctx context.Context, tx database.Transaction, user *reposi
|
|||||||
|
|
||||||
func (i *user) ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error) {
|
func (i *user) ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error) {
|
||||||
return handler.SkipNext(
|
return handler.SkipNext(
|
||||||
handler.CacheGetToHandle(i.cache.Get, repository.UserByID),
|
handler.SkipNilHandler(i.cache,
|
||||||
|
handler.ResFuncToHandle(i.cache.ByID),
|
||||||
|
),
|
||||||
handler.Chain(
|
handler.Chain(
|
||||||
handler.Decorate(
|
handler.Decorate(
|
||||||
sql.Query(querier).UserByID,
|
sql.Query(querier).UserByID,
|
||||||
|
@@ -7,11 +7,11 @@ import (
|
|||||||
"github.com/zitadel/zitadel/backend/repository"
|
"github.com/zitadel/zitadel/backend/repository"
|
||||||
)
|
)
|
||||||
|
|
||||||
const instanceByIDQuery = `SELECT id, name FROM instances WHERE id = $1`
|
const InstanceByIDStmt = `SELECT id, name FROM instances WHERE id = $1`
|
||||||
|
|
||||||
func (q *querier[C]) InstanceByID(ctx context.Context, id string) (*repository.Instance, error) {
|
func (q *querier[C]) InstanceByID(ctx context.Context, id string) (*repository.Instance, error) {
|
||||||
log.Println("sql.instance.byID")
|
log.Println("sql.instance.byID")
|
||||||
row := q.client.QueryRow(ctx, instanceByIDQuery, id)
|
row := q.client.QueryRow(ctx, InstanceByIDStmt, id)
|
||||||
var instance repository.Instance
|
var instance repository.Instance
|
||||||
if err := row.Scan(&instance.ID, &instance.Name); err != nil {
|
if err := row.Scan(&instance.ID, &instance.Name); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -51,9 +51,11 @@ func (q *querier[C]) ListInstances(ctx context.Context, request *repository.List
|
|||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
const InstanceCreateStmt = `INSERT INTO instances (id, name) VALUES ($1, $2)`
|
||||||
|
|
||||||
func (e *executor[C]) CreateInstance(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) {
|
func (e *executor[C]) CreateInstance(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) {
|
||||||
log.Println("sql.instance.create")
|
log.Println("sql.instance.create")
|
||||||
err := e.client.Exec(ctx, "INSERT INTO instances (id, name) VALUES ($1, $2)", instance.ID, instance.Name)
|
err := e.client.Exec(ctx, InstanceCreateStmt, instance.ID, instance.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -10,13 +10,13 @@ type User struct {
|
|||||||
type UserIndex uint8
|
type UserIndex uint8
|
||||||
|
|
||||||
var UserIndices = []UserIndex{
|
var UserIndices = []UserIndex{
|
||||||
UserByID,
|
UserByIDIndex,
|
||||||
UserByUsername,
|
UserByUsernameIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
UserByID UserIndex = iota
|
UserByIDIndex UserIndex = iota
|
||||||
UserByUsername
|
UserByUsernameIndex
|
||||||
)
|
)
|
||||||
|
|
||||||
var _ cache.Entry[UserIndex, string] = (*User)(nil)
|
var _ cache.Entry[UserIndex, string] = (*User)(nil)
|
||||||
@@ -24,9 +24,9 @@ var _ cache.Entry[UserIndex, string] = (*User)(nil)
|
|||||||
// Keys implements [cache.Entry].
|
// Keys implements [cache.Entry].
|
||||||
func (u *User) Keys(index UserIndex) (key []string) {
|
func (u *User) Keys(index UserIndex) (key []string) {
|
||||||
switch index {
|
switch index {
|
||||||
case UserByID:
|
case UserByIDIndex:
|
||||||
return []string{u.ID}
|
return []string{u.ID}
|
||||||
case UserByUsername:
|
case UserByUsernameIndex:
|
||||||
return []string{u.Username}
|
return []string{u.Username}
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
@@ -68,12 +68,12 @@ type QueryExecutor interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Querier interface {
|
type Querier interface {
|
||||||
Query(ctx context.Context, sql string, args ...any) (Rows, error)
|
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
|
||||||
QueryRow(ctx context.Context, sql string, args ...any) Row
|
QueryRow(ctx context.Context, stmt string, args ...any) Row
|
||||||
}
|
}
|
||||||
|
|
||||||
type Executor interface {
|
type Executor interface {
|
||||||
Exec(ctx context.Context, sql string, args ...any) error
|
Exec(ctx context.Context, stmt string, args ...any) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// LoadStatements sets the sql statements strings
|
// LoadStatements sets the sql statements strings
|
||||||
|
31
backend/storage/database/mock/row.go
Normal file
31
backend/storage/database/mock/row.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package mock
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Row struct {
|
||||||
|
t *testing.T
|
||||||
|
|
||||||
|
res []any
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewRow(t *testing.T, res ...any) *Row {
|
||||||
|
return &Row{t: t, res: res}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements [database.Row].
|
||||||
|
func (r *Row) Scan(dest ...any) error {
|
||||||
|
require.Len(r.t, dest, len(r.res))
|
||||||
|
for i := range dest {
|
||||||
|
reflect.ValueOf(dest[i]).Elem().Set(reflect.ValueOf(r.res[i]))
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ database.Row = (*Row)(nil)
|
@@ -3,17 +3,92 @@ package mock
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/backend/storage/database"
|
"github.com/zitadel/zitadel/backend/storage/database"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Transaction struct {
|
type Transaction struct {
|
||||||
|
t *testing.T
|
||||||
|
|
||||||
committed bool
|
committed bool
|
||||||
rolledBack bool
|
rolledBack bool
|
||||||
|
|
||||||
|
expectations []expecter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTransaction() *Transaction {
|
func NewTransaction(t *testing.T, opts ...TransactionOption) *Transaction {
|
||||||
return new(Transaction)
|
tx := &Transaction{t: t}
|
||||||
|
for _, opt := range opts {
|
||||||
|
opt(tx)
|
||||||
|
}
|
||||||
|
return tx
|
||||||
|
}
|
||||||
|
|
||||||
|
func (tx *Transaction) nextExpecter() expecter {
|
||||||
|
if len(tx.expectations) == 0 {
|
||||||
|
tx.t.Error("no more expectations on transaction")
|
||||||
|
tx.t.FailNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
e := tx.expectations[0]
|
||||||
|
tx.expectations = tx.expectations[1:]
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
type TransactionOption func(tx *Transaction)
|
||||||
|
|
||||||
|
type expecter interface {
|
||||||
|
assertArgs(ctx context.Context, stmt string, args ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectExec(stmt string, args ...any) TransactionOption {
|
||||||
|
return func(tx *Transaction) {
|
||||||
|
tx.expectations = append(tx.expectations, &expectation[struct{}]{
|
||||||
|
t: tx.t,
|
||||||
|
expectedStmt: stmt,
|
||||||
|
expectedArgs: args,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectQuery(res database.Rows, stmt string, args ...any) TransactionOption {
|
||||||
|
return func(tx *Transaction) {
|
||||||
|
tx.expectations = append(tx.expectations, &expectation[database.Rows]{
|
||||||
|
t: tx.t,
|
||||||
|
expectedStmt: stmt,
|
||||||
|
expectedArgs: args,
|
||||||
|
result: res,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func ExpectQueryRow(res database.Row, stmt string, args ...any) TransactionOption {
|
||||||
|
return func(tx *Transaction) {
|
||||||
|
tx.expectations = append(tx.expectations, &expectation[database.Row]{
|
||||||
|
t: tx.t,
|
||||||
|
expectedStmt: stmt,
|
||||||
|
expectedArgs: args,
|
||||||
|
result: res,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type expectation[R any] struct {
|
||||||
|
t *testing.T
|
||||||
|
|
||||||
|
expectedStmt string
|
||||||
|
expectedArgs []any
|
||||||
|
|
||||||
|
result R
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *expectation[R]) assertArgs(ctx context.Context, stmt string, args ...any) {
|
||||||
|
e.t.Helper()
|
||||||
|
assert.Equal(e.t, e.expectedStmt, stmt)
|
||||||
|
assert.Equal(e.t, e.expectedArgs, args)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit implements [database.Transaction].
|
// Commit implements [database.Transaction].
|
||||||
@@ -26,42 +101,48 @@ func (t *Transaction) Commit(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// End implements [database.Transaction].
|
// End implements [database.Transaction].
|
||||||
func (t *Transaction) End(ctx context.Context, err error) error {
|
func (tx *Transaction) End(ctx context.Context, err error) error {
|
||||||
if t.hasEnded() {
|
if tx.hasEnded() {
|
||||||
return errors.New("transaction already committed or rolled back")
|
return errors.New("transaction already committed or rolled back")
|
||||||
}
|
}
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return t.Rollback(ctx)
|
return tx.Rollback(ctx)
|
||||||
}
|
}
|
||||||
return t.Commit(ctx)
|
return tx.Commit(ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Exec implements [database.Transaction].
|
// Exec implements [database.Transaction].
|
||||||
func (t *Transaction) Exec(ctx context.Context, sql string, args ...any) error {
|
func (tx *Transaction) Exec(ctx context.Context, stmt string, args ...any) error {
|
||||||
|
tx.nextExpecter().assertArgs(ctx, stmt, args...)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Query implements [database.Transaction].
|
// Query implements [database.Transaction].
|
||||||
func (t *Transaction) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
func (tx *Transaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) {
|
||||||
return nil, nil
|
e := tx.nextExpecter()
|
||||||
|
e.assertArgs(ctx, stmt, args...)
|
||||||
|
return e.(*expectation[database.Rows]).result, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// QueryRow implements [database.Transaction].
|
// QueryRow implements [database.Transaction].
|
||||||
func (t *Transaction) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
func (tx *Transaction) QueryRow(ctx context.Context, stmt string, args ...any) database.Row {
|
||||||
return nil
|
e := tx.nextExpecter()
|
||||||
|
e.assertArgs(ctx, stmt, args...)
|
||||||
|
return e.(*expectation[database.Row]).result
|
||||||
}
|
}
|
||||||
|
|
||||||
// Rollback implements [database.Transaction].
|
// Rollback implements [database.Transaction].
|
||||||
func (t *Transaction) Rollback(ctx context.Context) error {
|
func (tx *Transaction) Rollback(ctx context.Context) error {
|
||||||
if t.hasEnded() {
|
if tx.hasEnded() {
|
||||||
return errors.New("transaction already committed or rolled back")
|
return errors.New("transaction already committed or rolled back")
|
||||||
}
|
}
|
||||||
t.rolledBack = true
|
tx.rolledBack = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var _ database.Transaction = (*Transaction)(nil)
|
var _ database.Transaction = (*Transaction)(nil)
|
||||||
|
|
||||||
func (t *Transaction) hasEnded() bool {
|
func (tx *Transaction) hasEnded() bool {
|
||||||
return t.committed || t.rolledBack
|
return tx.committed || tx.rolledBack
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user