diff --git a/backend/domain/instance.go b/backend/domain/instance.go index 0b16228914..9a966c0c40 100644 --- a/backend/domain/instance.go +++ b/backend/domain/instance.go @@ -13,7 +13,8 @@ import ( type Instance struct { db database.Pool - orchestrator instanceOrchestrator + instance instanceOrchestrator + user userOrchestrator } type instanceOrchestrator interface { @@ -24,19 +25,20 @@ type instanceOrchestrator interface { func NewInstance(db database.Pool, tracer *tracing.Tracer, logger *logging.Logger) *Instance { b := &Instance{ - db: db, - orchestrator: orchestrate.Instance(), + db: db, + instance: orchestrate.Instance(), + user: orchestrate.User(), } return b } func (b *Instance) ByID(ctx context.Context, id string) (*repository.Instance, error) { - return b.orchestrator.ByID(ctx, b.db, id) + return b.instance.ByID(ctx, b.db, id) } func (b *Instance) ByDomain(ctx context.Context, domain string) (*repository.Instance, error) { - return b.orchestrator.ByDomain(ctx, b.db, domain) + return b.instance.ByDomain(ctx, b.db, domain) } type SetUpInstance struct { @@ -52,9 +54,10 @@ func (b *Instance) SetUp(ctx context.Context, request *SetUpInstance) (err error defer func() { err = tx.End(ctx, err) }() - _, err = b.orchestrator.SetUp(ctx, tx, request.Instance) + _, err = b.instance.SetUp(ctx, tx, request.Instance) if err != nil { return err } - return b.userCommandRepo(tx).Create(ctx, request.User) + _, err = b.user.Create(ctx, tx, request.User) + return err } diff --git a/backend/domain/user.go b/backend/domain/user.go index f98f51795a..292cd999cf 100644 --- a/backend/domain/user.go +++ b/backend/domain/user.go @@ -1,39 +1,13 @@ package domain import ( + "context" + "github.com/zitadel/zitadel/backend/repository" - "github.com/zitadel/zitadel/backend/repository/event" - "github.com/zitadel/zitadel/backend/repository/sql" - "github.com/zitadel/zitadel/backend/repository/telemetry/logged" - "github.com/zitadel/zitadel/backend/repository/telemetry/traced" "github.com/zitadel/zitadel/backend/storage/database" - "github.com/zitadel/zitadel/backend/storage/eventstore" ) -func (b *Instance) userCommandRepo(tx database.Transaction) repository.UserRepository { - return logged.NewUser( - b.logger, - traced.NewUser( - b.tracer, - event.NewUser( - eventstore.New(tx), - sql.NewUser(tx), - ), - ), - ) -} - -func (b *Instance) userQueryRepo(tx database.QueryExecutor) repository.UserRepository { - return logged.NewUser( - b.logger, - traced.NewUser( - b.tracer, - sql.NewUser(tx), - ), - ) -} - -type User struct { - ID string - Username string +type userOrchestrator interface { + Create(ctx context.Context, client database.Transaction, user *repository.User) (*repository.User, error) + ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error) } diff --git a/backend/repository/cache/instance.go b/backend/repository/cache/instance.go index 7787dc9128..8e29148023 100644 --- a/backend/repository/cache/instance.go +++ b/backend/repository/cache/instance.go @@ -5,8 +5,8 @@ import ( "sync" "github.com/zitadel/zitadel/backend/repository" - "github.com/zitadel/zitadel/backend/repository/orchestrate/handler" "github.com/zitadel/zitadel/backend/storage/cache" + "github.com/zitadel/zitadel/backend/storage/cache/gomap" ) type Instance struct { @@ -15,85 +15,31 @@ type Instance struct { byDomain cache.Cache[string, *repository.Instance] } -func SetUpInstance( - cache *Instance, - handle handler.Handle[*repository.Instance, *repository.Instance], -) handler.Handle[*repository.Instance, *repository.Instance] { - return func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - instance, err := handle(ctx, instance) - if err != nil { - return nil, err - } - - cache.set(instance, "") - return instance, nil +func NewInstance() *Instance { + return &Instance{ + mu: &sync.RWMutex{}, + byID: gomap.New[string, *repository.Instance](), + byDomain: gomap.New[string, *repository.Instance](), } } -func SetUpInstanceWithout(cache *Instance) handler.Handle[*repository.Instance, *repository.Instance] { - return func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - cache.set(instance, "") - return instance, nil - } +func (i *Instance) Set(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { + i.set(instance, "") + return instance, nil } -func SetUpInstanceDecorated( - cache *Instance, - handle handler.Handle[*repository.Instance, *repository.Instance], - decorator handler.Decorate[*repository.Instance, *repository.Instance], -) handler.Handle[*repository.Instance, *repository.Instance] { - return func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - instance, err := handle(ctx, instance) - if err != nil { - return nil, err - } - - return decorator(ctx, instance, func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - cache.set(instance, "") - return instance, nil - }) - } +func (i *Instance) ByID(ctx context.Context, id string) (*repository.Instance, error) { + i.mu.RLock() + defer i.mu.RUnlock() + instance, _ := i.byID.Get(id) + return instance, nil } -func ForInstanceByID(cache *Instance, handle handler.Handle[string, *repository.Instance]) handler.Handle[string, *repository.Instance] { - return func(ctx context.Context, id string) (*repository.Instance, error) { - cache.mu.RLock() - - instance, ok := cache.byID.Get(id) - cache.mu.RUnlock() - if ok { - return instance, nil - } - - instance, err := handle(ctx, id) - if err != nil { - return nil, err - - } - - cache.set(instance, "") - return instance, nil - } -} - -func ForInstanceByDomain(cache *Instance, handle handler.Handle[string, *repository.Instance]) handler.Handle[string, *repository.Instance] { - return func(ctx context.Context, domain string) (*repository.Instance, error) { - cache.mu.RLock() - - instance, ok := cache.byDomain.Get(domain) - cache.mu.RUnlock() - if ok { - return instance, nil - } - - instance, err := handle(ctx, domain) - if err != nil { - return nil, err - } - - cache.set(instance, domain) - return instance, nil - } +func (i *Instance) ByDomain(ctx context.Context, domain string) (*repository.Instance, error) { + i.mu.RLock() + defer i.mu.RUnlock() + instance, _ := i.byDomain.Get(domain) + return instance, nil } func (i *Instance) set(instance *repository.Instance, domain string) { diff --git a/backend/repository/cache/user.go b/backend/repository/cache/user.go index a7bcad25d0..98c944c2a7 100644 --- a/backend/repository/cache/user.go +++ b/backend/repository/cache/user.go @@ -5,41 +5,31 @@ import ( "github.com/zitadel/zitadel/backend/repository" "github.com/zitadel/zitadel/backend/storage/cache" + "github.com/zitadel/zitadel/backend/storage/cache/gomap" ) type User struct { cache.Cache[string, *repository.User] +} - next repository.UserRepository +func NewUser() *User { + return &User{ + Cache: gomap.New[string, *repository.User](), + } } // ByID implements repository.UserRepository. func (u *User) ByID(ctx context.Context, id string) (*repository.User, error) { - if user, ok := u.Get(id); ok { - return user, nil - } + user, _ := u.Get(id) + return user, nil - user, err := u.next.ByID(ctx, id) - if err != nil { - return nil, err - } +} +func (u *User) Set(ctx context.Context, user *repository.User) (*repository.User, error) { u.set(user) return user, nil } -// Create implements repository.UserRepository. -func (u *User) Create(ctx context.Context, user *repository.User) error { - err := u.next.Create(ctx, user) - if err != nil { - return err - } - u.set(user) - return nil -} - -var _ repository.UserRepository = (*User)(nil) - func (u *User) set(user *repository.User) { u.Cache.Set(user.ID, user) } diff --git a/backend/repository/event/instance.go b/backend/repository/event/instance.go index 5252a6a17c..67b07df266 100644 --- a/backend/repository/event/instance.go +++ b/backend/repository/event/instance.go @@ -4,59 +4,12 @@ import ( "context" "github.com/zitadel/zitadel/backend/repository" - "github.com/zitadel/zitadel/backend/repository/orchestrate/handler" - "github.com/zitadel/zitadel/backend/storage/database" - "github.com/zitadel/zitadel/backend/storage/eventstore" ) -func SetUpInstance( - client database.Executor, - next handler.Handle[*repository.Instance, *repository.Instance], -) handler.Handle[*repository.Instance, *repository.Instance] { - es := eventstore.New(client) - return func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - instance, err := next(ctx, instance) - if err != nil { - return nil, err - } - - err = es.Push(ctx, instance) - if err != nil { - return nil, err - } - return instance, nil - } -} - -func SetUpInstanceWithout(client database.Executor) handler.Handle[*repository.Instance, *repository.Instance] { - es := eventstore.New(client) - return func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - err := es.Push(ctx, instance) - if err != nil { - return nil, err - } - return instance, nil - } -} - -func SetUpInstanceDecorated( - client database.Executor, - next handler.Handle[*repository.Instance, *repository.Instance], - decorate handler.Decorate[*repository.Instance, *repository.Instance], -) handler.Handle[*repository.Instance, *repository.Instance] { - es := eventstore.New(client) - return func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - instance, err := next(ctx, instance) - if err != nil { - return nil, err - } - - return decorate(ctx, instance, func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - err = es.Push(ctx, instance) - if err != nil { - return nil, err - } - return instance, nil - }) +func (s *store) CreateInstance(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { + err := s.es.Push(ctx, instance) + if err != nil { + return nil, err } + return instance, nil } diff --git a/backend/repository/event/store.go b/backend/repository/event/store.go new file mode 100644 index 0000000000..eff0636a35 --- /dev/null +++ b/backend/repository/event/store.go @@ -0,0 +1,16 @@ +package event + +import ( + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/backend/storage/eventstore" +) + +type store struct { + es *eventstore.Eventstore +} + +func Store(client database.Executor) *store { + return &store{ + es: eventstore.New(client), + } +} diff --git a/backend/repository/event/user.go b/backend/repository/event/user.go index 88ba64240c..90ea9b15c7 100644 --- a/backend/repository/event/user.go +++ b/backend/repository/event/user.go @@ -4,30 +4,12 @@ import ( "context" "github.com/zitadel/zitadel/backend/repository" - "github.com/zitadel/zitadel/backend/storage/eventstore" ) -var _ repository.UserRepository = (*User)(nil) - -type User struct { - *eventstore.Eventstore - - next repository.UserRepository -} - -func NewUser(eventstore *eventstore.Eventstore, next repository.UserRepository) *User { - return &User{next: next, Eventstore: eventstore} -} - -func (i *User) ByID(ctx context.Context, id string) (*repository.User, error) { - return i.next.ByID(ctx, id) -} - -func (i *User) Create(ctx context.Context, user *repository.User) error { - err := i.next.Create(ctx, user) +func (s *store) CreateUser(ctx context.Context, user *repository.User) (*repository.User, error) { + err := s.es.Push(ctx, user) if err != nil { - return err + return nil, err } - - return i.Push(ctx, user) + return user, nil } diff --git a/backend/repository/orchestrate/handler/handle.go b/backend/repository/orchestrate/handler/handle.go new file mode 100644 index 0000000000..11278d289a --- /dev/null +++ b/backend/repository/orchestrate/handler/handle.go @@ -0,0 +1,64 @@ +package handler + +import "context" + +// Handler is a function that handles the request. +type Handler[Req, Res any] func(ctx context.Context, request Req) (res Res, err error) + +// Decorator is a function that decorates the handle function. +type Decorator[Req, Res any] func(ctx context.Context, request Req, handle Handler[Req, Res]) (res Res, err error) + +// Chain chains the handle function with the next handler. +// The next handler is called after the handle function. +func Chain[Req, Res any](handle Handler[Req, Res], next Handler[Res, Res]) Handler[Req, Res] { + return func(ctx context.Context, request Req) (res Res, err error) { + res, err = handle(ctx, request) + if err != nil { + return res, err + } + return next(ctx, res) + } +} + +// Decorate decorates the handle function with the decorate function. +// The decorate function is called before the handle function. +func Decorate[Req, Res any](handle Handler[Req, Res], decorate Decorator[Req, Res]) Handler[Req, Res] { + return func(ctx context.Context, request Req) (res Res, err error) { + return decorate(ctx, request, handle) + } +} + +// Decorates decorates the handle function with the decorate function. +// The decorate function is called before the handle function. +func Decorates[Req, Res any](handle Handler[Req, Res], decorates ...Decorator[Req, Res]) Handler[Req, Res] { + return func(ctx context.Context, request Req) (res Res, err error) { + for _, decorate := range decorates { + handle = Decorate(handle, decorate) + } + return handle(ctx, request) + } +} + +// SkipNext skips the next handler if the handle function returns a non-nil response. +func SkipNext[Req, Res any](handle Handler[Req, Res], next Handler[Req, Res]) Handler[Req, Res] { + return func(ctx context.Context, request Req) (res Res, err error) { + var empty Res + res, err = handle(ctx, request) + // TODO: does this work? + if any(res) == any(empty) || err != nil { + return res, err + } + return next(ctx, request) + } +} + +// SkipNilHandler skips the handle function if the handler is nil. +// The function is safe to call with nil handler. +func SkipNilHandler[O, R any](handler *O, handle Handler[R, R]) Handler[R, R] { + return func(ctx context.Context, request R) (res R, err error) { + if handler == nil { + return request, nil + } + return handle(ctx, request) + } +} diff --git a/backend/repository/orchestrate/handler/request_handler.go b/backend/repository/orchestrate/handler/request_handler.go deleted file mode 100644 index 88519f1eb3..0000000000 --- a/backend/repository/orchestrate/handler/request_handler.go +++ /dev/null @@ -1,26 +0,0 @@ -package handler - -import "context" - -type Handle[Req, Res any] func(ctx context.Context, request Req) (res Res, err error) - -type Decorate[Req, Res any] func(ctx context.Context, request Req, handle Handle[Req, Res]) (res Res, err error) - -func NewChained[Req, Res any](handle Handle[Req, Res], next Handle[Res, Res]) Handle[Req, Res] { - return func(ctx context.Context, request Req) (res Res, err error) { - res, err = handle(ctx, request) - if err != nil { - return res, err - } - if next == nil { - return res, nil - } - return next(ctx, res) - } -} - -func NewDecorated[Req, Res any](decorate Decorate[Req, Res], handle Handle[Req, Res]) Handle[Req, Res] { - return func(ctx context.Context, request Req) (res Res, err error) { - return decorate(ctx, request, handle) - } -} diff --git a/backend/repository/orchestrate/instance.go b/backend/repository/orchestrate/instance.go index 13580d9c95..f2ba338269 100644 --- a/backend/repository/orchestrate/instance.go +++ b/backend/repository/orchestrate/instance.go @@ -33,40 +33,51 @@ func (i *instance) apply(o Option) { } func (i *instance) SetUp(ctx context.Context, tx database.Transaction, instance *repository.Instance) (*repository.Instance, error) { - return handler.NewChained( - handler.NewDecorated( - traced.DecorateHandle[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.sql.SetUp")), - sql.SetUpInstance(tx), - ), - handler.NewChained( - handler.NewDecorated( - traced.DecorateHandle[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.event.SetUp")), - event.SetUpInstanceWithout(tx), + return traced.Wrap(i.tracer, "instance.SetUp", + handler.Chain( + handler.Decorates( + sql.Execute(tx).CreateInstance, + traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.sql.SetUp")), + logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.sql.SetUp"), ), - handler.NewDecorated( - traced.DecorateHandle[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.cache.SetUp")), - cache.SetUpInstanceWithout(i.cache), + handler.Chain( + handler.Decorates( + event.Store(tx).CreateInstance, + traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.event.SetUp")), + logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.event.SetUp"), + ), + handler.Decorates( + handler.SkipNilHandler(i.cache, i.cache.Set), + traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.cache.SetUp")), + logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.cache.SetUp"), + ), ), ), )(ctx, instance) } func (i *instance) ByID(ctx context.Context, querier database.Querier, id string) (*repository.Instance, error) { - return traced.Wrap(i.tracer, "instance.byID", - logged.Wrap(i.logger, "instance.byID", - cache.ForInstanceByID(i.cache, - sql.InstanceByID(querier), + return handler.SkipNext( + i.cache.ByID, + handler.Chain( + handler.Decorate( + sql.Query(querier).InstanceByID, + traced.Decorate[string, *repository.Instance](i.tracer, tracing.WithSpanName("instance.sql.ByID")), ), + handler.SkipNilHandler(i.cache, i.cache.Set), ), )(ctx, id) } func (i *instance) ByDomain(ctx context.Context, querier database.Querier, domain string) (*repository.Instance, error) { - return traced.Wrap(i.tracer, "instance.byDomain", - logged.Wrap(i.logger, "instance.byDomain", - cache.ForInstanceByDomain(i.cache, - sql.InstanceByDomain(querier), + return handler.SkipNext( + i.cache.ByDomain, + handler.Chain( + handler.Decorate( + sql.Query(querier).InstanceByDomain, + traced.Decorate[string, *repository.Instance](i.tracer, tracing.WithSpanName("instance.sql.ByDomain")), ), + handler.SkipNilHandler(i.cache, i.cache.Set), ), )(ctx, domain) } diff --git a/backend/repository/orchestrate/instance_test.go b/backend/repository/orchestrate/instance_test.go new file mode 100644 index 0000000000..0b80d7adb5 --- /dev/null +++ b/backend/repository/orchestrate/instance_test.go @@ -0,0 +1,75 @@ +package orchestrate + +import ( + "context" + "log/slog" + "os" + "reflect" + "testing" + + "github.com/zitadel/zitadel/backend/repository" + "github.com/zitadel/zitadel/backend/repository/cache" + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/backend/storage/database/mock" + "github.com/zitadel/zitadel/backend/telemetry/logging" + "github.com/zitadel/zitadel/backend/telemetry/tracing" +) + +func Test_instance_SetUp(t *testing.T) { + type fields struct { + options options + cache *cache.Instance + } + type args struct { + ctx context.Context + tx database.Transaction + instance *repository.Instance + } + tests := []struct { + name string + fields fields + args args + want *repository.Instance + wantErr bool + }{ + { + name: "simple", + fields: fields{ + options: options{ + tracer: tracing.NewTracer("test"), + logger: logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))), + }, + cache: cache.NewInstance(), + }, + args: args{ + ctx: context.Background(), + tx: mock.NewTransaction(), + instance: &repository.Instance{ + ID: "ID", + Name: "Name", + }, + }, + want: &repository.Instance{ + ID: "ID", + Name: "Name", + }, + wantErr: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + i := &instance{ + options: tt.fields.options, + cache: tt.fields.cache, + } + got, err := i.SetUp(tt.args.ctx, tt.args.tx, tt.args.instance) + if (err != nil) != tt.wantErr { + t.Errorf("instance.SetUp() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("instance.SetUp() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/backend/repository/orchestrate/user.go b/backend/repository/orchestrate/user.go new file mode 100644 index 0000000000..937f0bc83b --- /dev/null +++ b/backend/repository/orchestrate/user.go @@ -0,0 +1,60 @@ +package orchestrate + +import ( + "context" + + "github.com/zitadel/zitadel/backend/repository" + "github.com/zitadel/zitadel/backend/repository/cache" + "github.com/zitadel/zitadel/backend/repository/event" + "github.com/zitadel/zitadel/backend/repository/orchestrate/handler" + "github.com/zitadel/zitadel/backend/repository/sql" + "github.com/zitadel/zitadel/backend/repository/telemetry/traced" + "github.com/zitadel/zitadel/backend/storage/database" + "github.com/zitadel/zitadel/backend/telemetry/tracing" +) + +type user struct { + options + + cache *cache.User +} + +func User(opts ...Option) *user { + i := new(user) + for _, opt := range opts { + opt(&i.options) + } + return i +} + +func (i *user) apply(o Option) { + o(&i.options) +} + +func (i *user) Create(ctx context.Context, tx database.Transaction, user *repository.User) (*repository.User, error) { + return traced.Wrap(i.tracer, "user.Create", + handler.Chain( + handler.Decorate( + sql.Execute(tx).CreateUser, + traced.Decorate[*repository.User, *repository.User](i.tracer, tracing.WithSpanName("user.sql.Create")), + ), + handler.Decorate( + event.Store(tx).CreateUser, + traced.Decorate[*repository.User, *repository.User](i.tracer, tracing.WithSpanName("user.event.Create")), + ), + ), + )(ctx, user) +} + +func (i *user) ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error) { + return handler.SkipNext( + i.cache.ByID, + handler.Chain( + handler.Decorate( + sql.Query(querier).UserByID, + traced.Decorate[string, *repository.User](i.tracer, tracing.WithSpanName("user.sql.ByID")), + ), + handler.SkipNilHandler(i.cache, i.cache.Set), + ), + )(ctx, id) +} diff --git a/backend/repository/sql/client.go b/backend/repository/sql/client.go new file mode 100644 index 0000000000..860397c0a6 --- /dev/null +++ b/backend/repository/sql/client.go @@ -0,0 +1,21 @@ +package sql + +import ( + "github.com/zitadel/zitadel/backend/storage/database" +) + +type executor[C database.Executor] struct { + client C +} + +func Execute[C database.Executor](client C) *executor[C] { + return &executor[C]{client: client} +} + +type querier[C database.Querier] struct { + client C +} + +func Query[C database.Querier](client C) *querier[C] { + return &querier[C]{client: client} +} diff --git a/backend/repository/sql/instance.go b/backend/repository/sql/instance.go index 0fb3129196..787ddee100 100644 --- a/backend/repository/sql/instance.go +++ b/backend/repository/sql/instance.go @@ -4,42 +4,53 @@ import ( "context" "github.com/zitadel/zitadel/backend/repository" - "github.com/zitadel/zitadel/backend/repository/orchestrate/handler" - "github.com/zitadel/zitadel/backend/storage/database" ) -const instanceByDomainQuery = `SELECT i.id, i.name FROM instances i JOIN instance_domains id ON i.id = id.instance_id WHERE id.domain = $1` - -func InstanceByDomain(client database.Querier) handler.Handle[string, *repository.Instance] { - return func(ctx context.Context, domain string) (*repository.Instance, error) { - row := client.QueryRow(ctx, instanceByDomainQuery, domain) - var instance repository.Instance - if err := row.Scan(&instance.ID, &instance.Name); err != nil { - return nil, err - } - return &instance, nil - } -} - const instanceByIDQuery = `SELECT id, name FROM instances WHERE id = $1` -func InstanceByID(client database.Querier) handler.Handle[string, *repository.Instance] { - return func(ctx context.Context, id string) (*repository.Instance, error) { - row := client.QueryRow(ctx, instanceByIDQuery, id) - var instance repository.Instance - if err := row.Scan(&instance.ID, &instance.Name); err != nil { - return nil, err - } - return &instance, nil +func (q *querier[C]) InstanceByID(ctx context.Context, id string) (*repository.Instance, error) { + row := q.client.QueryRow(ctx, instanceByIDQuery, id) + var instance repository.Instance + if err := row.Scan(&instance.ID, &instance.Name); err != nil { + return nil, err } + return &instance, nil } -func SetUpInstance(tx database.Transaction) handler.Handle[*repository.Instance, *repository.Instance] { - return func(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { - err := tx.Exec(ctx, "INSERT INTO instances (id, name) VALUES ($1, $2)", instance.ID, instance.Name) - if err != nil { +const instanceByDomainQuery = `SELECT i.id, i.name FROM instances i JOIN instance_domains id ON i.id = id.instance_id WHERE id.domain = $1` + +func (q *querier[C]) InstanceByDomain(ctx context.Context, domain string) (*repository.Instance, error) { + row := q.client.QueryRow(ctx, instanceByDomainQuery, domain) + var instance repository.Instance + if err := row.Scan(&instance.ID, &instance.Name); err != nil { + return nil, err + } + return &instance, nil +} + +func (q *querier[C]) ListInstances(ctx context.Context, request *repository.ListRequest) (res []*repository.Instance, err error) { + rows, err := q.client.Query(ctx, "SELECT id, name FROM instances") + if err != nil { + return nil, err + } + defer rows.Close() + for rows.Next() { + var instance repository.Instance + if err = rows.Scan(&instance.ID, &instance.Name); err != nil { return nil, err } - return instance, nil + res = append(res, &instance) } + if err = rows.Err(); err != nil { + return nil, err + } + return res, nil +} + +func (e *executor[C]) CreateInstance(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) { + err := e.client.Exec(ctx, "INSERT INTO instances (id, name) VALUES ($1, $2)", instance.ID, instance.Name) + if err != nil { + return nil, err + } + return instance, nil } diff --git a/backend/repository/sql/user.go b/backend/repository/sql/user.go index 6f5b4c8035..d8f54da191 100644 --- a/backend/repository/sql/user.go +++ b/backend/repository/sql/user.go @@ -4,22 +4,12 @@ import ( "context" "github.com/zitadel/zitadel/backend/repository" - "github.com/zitadel/zitadel/backend/storage/database" ) -func NewUser(client database.QueryExecutor) repository.UserRepository { - return &User{client: client} -} - -type User struct { - client database.QueryExecutor -} - const userByIDQuery = `SELECT id, username FROM users WHERE id = $1` -// ByID implements [UserRepository]. -func (r *User) ByID(ctx context.Context, id string) (*repository.User, error) { - row := r.client.QueryRow(ctx, userByIDQuery, id) +func (q *querier[C]) UserByID(ctx context.Context, id string) (res *repository.User, err error) { + row := q.client.QueryRow(ctx, userByIDQuery, id) var user repository.User if err := row.Scan(&user.ID, &user.Username); err != nil { return nil, err @@ -27,7 +17,10 @@ func (r *User) ByID(ctx context.Context, id string) (*repository.User, error) { return &user, nil } -// Create implements [UserRepository]. -func (r *User) Create(ctx context.Context, user *repository.User) error { - return r.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username) +func (e *executor[C]) CreateUser(ctx context.Context, user *repository.User) (res *repository.User, err error) { + err = e.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username) + if err != nil { + return nil, err + } + return user, nil } diff --git a/backend/repository/telemetry/logged/global.go b/backend/repository/telemetry/logged/global.go index cd9ca5307f..baa945b36f 100644 --- a/backend/repository/telemetry/logged/global.go +++ b/backend/repository/telemetry/logged/global.go @@ -10,7 +10,7 @@ import ( // Wrap decorates the given handle function with logging. // The function is safe to call with nil logger. -func Wrap[Req, Res any](logger *logging.Logger, name string, handle handler.Handle[Req, Res]) handler.Handle[Req, Res] { +func Wrap[Req, Res any](logger *logging.Logger, name string, handle handler.Handler[Req, Res]) handler.Handler[Req, Res] { if logger == nil { return handle } @@ -21,23 +21,11 @@ func Wrap[Req, Res any](logger *logging.Logger, name string, handle handler.Hand } } -func WrapInside(logger *logging.Logger, name string) func(ctx context.Context, fn func(context.Context) error) { - logger = logger.With(slog.String("handler", name)) - return func(ctx context.Context, fn func(context.Context) error) { - logger.Debug("execute") - var err error - defer func() { - if err != nil { - logger.Error("failed", slog.String("cause", err.Error())) - } - logger.Debug("done") - }() - err = fn(ctx) - } -} - -func DecorateHandle[Req, Res any](logger *logging.Logger, handle func(context.Context, Req) (Res, error)) func(ctx context.Context, r Req) (_ Res, err error) { - return func(ctx context.Context, r Req) (_ Res, err error) { +// Decorate decorates the given handle function with logging. +// The function is safe to call with nil logger. +func Decorate[Req, Res any](logger *logging.Logger, name string) handler.Decorator[Req, Res] { + logger = logger.With("handler", name) + return func(ctx context.Context, request Req, handle handler.Handler[Req, Res]) (res Res, err error) { logger.DebugContext(ctx, "execute") defer func() { if err != nil { @@ -45,24 +33,6 @@ func DecorateHandle[Req, Res any](logger *logging.Logger, handle func(context.Co } logger.DebugContext(ctx, "done") }() - return handle(ctx, r) + return handle(ctx, request) } } - -// // Handler wraps the given handle function with logging. -// // The function is safe to call with nil logger. -// func Handler[Req, Res any, H handler.Handle[Req, Res]](logger *logging.Logger, name string, handle H) *handler.Handler[Req, Res, H] { -// return &handler.Handler[Req, Res, H]{ -// Handle: Wrap(logger, name, handle), -// } -// } - -// // Chained wraps the given handle function with logging. -// // The function is safe to call with nil logger. -// // The next handler is called after the handle function. -// func Chained[Req, Res any, H, N handler.Handle[Req, Res]](logger *logging.Logger, name string, handle H, next N) *handler.Chained[Req, Res, H, N] { -// return handler.NewChained( -// Wrap(logger, name, handle), -// next, -// ) -// } diff --git a/backend/repository/telemetry/traced/global.go b/backend/repository/telemetry/traced/global.go index e256080e29..a4902a77f4 100644 --- a/backend/repository/telemetry/traced/global.go +++ b/backend/repository/telemetry/traced/global.go @@ -2,6 +2,7 @@ package traced import ( "context" + "log" "github.com/zitadel/zitadel/backend/repository/orchestrate/handler" "github.com/zitadel/zitadel/backend/telemetry/tracing" @@ -9,7 +10,7 @@ import ( // Wrap decorates the given handle function with tracing. // The function is safe to call with nil tracer. -func Wrap[Req, Res any](tracer *tracing.Tracer, name string, handle handler.Handle[Req, Res]) handler.Handle[Req, Res] { +func Wrap[Req, Res any](tracer *tracing.Tracer, name string, handle handler.Handler[Req, Res]) handler.Handler[Req, Res] { if tracer == nil { return handle } @@ -28,50 +29,18 @@ func Wrap[Req, Res any](tracer *tracing.Tracer, name string, handle handler.Hand } } -func WrapInside(tracer *tracing.Tracer, name string) func(ctx context.Context, fn func() error) { - return func(ctx context.Context, fn func() error) { - var err error - _, span := tracer.Start( - ctx, - name, - ) - defer func() { - if err != nil { - span.RecordError(err) - } - span.End() - }() - err = fn() - } -} - -func DecorateHandle[Req, Res any](tracer *tracing.Tracer, opts ...tracing.DecorateOption) handler.Decorate[Req, Res] { - return func(ctx context.Context, r Req, handle handler.Handle[Req, Res]) (_ Res, err error) { +// Decorate decorates the given handle function with tracing. +// The function is safe to call with nil tracer. +func Decorate[Req, Res any](tracer *tracing.Tracer, opts ...tracing.DecorateOption) handler.Decorator[Req, Res] { + return func(ctx context.Context, r Req, handle handler.Handler[Req, Res]) (_ Res, err error) { o := new(tracing.DecorateOptions) for _, opt := range opts { opt(o) } + log.Println("trace") - ctx = o.Start(ctx, tracer) - defer o.End(err) + ctx, end := o.Start(ctx, tracer) + defer end(err) return handle(ctx, r) } } - -// // Handler wraps the given handle function with tracing. -// // The function is safe to call with nil logger. -// func Handler[Req, Res any, H handler.Handle[Req, Res]](tracer *tracing.Tracer, name string, handle H) *handler.Handler[Req, Res, H] { -// return &handler.Handler[Req, Res, H]{ -// Handle: Wrap(tracer, name, handle), -// } -// } - -// // Chained wraps the given handle function with tracing. -// // The function is safe to call with nil logger. -// // The next handler is called after the handle function. -// func Chained[Req, Res any, H, N handler.Handle[Req, Res]](tracer *tracing.Tracer, name string, handle H, next N) *handler.Chained[Req, Res, H, N] { -// return handler.NewChained( -// Wrap(tracer, name, handle), -// next, -// ) -// } diff --git a/backend/repository/user.go b/backend/repository/user.go index 53dd417ff3..3a5f10d4c3 100644 --- a/backend/repository/user.go +++ b/backend/repository/user.go @@ -1,12 +1,5 @@ package repository -import "context" - -type UserRepository interface { - Create(ctx context.Context, user *User) error - ByID(ctx context.Context, id string) (*User, error) -} - type User struct { ID string Username string diff --git a/backend/storage/cache/gomap/map.go b/backend/storage/cache/gomap/map.go index 26e9eaa772..608e636bd1 100644 --- a/backend/storage/cache/gomap/map.go +++ b/backend/storage/cache/gomap/map.go @@ -11,6 +11,13 @@ type Map[K comparable, V any] struct { items map[K]V } +func New[K comparable, V any]() *Map[K, V] { + return &Map[K, V]{ + items: make(map[K]V), + mu: sync.RWMutex{}, + } +} + // Clear implements cache.Cache. func (m *Map[K, V]) Clear() { m.mu.Lock() diff --git a/backend/storage/database/mock/transaction.go b/backend/storage/database/mock/transaction.go new file mode 100644 index 0000000000..ba750324f9 --- /dev/null +++ b/backend/storage/database/mock/transaction.go @@ -0,0 +1,67 @@ +package mock + +import ( + "context" + "errors" + + "github.com/zitadel/zitadel/backend/storage/database" +) + +type Transaction struct { + committed bool + rolledBack bool +} + +func NewTransaction() *Transaction { + return new(Transaction) +} + +// Commit implements [database.Transaction]. +func (t *Transaction) Commit(ctx context.Context) error { + if t.hasEnded() { + return errors.New("transaction already committed or rolled back") + } + t.committed = true + return nil +} + +// End implements [database.Transaction]. +func (t *Transaction) End(ctx context.Context, err error) error { + if t.hasEnded() { + return errors.New("transaction already committed or rolled back") + } + if err != nil { + return t.Rollback(ctx) + } + return t.Commit(ctx) +} + +// Exec implements [database.Transaction]. +func (t *Transaction) Exec(ctx context.Context, sql string, args ...any) error { + return nil +} + +// Query implements [database.Transaction]. +func (t *Transaction) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + return nil, nil +} + +// QueryRow implements [database.Transaction]. +func (t *Transaction) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return nil +} + +// Rollback implements [database.Transaction]. +func (t *Transaction) Rollback(ctx context.Context) error { + if t.hasEnded() { + return errors.New("transaction already committed or rolled back") + } + t.rolledBack = true + return nil +} + +var _ database.Transaction = (*Transaction)(nil) + +func (t *Transaction) hasEnded() bool { + return t.committed || t.rolledBack +} diff --git a/backend/telemetry/logging/logger.go b/backend/telemetry/logging/logger.go index 51fa1fb247..c452d1021f 100644 --- a/backend/telemetry/logging/logger.go +++ b/backend/telemetry/logging/logger.go @@ -6,6 +6,10 @@ type Logger struct { *slog.Logger } +func New(l *slog.Logger) *Logger { + return &Logger{Logger: l} +} + func (l *Logger) With(args ...any) *Logger { return &Logger{l.Logger.With(args...)} } diff --git a/backend/telemetry/tracing/tracer.go b/backend/telemetry/tracing/tracer.go index 25b61868df..6c0801fb2c 100644 --- a/backend/telemetry/tracing/tracer.go +++ b/backend/telemetry/tracing/tracer.go @@ -10,8 +10,8 @@ import ( type Tracer struct{ trace.Tracer } -func NewTracer(name string) Tracer { - return Tracer{otel.Tracer(name)} +func NewTracer(name string) *Tracer { + return &Tracer{otel.Tracer(name)} } type DecorateOption func(*DecorateOptions) @@ -43,15 +43,15 @@ func WithSpanEndOptions(opts ...trace.SpanEndOption) DecorateOption { } } -func (o *DecorateOptions) Start(ctx context.Context, tracer *Tracer) context.Context { +func (o *DecorateOptions) Start(ctx context.Context, tracer *Tracer) (context.Context, func(error)) { if o.spanName == "" { o.spanName = functionName() } ctx, o.span = tracer.Tracer.Start(ctx, o.spanName, o.startOpts...) - return ctx + return ctx, o.end } -func (o *DecorateOptions) End(err error) { +func (o *DecorateOptions) end(err error) { o.span.RecordError(err) o.span.End(o.endOpts...) }