works like a charm

This commit is contained in:
adlerhurst
2025-03-16 09:10:57 +01:00
parent fa02beb883
commit 599a600bfb
15 changed files with 169 additions and 41 deletions

View File

@@ -20,7 +20,7 @@ type Instance struct {
type instanceOrchestrator interface {
ByID(ctx context.Context, querier database.Querier, id string) (*repository.Instance, error)
ByDomain(ctx context.Context, querier database.Querier, domain string) (*repository.Instance, error)
SetUp(ctx context.Context, tx database.Transaction, instance *repository.Instance) (*repository.Instance, error)
Create(ctx context.Context, tx database.Transaction, instance *repository.Instance) (*repository.Instance, error)
}
func NewInstance(db database.Pool, tracer *tracing.Tracer, logger *logging.Logger) *Instance {
@@ -54,7 +54,7 @@ func (b *Instance) SetUp(ctx context.Context, request *SetUpInstance) (err error
defer func() {
err = tx.End(ctx, err)
}()
_, err = b.instance.SetUp(ctx, tx, request.Instance)
_, err = b.instance.Create(ctx, tx, request.Instance)
if err != nil {
return err
}

View File

@@ -8,6 +8,6 @@ import (
)
type userOrchestrator interface {
Create(ctx context.Context, client database.Transaction, user *repository.User) (*repository.User, error)
Create(ctx context.Context, tx database.Transaction, user *repository.User) (*repository.User, error)
ByID(ctx context.Context, querier database.Querier, id string) (*repository.User, error)
}

View File

@@ -2,6 +2,7 @@ package cache
import (
"context"
"log"
"sync"
"github.com/zitadel/zitadel/backend/repository"
@@ -24,6 +25,7 @@ func NewInstance() *Instance {
}
func (i *Instance) Set(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) {
log.Println("cache.instance.set")
i.set(instance, "")
return instance, nil
}
@@ -31,6 +33,7 @@ func (i *Instance) Set(ctx context.Context, instance *repository.Instance) (*rep
func (i *Instance) ByID(ctx context.Context, id string) (*repository.Instance, error) {
i.mu.RLock()
defer i.mu.RUnlock()
log.Println("cache.instance.byID")
instance, _ := i.byID.Get(id)
return instance, nil
}
@@ -38,6 +41,7 @@ func (i *Instance) ByID(ctx context.Context, id string) (*repository.Instance, e
func (i *Instance) ByDomain(ctx context.Context, domain string) (*repository.Instance, error) {
i.mu.RLock()
defer i.mu.RUnlock()
log.Println("cache.instance.byDomain")
instance, _ := i.byDomain.Get(domain)
return instance, nil
}

View File

@@ -2,6 +2,7 @@ package cache
import (
"context"
"log"
"github.com/zitadel/zitadel/backend/repository"
"github.com/zitadel/zitadel/backend/storage/cache"
@@ -20,12 +21,14 @@ func NewUser() *User {
// ByID implements repository.UserRepository.
func (u *User) ByID(ctx context.Context, id string) (*repository.User, error) {
log.Println("cache.user.byid")
user, _ := u.Get(id)
return user, nil
}
func (u *User) Set(ctx context.Context, user *repository.User) (*repository.User, error) {
log.Println("cache.user.set")
u.set(user)
return user, nil
}

View File

@@ -2,11 +2,13 @@ package event
import (
"context"
"log"
"github.com/zitadel/zitadel/backend/repository"
)
func (s *store) CreateInstance(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) {
log.Println("event.instance.create")
err := s.es.Push(ctx, instance)
if err != nil {
return nil, err

View File

@@ -2,11 +2,13 @@ package event
import (
"context"
"log"
"github.com/zitadel/zitadel/backend/repository"
)
func (s *store) CreateUser(ctx context.Context, user *repository.User) (*repository.User, error) {
log.Println("event.user.create")
err := s.es.Push(ctx, user)
if err != nil {
return nil, err

View File

@@ -1,6 +1,8 @@
package handler
import "context"
import (
"context"
)
// Handler is a function that handles the request.
type Handler[Req, Res any] func(ctx context.Context, request Req) (res Res, err error)
@@ -20,6 +22,15 @@ func Chain[Req, Res any](handle Handler[Req, Res], next Handler[Res, Res]) Handl
}
}
func Chains[Req, Res any](handle Handler[Req, Res], nexts ...Handler[Res, Res]) Handler[Req, Res] {
return func(ctx context.Context, request Req) (res Res, err error) {
for _, next := range nexts {
handle = Chain(handle, next)
}
return handle(ctx, request)
}
}
// Decorate decorates the handle function with the decorate function.
// The decorate function is called before the handle function.
func Decorate[Req, Res any](handle Handler[Req, Res], decorate Decorator[Req, Res]) Handler[Req, Res] {
@@ -28,12 +39,12 @@ func Decorate[Req, Res any](handle Handler[Req, Res], decorate Decorator[Req, Re
}
}
// Decorates decorates the handle function with the decorate function.
// The decorate function is called before the handle function.
// Decorates decorates the handle function with the decorate functions.
// The decorates function is called before the handle function.
func Decorates[Req, Res any](handle Handler[Req, Res], decorates ...Decorator[Req, Res]) Handler[Req, Res] {
return func(ctx context.Context, request Req) (res Res, err error) {
for _, decorate := range decorates {
handle = Decorate(handle, decorate)
for i := len(decorates) - 1; i >= 0; i-- {
handle = Decorate(handle, decorates[i])
}
return handle(ctx, request)
}

View File

@@ -2,6 +2,7 @@ package orchestrate
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/repository"
"github.com/zitadel/zitadel/backend/repository/cache"
@@ -20,34 +21,52 @@ type instance struct {
cache *cache.Instance
}
func Instance(opts ...Option) *instance {
func Instance(opts ...InstanceConfig) *instance {
i := new(instance)
for _, opt := range opts {
opt(&i.options)
opt.applyInstance(i)
}
return i
}
func (i *instance) apply(o Option) {
func WithInstanceCache(cache *cache.Instance) instanceOption {
return func(i *instance) {
i.cache = cache
}
}
type InstanceConfig interface {
applyInstance(*instance)
}
// instanceOption applies an option to the instance.
type instanceOption func(*instance)
func (io instanceOption) applyInstance(i *instance) {
io(i)
}
func (o Option) applyInstance(i *instance) {
o(&i.options)
}
func (i *instance) SetUp(ctx context.Context, tx database.Transaction, instance *repository.Instance) (*repository.Instance, error) {
func (i *instance) Create(ctx context.Context, tx database.Transaction, instance *repository.Instance) (*repository.Instance, error) {
fmt.Println("----------------")
return traced.Wrap(i.tracer, "instance.SetUp",
handler.Chain(
handler.Chains(
handler.Decorates(
sql.Execute(tx).CreateInstance,
traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.sql.SetUp")),
logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.sql.SetUp"),
),
handler.Chain(
handler.Decorates(
event.Store(tx).CreateInstance,
traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.event.SetUp")),
logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.event.SetUp"),
),
handler.SkipNilHandler(i.cache,
handler.Decorates(
handler.SkipNilHandler(i.cache, i.cache.Set),
i.cache.Set,
traced.Decorate[*repository.Instance, *repository.Instance](i.tracer, tracing.WithSpanName("instance.cache.SetUp")),
logged.Decorate[*repository.Instance, *repository.Instance](i.logger, "instance.cache.SetUp"),
),

View File

@@ -1,4 +1,4 @@
package orchestrate
package orchestrate_test
import (
"context"
@@ -9,6 +9,7 @@ import (
"github.com/zitadel/zitadel/backend/repository"
"github.com/zitadel/zitadel/backend/repository/cache"
"github.com/zitadel/zitadel/backend/repository/orchestrate"
"github.com/zitadel/zitadel/backend/storage/database"
"github.com/zitadel/zitadel/backend/storage/database/mock"
"github.com/zitadel/zitadel/backend/telemetry/logging"
@@ -16,10 +17,6 @@ import (
)
func Test_instance_SetUp(t *testing.T) {
type fields struct {
options options
cache *cache.Instance
}
type args struct {
ctx context.Context
tx database.Transaction
@@ -27,20 +24,73 @@ func Test_instance_SetUp(t *testing.T) {
}
tests := []struct {
name string
fields fields
opts []orchestrate.InstanceConfig
args args
want *repository.Instance
wantErr bool
}{
{
name: "simple",
fields: fields{
options: options{
tracer: tracing.NewTracer("test"),
logger: logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))),
opts: []orchestrate.InstanceConfig{
orchestrate.WithTracer(tracing.NewTracer("test")),
orchestrate.WithLogger(logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))),
orchestrate.WithInstanceCache(cache.NewInstance()),
},
cache: cache.NewInstance(),
args: args{
ctx: context.Background(),
tx: mock.NewTransaction(),
instance: &repository.Instance{
ID: "ID",
Name: "Name",
},
},
want: &repository.Instance{
ID: "ID",
Name: "Name",
},
wantErr: false,
},
{
name: "without cache",
opts: []orchestrate.InstanceConfig{
orchestrate.WithTracer(tracing.NewTracer("test")),
orchestrate.WithLogger(logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))),
},
args: args{
ctx: context.Background(),
tx: mock.NewTransaction(),
instance: &repository.Instance{
ID: "ID",
Name: "Name",
},
},
want: &repository.Instance{
ID: "ID",
Name: "Name",
},
wantErr: false,
},
{
name: "without cache, tracer",
opts: []orchestrate.InstanceConfig{
orchestrate.WithLogger(logging.New(slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug})))),
},
args: args{
ctx: context.Background(),
tx: mock.NewTransaction(),
instance: &repository.Instance{
ID: "ID",
Name: "Name",
},
},
want: &repository.Instance{
ID: "ID",
Name: "Name",
},
wantErr: false,
},
{
name: "without cache, tracer, logger",
args: args{
ctx: context.Background(),
tx: mock.NewTransaction(),
@@ -58,17 +108,14 @@ func Test_instance_SetUp(t *testing.T) {
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
i := &instance{
options: tt.fields.options,
cache: tt.fields.cache,
}
got, err := i.SetUp(tt.args.ctx, tt.args.tx, tt.args.instance)
i := orchestrate.Instance(tt.opts...)
got, err := i.Create(tt.args.ctx, tt.args.tx, tt.args.instance)
if (err != nil) != tt.wantErr {
t.Errorf("instance.SetUp() error = %v, wantErr %v", err, tt.wantErr)
t.Errorf("instance.Create() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !reflect.DeepEqual(got, tt.want) {
t.Errorf("instance.SetUp() = %v, want %v", got, tt.want)
t.Errorf("instance.Create() = %v, want %v", got, tt.want)
}
})
}

View File

@@ -5,6 +5,7 @@ import (
"github.com/zitadel/zitadel/backend/telemetry/tracing"
)
// options are the default options for orchestrators.
type options struct {
tracer *tracing.Tracer
logger *logging.Logger
@@ -23,3 +24,7 @@ func WithLogger(logger *logging.Logger) Option {
o.logger = logger
}
}
func (o Option) apply(opts *options) {
o(opts)
}

View File

@@ -27,7 +27,24 @@ func User(opts ...Option) *user {
return i
}
func (i *user) apply(o Option) {
func WithUserCache(cache *cache.User) userOption {
return func(i *user) {
i.cache = cache
}
}
type UserConfig interface {
applyUser(*user)
}
// userOption applies an option to the user.
type userOption func(*user)
func (io userOption) applyUser(i *user) {
io(i)
}
func (o Option) applyUser(i *user) {
o(&i.options)
}

View File

@@ -2,6 +2,7 @@ package sql
import (
"context"
"log"
"github.com/zitadel/zitadel/backend/repository"
)
@@ -9,6 +10,7 @@ import (
const instanceByIDQuery = `SELECT id, name FROM instances WHERE id = $1`
func (q *querier[C]) InstanceByID(ctx context.Context, id string) (*repository.Instance, error) {
log.Println("sql.instance.byID")
row := q.client.QueryRow(ctx, instanceByIDQuery, id)
var instance repository.Instance
if err := row.Scan(&instance.ID, &instance.Name); err != nil {
@@ -20,6 +22,7 @@ func (q *querier[C]) InstanceByID(ctx context.Context, id string) (*repository.I
const instanceByDomainQuery = `SELECT i.id, i.name FROM instances i JOIN instance_domains id ON i.id = id.instance_id WHERE id.domain = $1`
func (q *querier[C]) InstanceByDomain(ctx context.Context, domain string) (*repository.Instance, error) {
log.Println("sql.instance.byDomain")
row := q.client.QueryRow(ctx, instanceByDomainQuery, domain)
var instance repository.Instance
if err := row.Scan(&instance.ID, &instance.Name); err != nil {
@@ -29,6 +32,7 @@ func (q *querier[C]) InstanceByDomain(ctx context.Context, domain string) (*repo
}
func (q *querier[C]) ListInstances(ctx context.Context, request *repository.ListRequest) (res []*repository.Instance, err error) {
log.Println("sql.instance.list")
rows, err := q.client.Query(ctx, "SELECT id, name FROM instances")
if err != nil {
return nil, err
@@ -48,6 +52,7 @@ func (q *querier[C]) ListInstances(ctx context.Context, request *repository.List
}
func (e *executor[C]) CreateInstance(ctx context.Context, instance *repository.Instance) (*repository.Instance, error) {
log.Println("sql.instance.create")
err := e.client.Exec(ctx, "INSERT INTO instances (id, name) VALUES ($1, $2)", instance.ID, instance.Name)
if err != nil {
return nil, err

View File

@@ -2,6 +2,7 @@ package sql
import (
"context"
"log"
"github.com/zitadel/zitadel/backend/repository"
)
@@ -9,6 +10,7 @@ import (
const userByIDQuery = `SELECT id, username FROM users WHERE id = $1`
func (q *querier[C]) UserByID(ctx context.Context, id string) (res *repository.User, err error) {
log.Println("sql.user.byID")
row := q.client.QueryRow(ctx, userByIDQuery, id)
var user repository.User
if err := row.Scan(&user.ID, &user.Username); err != nil {
@@ -18,6 +20,7 @@ func (q *querier[C]) UserByID(ctx context.Context, id string) (res *repository.U
}
func (e *executor[C]) CreateUser(ctx context.Context, user *repository.User) (res *repository.User, err error) {
log.Println("sql.user.create")
err = e.client.Exec(ctx, "INSERT INTO users (id, username) VALUES ($1, $2)", user.ID, user.Username)
if err != nil {
return nil, err

View File

@@ -2,6 +2,7 @@ package logged
import (
"context"
"log"
"log/slog"
"github.com/zitadel/zitadel/backend/repository/orchestrate/handler"
@@ -14,6 +15,7 @@ func Wrap[Req, Res any](logger *logging.Logger, name string, handle handler.Hand
if logger == nil {
return handle
}
log.Println("log.wrap", name)
return func(ctx context.Context, r Req) (_ Res, err error) {
logger.Debug("execute", slog.String("handler", name))
defer logger.Debug("done", slog.String("handler", name))
@@ -24,9 +26,13 @@ func Wrap[Req, Res any](logger *logging.Logger, name string, handle handler.Hand
// Decorate decorates the given handle function with logging.
// The function is safe to call with nil logger.
func Decorate[Req, Res any](logger *logging.Logger, name string) handler.Decorator[Req, Res] {
logger = logger.With("handler", name)
return func(ctx context.Context, request Req, handle handler.Handler[Req, Res]) (res Res, err error) {
if logger == nil {
return handle(ctx, request)
}
logger = logger.With("handler", name)
logger.DebugContext(ctx, "execute")
log.Println("log.decorate", name)
defer func() {
if err != nil {
logger.ErrorContext(ctx, "failed", slog.String("cause", err.Error()))

View File

@@ -19,6 +19,7 @@ func Wrap[Req, Res any](tracer *tracing.Tracer, name string, handle handler.Hand
ctx,
name,
)
log.Println("trace.wrap", name)
defer func() {
if err != nil {
span.RecordError(err)
@@ -33,6 +34,9 @@ func Wrap[Req, Res any](tracer *tracing.Tracer, name string, handle handler.Hand
// The function is safe to call with nil tracer.
func Decorate[Req, Res any](tracer *tracing.Tracer, opts ...tracing.DecorateOption) handler.Decorator[Req, Res] {
return func(ctx context.Context, r Req, handle handler.Handler[Req, Res]) (_ Res, err error) {
if tracer == nil {
return handle(ctx, r)
}
o := new(tracing.DecorateOptions)
for _, opt := range opts {
opt(o)