From 258e973503818ca253e38feb6e62245930735d78 Mon Sep 17 00:00:00 2001 From: adlerhurst <27845747+adlerhurst@users.noreply.github.com> Date: Mon, 17 Mar 2025 08:08:32 +0100 Subject: [PATCH] better --- backend/repository/cached/instance.go | 30 +++++ backend/repository/cached/user.go | 25 ++++ .../repository/orchestrate/handler/handle.go | 117 +++++++++------- backend/repository/orchestrate/instance.go | 19 +-- .../repository/orchestrate/instance_test.go | 126 +++++++++++++++++- backend/repository/orchestrate/user.go | 11 +- backend/repository/sql/instance.go | 8 +- backend/repository/user.go | 12 +- backend/storage/database/database.go | 6 +- backend/storage/database/mock/row.go | 31 +++++ backend/storage/database/mock/transaction.go | 113 +++++++++++++--- 11 files changed, 406 insertions(+), 92 deletions(-) create mode 100644 backend/repository/cached/instance.go create mode 100644 backend/repository/cached/user.go create mode 100644 backend/storage/database/mock/row.go diff --git a/backend/repository/cached/instance.go b/backend/repository/cached/instance.go new file mode 100644 index 0000000000..c492dedbf5 --- /dev/null +++ b/backend/repository/cached/instance.go @@ -0,0 +1,30 @@ +package cached + +import ( + "context" + + "github.com/zitadel/zitadel/backend/repository" + "github.com/zitadel/zitadel/backend/storage/cache" +) + +type Instance struct { + cache.Cache[repository.InstanceIndex, string, *repository.Instance] +} + +func NewInstance(c cache.Cache[repository.InstanceIndex, string, *repository.Instance]) *Instance { + return &Instance{c} +} + +func (i *Instance) ByID(ctx context.Context, id string) *repository.Instance { + instance, _ := i.Cache.Get(ctx, repository.InstanceByID, id) + return instance +} + +func (i *Instance) ByDomain(ctx context.Context, domain string) *repository.Instance { + instance, _ := i.Cache.Get(ctx, repository.InstanceByDomain, domain) + return instance +} + +func (i *Instance) Set(ctx context.Context, instance *repository.Instance) { + i.Cache.Set(ctx, instance) +} diff --git a/backend/repository/cached/user.go b/backend/repository/cached/user.go new file mode 100644 index 0000000000..0d47b74acf --- /dev/null +++ b/backend/repository/cached/user.go @@ -0,0 +1,25 @@ +package cached + +import ( + "context" + + "github.com/zitadel/zitadel/backend/repository" + "github.com/zitadel/zitadel/backend/storage/cache" +) + +type User struct { + cache.Cache[repository.UserIndex, string, *repository.User] +} + +func NewUser(c cache.Cache[repository.UserIndex, string, *repository.User]) *User { + return &User{c} +} + +func (i *User) ByID(ctx context.Context, id string) *repository.User { + user, _ := i.Cache.Get(ctx, repository.UserByIDIndex, id) + return user +} + +func (i *User) Set(ctx context.Context, user *repository.User) { + i.Cache.Set(ctx, user) +} diff --git a/backend/repository/orchestrate/handler/handle.go b/backend/repository/orchestrate/handler/handle.go index 3fb3f23a26..e8bcfac426 100644 --- a/backend/repository/orchestrate/handler/handle.go +++ b/backend/repository/orchestrate/handler/handle.go @@ -2,100 +2,125 @@ package handler import ( "context" - - "github.com/zitadel/zitadel/backend/storage/cache" ) -// Handler is a function that handles the request. -type Handler[Req, Res any] func(ctx context.Context, request Req) (res Res, err error) +// Handler is a function that handles the in. +type Handler[Out, In any] func(ctx context.Context, in Out) (out In, err error) + +func (h Handler[Out, In]) Chain(next Handler[In, In]) Handler[Out, In] { + return Chain(h, next) +} + +func (h Handler[Out, In]) Decorate(decorate Decorator[Out, In]) Handler[Out, In] { + return Decorate(h, decorate) +} + +func (h Handler[Out, In]) SkipNext(next Handler[Out, In]) Handler[Out, In] { + return SkipNext(h, next) +} + +func (h Handler[Out, In]) SkipNilHandler(handler *Out) Handler[Out, In] { + return SkipNilHandler(handler, h) +} // Decorator is a function that decorates the handle function. -type Decorator[Req, Res any] func(ctx context.Context, request Req, handle Handler[Req, Res]) (res Res, err error) +type Decorator[In, Out any] func(ctx context.Context, in In, handle Handler[In, Out]) (out Out, err error) // Chain chains the handle function with the next handler. // The next handler is called after the handle function. -func Chain[Req, Res any](handle Handler[Req, Res], next Handler[Res, Res]) Handler[Req, Res] { - return func(ctx context.Context, request Req) (res Res, err error) { - res, err = handle(ctx, request) +func Chain[In, Out any](handle Handler[In, Out], next Handler[Out, Out]) Handler[In, Out] { + return func(ctx context.Context, in In) (out Out, err error) { + out, err = handle(ctx, in) if err != nil { - return res, err + return out, err } - return next(ctx, res) + return next(ctx, out) } } -func Chains[Req, Res any](handle Handler[Req, Res], nexts ...Handler[Res, Res]) Handler[Req, Res] { - return func(ctx context.Context, request Req) (res Res, err error) { +func Chains[In, Out any](handle Handler[In, Out], nexts ...Handler[Out, Out]) Handler[In, Out] { + return func(ctx context.Context, in In) (out Out, err error) { for _, next := range nexts { handle = Chain(handle, next) } - return handle(ctx, request) + return handle(ctx, in) } } // Decorate decorates the handle function with the decorate function. // The decorate function is called before the handle function. -func Decorate[Req, Res any](handle Handler[Req, Res], decorate Decorator[Req, Res]) Handler[Req, Res] { - return func(ctx context.Context, request Req) (res Res, err error) { - return decorate(ctx, request, handle) +func Decorate[In, Out any](handle Handler[In, Out], decorate Decorator[In, Out]) Handler[In, Out] { + return func(ctx context.Context, in In) (out Out, err error) { + return decorate(ctx, in, handle) } } // Decorates decorates the handle function with the decorate functions. // The decorates function is called before the handle function. -func Decorates[Req, Res any](handle Handler[Req, Res], decorates ...Decorator[Req, Res]) Handler[Req, Res] { - return func(ctx context.Context, request Req) (res Res, err error) { +func Decorates[In, Out any](handle Handler[In, Out], decorates ...Decorator[In, Out]) Handler[In, Out] { + return func(ctx context.Context, in In) (out Out, err error) { for i := len(decorates) - 1; i >= 0; i-- { handle = Decorate(handle, decorates[i]) } - return handle(ctx, request) + return handle(ctx, in) } } // SkipNext skips the next handler if the handle function returns a non-nil response. -func SkipNext[Req, Res any](handle Handler[Req, Res], next Handler[Req, Res]) Handler[Req, Res] { - return func(ctx context.Context, request Req) (res Res, err error) { - var empty Res - res, err = handle(ctx, request) +func SkipNext[In, Out any](handle Handler[In, Out], next Handler[In, Out]) Handler[In, Out] { + return func(ctx context.Context, in In) (out Out, err error) { + var empty Out + out, err = handle(ctx, in) // TODO: does this work? - if any(res) == any(empty) || err != nil { - return res, err + if any(out) != any(empty) || err != nil { + return out, err } - return next(ctx, request) + return next(ctx, in) } } // SkipNilHandler skips the handle function if the handler is nil. +// If handle is nil, an empty output is returned. // The function is safe to call with nil handler. -func SkipNilHandler[R any](handler any, handle Handler[R, R]) Handler[R, R] { - return func(ctx context.Context, request R) (res R, err error) { +func SkipNilHandler[O, In, Out any](handler *O, handle Handler[In, Out]) Handler[In, Out] { + return func(ctx context.Context, in In) (out Out, err error) { if handler == nil { - return request, nil + return out, nil } - return handle(ctx, request) + return handle(ctx, in) } } -func ErrFuncToHandle[R any](fn func(context.Context, R) error) Handler[R, R] { - return func(ctx context.Context, request R) (res R, err error) { - err = fn(ctx, request) +// SkipReturnPreviousHandler skips the handle function if the handler is nil and returns the input. +// The function is safe to call with nil handler. +func SkipReturnPreviousHandler[O, In any](handler *O, handle Handler[In, In]) Handler[In, In] { + return func(ctx context.Context, in In) (out In, err error) { + if handler == nil { + return in, nil + } + return handle(ctx, in) + } +} + +func ResFuncToHandle[In any, Out any](fn func(context.Context, In) Out) Handler[In, Out] { + return func(ctx context.Context, in In) (out Out, err error) { + return fn(ctx, in), nil + } +} + +func ErrFuncToHandle[In any](fn func(context.Context, In) error) Handler[In, In] { + return func(ctx context.Context, in In) (out In, err error) { + err = fn(ctx, in) if err != nil { - return res, err + return out, err } - return request, nil + return in, nil } } -func NoReturnToHandle[R any](fn func(context.Context, R)) Handler[R, R] { - return func(ctx context.Context, request R) (res R, err error) { - fn(ctx, request) - return request, nil - } -} - -func CacheGetToHandle[I, K comparable, E cache.Entry[I, K]](fn func(context.Context, I, K) (E, bool), index I) Handler[K, E] { - return func(ctx context.Context, request K) (res E, err error) { - res, _ = fn(ctx, index, request) - return res, nil +func NoReturnToHandle[In any](fn func(context.Context, In)) Handler[In, In] { + return func(ctx context.Context, in In) (out In, err error) { + fn(ctx, in) + return in, nil } } diff --git a/backend/repository/orchestrate/instance.go b/backend/repository/orchestrate/instance.go index 0f156e3f9c..7b2c68a096 100644 --- a/backend/repository/orchestrate/instance.go +++ b/backend/repository/orchestrate/instance.go @@ -4,19 +4,19 @@ import ( "context" "github.com/zitadel/zitadel/backend/repository" + "github.com/zitadel/zitadel/backend/repository/cached" "github.com/zitadel/zitadel/backend/repository/event" "github.com/zitadel/zitadel/backend/repository/orchestrate/handler" "github.com/zitadel/zitadel/backend/repository/sql" "github.com/zitadel/zitadel/backend/repository/telemetry/logged" "github.com/zitadel/zitadel/backend/repository/telemetry/traced" "github.com/zitadel/zitadel/backend/storage/cache" - "github.com/zitadel/zitadel/backend/storage/cache/connector/noop" "github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/telemetry/tracing" ) type InstanceOptions struct { - cache cache.Cache[repository.InstanceIndex, string, *repository.Instance] + cache *cached.Instance } type instance struct { @@ -27,7 +27,6 @@ type instance struct { func Instance(opts ...Option[InstanceOptions]) *instance { i := new(instance) i.InstanceOptions = &i.options.custom - i.cache = noop.NewCache[repository.InstanceIndex, string, *repository.Instance]() for _, opt := range opts { opt.apply(&i.options) @@ -35,9 +34,9 @@ func Instance(opts ...Option[InstanceOptions]) *instance { return i } -func WithInstanceCache(cache cache.Cache[repository.InstanceIndex, string, *repository.Instance]) Option[InstanceOptions] { +func WithInstanceCache(c cache.Cache[repository.InstanceIndex, string, *repository.Instance]) Option[InstanceOptions] { return func(opts *options[InstanceOptions]) { - opts.custom.cache = cache + opts.custom.cache = cached.NewInstance(c) } } @@ -54,7 +53,7 @@ func (i *instance) Create(ctx context.Context, tx database.Transaction, instance traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.event.SetUp")), logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.event.SetUp"), ), - handler.SkipNilHandler(i.cache, + handler.SkipReturnPreviousHandler(i.cache, handler.Decorates( handler.NoReturnToHandle(i.cache.Set), traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.cache.SetUp")), @@ -67,7 +66,9 @@ func (i *instance) Create(ctx context.Context, tx database.Transaction, instance func (i *instance) ByID(ctx context.Context, querier database.Querier, id string) (*repository.Instance, error) { return handler.SkipNext( - handler.CacheGetToHandle(i.cache.Get, repository.InstanceByID), + handler.SkipNilHandler(i.cache, + handler.ResFuncToHandle(i.cache.ByID), + ), handler.Chain( handler.Decorate( sql.Query(querier).InstanceByID, @@ -80,7 +81,9 @@ func (i *instance) ByID(ctx context.Context, querier database.Querier, id string func (i *instance) ByDomain(ctx context.Context, querier database.Querier, domain string) (*repository.Instance, error) { return handler.SkipNext( - handler.CacheGetToHandle(i.cache.Get, repository.InstanceByDomain), + handler.SkipNilHandler(i.cache, + handler.ResFuncToHandle(i.cache.ByDomain), + ), handler.Chain( handler.Decorate( sql.Query(querier).InstanceByDomain, diff --git a/backend/repository/orchestrate/instance_test.go b/backend/repository/orchestrate/instance_test.go index cab67e63a5..525ab78da0 100644 --- a/backend/repository/orchestrate/instance_test.go +++ b/backend/repository/orchestrate/instance_test.go @@ -9,6 +9,7 @@ import ( "github.com/zitadel/zitadel/backend/repository" "github.com/zitadel/zitadel/backend/repository/orchestrate" + "github.com/zitadel/zitadel/backend/repository/sql" "github.com/zitadel/zitadel/backend/storage/cache" "github.com/zitadel/zitadel/backend/storage/cache/connector/gomap" "github.com/zitadel/zitadel/backend/storage/database" @@ -17,7 +18,7 @@ import ( "github.com/zitadel/zitadel/backend/telemetry/tracing" ) -func Test_instance_SetUp(t *testing.T) { +func Test_instance_Create(t *testing.T) { type args struct { ctx context.Context tx database.Transaction @@ -41,7 +42,7 @@ func Test_instance_SetUp(t *testing.T) { }, args: args{ ctx: context.Background(), - tx: mock.NewTransaction(), + tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")), instance: &repository.Instance{ ID: "ID", Name: "Name", @@ -61,7 +62,7 @@ func Test_instance_SetUp(t *testing.T) { }, args: args{ ctx: context.Background(), - tx: mock.NewTransaction(), + tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")), instance: &repository.Instance{ ID: "ID", Name: "Name", @@ -80,7 +81,7 @@ func Test_instance_SetUp(t *testing.T) { }, args: args{ ctx: context.Background(), - tx: mock.NewTransaction(), + tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")), instance: &repository.Instance{ ID: "ID", Name: "Name", @@ -96,7 +97,7 @@ func Test_instance_SetUp(t *testing.T) { name: "without cache, tracer, logger", args: args{ ctx: context.Background(), - tx: mock.NewTransaction(), + tx: mock.NewTransaction(t, mock.ExpectExec(sql.InstanceCreateStmt, "ID", "Name")), instance: &repository.Instance{ ID: "ID", Name: "Name", @@ -123,3 +124,118 @@ func Test_instance_SetUp(t *testing.T) { }) } } + +func Test_instance_ByID(t *testing.T) { + type args struct { + ctx context.Context + tx database.Transaction + id string + } + tests := []struct { + name string + opts []orchestrate.Option[orchestrate.InstanceOptions] + args args + want *repository.Instance + wantErr bool + }{ + { + name: "simple, not cached", + opts: []orchestrate.Option[orchestrate.InstanceOptions]{ + orchestrate.WithTracer[orchestrate.InstanceOptions](tracing.NewTracer("test")), + orchestrate.WithLogger[orchestrate.InstanceOptions](logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))), + orchestrate.WithInstanceCache( + gomap.NewCache[repository.InstanceIndex, string, *repository.Instance](context.Background(), repository.InstanceIndices, cache.Config{}), + ), + }, + args: args{ + ctx: context.Background(), + tx: mock.NewTransaction(t, + mock.ExpectQueryRow(mock.NewRow(t, "id", "Name"), sql.InstanceByIDStmt, "id"), + ), + id: "id", + }, + want: &repository.Instance{ + ID: "id", + Name: "Name", + }, + wantErr: false, + }, + { + name: "simple, cached", + opts: []orchestrate.Option[orchestrate.InstanceOptions]{ + orchestrate.WithTracer[orchestrate.InstanceOptions](tracing.NewTracer("test")), + orchestrate.WithLogger[orchestrate.InstanceOptions](logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))), + orchestrate.WithInstanceCache( + func() cache.Cache[repository.InstanceIndex, string, *repository.Instance] { + c := gomap.NewCache[repository.InstanceIndex, string, *repository.Instance](context.Background(), repository.InstanceIndices, cache.Config{}) + c.Set(context.Background(), &repository.Instance{ + ID: "id", + Name: "Name", + }) + return c + }(), + ), + }, + args: args{ + ctx: context.Background(), + tx: mock.NewTransaction(t, + mock.ExpectQueryRow(mock.NewRow(t, "id", "Name"), sql.InstanceByIDStmt, "id"), + ), + id: "id", + }, + want: &repository.Instance{ + ID: "id", + Name: "Name", + }, + wantErr: false, + }, + // { + // name: "without cache, tracer", + // opts: []orchestrate.Option[orchestrate.InstanceOptions]{ + // orchestrate.WithLogger[orchestrate.InstanceOptions](logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))), + // }, + // args: args{ + // ctx: context.Background(), + // tx: mock.NewTransaction(), + // id: &repository.Instance{ + // ID: "ID", + // Name: "Name", + // }, + // }, + // want: &repository.Instance{ + // ID: "ID", + // Name: "Name", + // }, + // wantErr: false, + // }, + // { + // name: "without cache, tracer, logger", + // args: args{ + // ctx: context.Background(), + // tx: mock.NewTransaction(), + // id: &repository.Instance{ + // ID: "ID", + // Name: "Name", + // }, + // }, + // want: &repository.Instance{ + // ID: "ID", + // Name: "Name", + // }, + // wantErr: false, + // }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := orchestrate.Instance(tt.opts...) + got, err := i.ByID(tt.args.ctx, tt.args.tx, tt.args.id) + if (err != nil) != tt.wantErr { + t.Errorf("instance.Create() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("instance.Create() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/backend/repository/orchestrate/user.go b/backend/repository/orchestrate/user.go index 81cc1b0849..ff3de5bb10 100644 --- a/backend/repository/orchestrate/user.go +++ b/backend/repository/orchestrate/user.go @@ -4,18 +4,18 @@ import ( "context" "github.com/zitadel/zitadel/backend/repository" + "github.com/zitadel/zitadel/backend/repository/cached" "github.com/zitadel/zitadel/backend/repository/event" "github.com/zitadel/zitadel/backend/repository/orchestrate/handler" "github.com/zitadel/zitadel/backend/repository/sql" "github.com/zitadel/zitadel/backend/repository/telemetry/traced" "github.com/zitadel/zitadel/backend/storage/cache" - "github.com/zitadel/zitadel/backend/storage/cache/connector/noop" "github.com/zitadel/zitadel/backend/storage/database" "github.com/zitadel/zitadel/backend/telemetry/tracing" ) type UserOptions struct { - cache cache.Cache[repository.UserIndex, string, *repository.User] + cache *cached.User } type user struct { @@ -26,7 +26,6 @@ type user struct { func User(opts ...Option[UserOptions]) *user { i := new(user) i.UserOptions = &i.options.custom - i.cache = noop.NewCache[repository.UserIndex, string, *repository.User]() for _, opt := range opts { opt(&i.options) @@ -36,7 +35,7 @@ func User(opts ...Option[UserOptions]) *user { func WithUserCache(cache cache.Cache[repository.UserIndex, string, *repository.User]) Option[UserOptions] { return func(i *options[UserOptions]) { - i.custom.cache = cache + i.custom.cache = cached.NewUser(cache) } } @@ -57,7 +56,9 @@ func (i *user) Create(ctx context.Context, tx database.Transaction, user *reposi func (i *user) ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error) { return handler.SkipNext( - handler.CacheGetToHandle(i.cache.Get, repository.UserByID), + handler.SkipNilHandler(i.cache, + handler.ResFuncToHandle(i.cache.ByID), + ), handler.Chain( handler.Decorate( sql.Query(querier).UserByID, diff --git a/backend/repository/sql/instance.go b/backend/repository/sql/instance.go index 2e0f17ade5..459912be3e 100644 --- a/backend/repository/sql/instance.go +++ b/backend/repository/sql/instance.go @@ -7,11 +7,11 @@ import ( "github.com/zitadel/zitadel/backend/repository" ) -const instanceByIDQuery = `SELECT id, name FROM instances WHERE id = $1` +const InstanceByIDStmt = `SELECT id, name FROM instances WHERE id = $1` func (q *querier[C]) InstanceByID(ctx context.Context, id string) (*repository.Instance, error) { log.Println("sql.instance.byID") - row := q.client.QueryRow(ctx, instanceByIDQuery, id) + row := q.client.QueryRow(ctx, InstanceByIDStmt, id) var instance repository.Instance if err := row.Scan(&instance.ID, &instance.Name); err != nil { return nil, err @@ -51,9 +51,11 @@ func (q *querier[C]) ListInstances(ctx context.Context, request *repository.List return res, nil } +const InstanceCreateStmt = `INSERT INTO instances (id, name) VALUES ($1, $2)` + func (e *executor[C]) CreateInstance(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { log.Println("sql.instance.create") - err := e.client.Exec(ctx, "INSERT INTO instances (id, name) VALUES ($1, $2)", instance.ID, instance.Name) + err := e.client.Exec(ctx, InstanceCreateStmt, instance.ID, instance.Name) if err != nil { return nil, err } diff --git a/backend/repository/user.go b/backend/repository/user.go index 6ffe767a26..1a5dc80279 100644 --- a/backend/repository/user.go +++ b/backend/repository/user.go @@ -10,13 +10,13 @@ type User struct { type UserIndex uint8 var UserIndices = []UserIndex{ - UserByID, - UserByUsername, + UserByIDIndex, + UserByUsernameIndex, } const ( - UserByID UserIndex = iota - UserByUsername + UserByIDIndex UserIndex = iota + UserByUsernameIndex ) var _ cache.Entry[UserIndex, string] = (*User)(nil) @@ -24,9 +24,9 @@ var _ cache.Entry[UserIndex, string] = (*User)(nil) // Keys implements [cache.Entry]. func (u *User) Keys(index UserIndex) (key []string) { switch index { - case UserByID: + case UserByIDIndex: return []string{u.ID} - case UserByUsername: + case UserByUsernameIndex: return []string{u.Username} } return nil diff --git a/backend/storage/database/database.go b/backend/storage/database/database.go index 9d9cc6f676..2108f36423 100644 --- a/backend/storage/database/database.go +++ b/backend/storage/database/database.go @@ -68,12 +68,12 @@ type QueryExecutor interface { } type Querier interface { - Query(ctx context.Context, sql string, args ...any) (Rows, error) - QueryRow(ctx context.Context, sql string, args ...any) Row + Query(ctx context.Context, stmt string, args ...any) (Rows, error) + QueryRow(ctx context.Context, stmt string, args ...any) Row } type Executor interface { - Exec(ctx context.Context, sql string, args ...any) error + Exec(ctx context.Context, stmt string, args ...any) error } // LoadStatements sets the sql statements strings diff --git a/backend/storage/database/mock/row.go b/backend/storage/database/mock/row.go new file mode 100644 index 0000000000..3d95247393 --- /dev/null +++ b/backend/storage/database/mock/row.go @@ -0,0 +1,31 @@ +package mock + +import ( + "reflect" + "testing" + + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/storage/database" +) + +type Row struct { + t *testing.T + + res []any +} + +func NewRow(t *testing.T, res ...any) *Row { + return &Row{t: t, res: res} +} + +// Scan implements [database.Row]. +func (r *Row) Scan(dest ...any) error { + require.Len(r.t, dest, len(r.res)) + for i := range dest { + reflect.ValueOf(dest[i]).Elem().Set(reflect.ValueOf(r.res[i])) + } + return nil +} + +var _ database.Row = (*Row)(nil) diff --git a/backend/storage/database/mock/transaction.go b/backend/storage/database/mock/transaction.go index ba750324f9..b3fd47bb15 100644 --- a/backend/storage/database/mock/transaction.go +++ b/backend/storage/database/mock/transaction.go @@ -3,17 +3,92 @@ package mock import ( "context" "errors" + "testing" + + "github.com/stretchr/testify/assert" "github.com/zitadel/zitadel/backend/storage/database" ) type Transaction struct { + t *testing.T + committed bool rolledBack bool + + expectations []expecter } -func NewTransaction() *Transaction { - return new(Transaction) +func NewTransaction(t *testing.T, opts ...TransactionOption) *Transaction { + tx := &Transaction{t: t} + for _, opt := range opts { + opt(tx) + } + return tx +} + +func (tx *Transaction) nextExpecter() expecter { + if len(tx.expectations) == 0 { + tx.t.Error("no more expectations on transaction") + tx.t.FailNow() + } + + e := tx.expectations[0] + tx.expectations = tx.expectations[1:] + return e +} + +type TransactionOption func(tx *Transaction) + +type expecter interface { + assertArgs(ctx context.Context, stmt string, args ...any) +} + +func ExpectExec(stmt string, args ...any) TransactionOption { + return func(tx *Transaction) { + tx.expectations = append(tx.expectations, &expectation[struct{}]{ + t: tx.t, + expectedStmt: stmt, + expectedArgs: args, + }) + } +} + +func ExpectQuery(res database.Rows, stmt string, args ...any) TransactionOption { + return func(tx *Transaction) { + tx.expectations = append(tx.expectations, &expectation[database.Rows]{ + t: tx.t, + expectedStmt: stmt, + expectedArgs: args, + result: res, + }) + } +} + +func ExpectQueryRow(res database.Row, stmt string, args ...any) TransactionOption { + return func(tx *Transaction) { + tx.expectations = append(tx.expectations, &expectation[database.Row]{ + t: tx.t, + expectedStmt: stmt, + expectedArgs: args, + result: res, + }) + } +} + +type expectation[R any] struct { + t *testing.T + + expectedStmt string + expectedArgs []any + + result R +} + +func (e *expectation[R]) assertArgs(ctx context.Context, stmt string, args ...any) { + e.t.Helper() + assert.Equal(e.t, e.expectedStmt, stmt) + assert.Equal(e.t, e.expectedArgs, args) } // Commit implements [database.Transaction]. @@ -26,42 +101,48 @@ func (t *Transaction) Commit(ctx context.Context) error { } // End implements [database.Transaction]. -func (t *Transaction) End(ctx context.Context, err error) error { - if t.hasEnded() { +func (tx *Transaction) End(ctx context.Context, err error) error { + if tx.hasEnded() { return errors.New("transaction already committed or rolled back") } if err != nil { - return t.Rollback(ctx) + return tx.Rollback(ctx) } - return t.Commit(ctx) + return tx.Commit(ctx) } // Exec implements [database.Transaction]. -func (t *Transaction) Exec(ctx context.Context, sql string, args ...any) error { +func (tx *Transaction) Exec(ctx context.Context, stmt string, args ...any) error { + tx.nextExpecter().assertArgs(ctx, stmt, args...) + return nil } // Query implements [database.Transaction]. -func (t *Transaction) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { - return nil, nil +func (tx *Transaction) Query(ctx context.Context, stmt string, args ...any) (database.Rows, error) { + e := tx.nextExpecter() + e.assertArgs(ctx, stmt, args...) + return e.(*expectation[database.Rows]).result, nil } // QueryRow implements [database.Transaction]. -func (t *Transaction) QueryRow(ctx context.Context, sql string, args ...any) database.Row { - return nil +func (tx *Transaction) QueryRow(ctx context.Context, stmt string, args ...any) database.Row { + e := tx.nextExpecter() + e.assertArgs(ctx, stmt, args...) + return e.(*expectation[database.Row]).result } // Rollback implements [database.Transaction]. -func (t *Transaction) Rollback(ctx context.Context) error { - if t.hasEnded() { +func (tx *Transaction) Rollback(ctx context.Context) error { + if tx.hasEnded() { return errors.New("transaction already committed or rolled back") } - t.rolledBack = true + tx.rolledBack = true return nil } var _ database.Transaction = (*Transaction)(nil) -func (t *Transaction) hasEnded() bool { - return t.committed || t.rolledBack +func (tx *Transaction) hasEnded() bool { + return tx.committed || tx.rolledBack }