This commit is contained in:
adlerhurst
2025-05-08 15:30:06 +02:00
parent 8ba497cb87
commit 47e63ed801
14 changed files with 259 additions and 63 deletions

View File

@@ -2,12 +2,14 @@ package domain
import ( import (
"context" "context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
) )
type Commander interface { type Commander interface {
Execute(ctx context.Context, opts *CommandOpts) (err error) Execute(ctx context.Context, opts *CommandOpts) (err error)
fmt.Stringer
} }
type Invoker interface { type Invoker interface {
@@ -93,13 +95,28 @@ func DefaultOpts(invoker Invoker) *CommandOpts {
} }
} }
type noopInvoker struct { type commandBatch struct {
next Invoker Commands []Commander
} }
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error { func BatchCommands(cmds ...Commander) *commandBatch {
if i.next != nil { return &commandBatch{
return i.next.Invoke(ctx, command, opts) Commands: cmds,
} }
return command.Execute(ctx, opts)
} }
// String implements [Commander].
func (cmd *commandBatch) String() string {
return "commandBatch"
}
func (b *commandBatch) Execute(ctx context.Context, opts *CommandOpts) (err error) {
for _, cmd := range b.Commands {
if err = opts.Invoke(ctx, cmd); err != nil {
return err
}
}
return nil
}
var _ Commander = (*commandBatch)(nil)

View File

@@ -30,9 +30,21 @@ func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserC
return cmd return cmd
} }
// String implements [Commander].
func (cmd *CreateUserCommand) String() string {
return "CreateUserCommand"
}
// Events implements [eventer]. // Events implements [eventer].
func (c *CreateUserCommand) Events() []*eventstore.Event { func (c *CreateUserCommand) Events() []*eventstore.Event {
panic("unimplemented") return []*eventstore.Event{
{
AggregateType: "user",
AggregateID: c.user.ID,
Type: "user.added",
Payload: c.user,
},
}
} }
// Execute implements [Commander]. // Execute implements [Commander].

View File

@@ -15,6 +15,11 @@ type CryptoRepository interface {
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
} }
// String implements [Commander].
func (cmd *generateCodeCommand) String() string {
return "generateCodeCommand"
}
func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx) config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
if err != nil { if err != nil {
@@ -24,3 +29,5 @@ func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts)
cmd.value, cmd.code, err = crypto.NewCode(generator) cmd.value, cmd.code, err = crypto.NewCode(generator)
return err return err
} }
var _ Commander = (*generateCodeCommand)(nil)

View File

@@ -6,6 +6,7 @@ import (
"github.com/zitadel/zitadel/backend/v3/storage/cache" "github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing" "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
"github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/crypto"
) )
@@ -14,13 +15,15 @@ var (
pool database.Pool pool database.Pool
userCodeAlgorithm crypto.EncryptionAlgorithm userCodeAlgorithm crypto.EncryptionAlgorithm
tracer tracing.Tracer tracer tracing.Tracer
logger logging.Logger
userRepo func(database.QueryExecutor) UserRepository userRepo func(database.QueryExecutor) UserRepository
instanceRepo func(database.QueryExecutor) InstanceRepository instanceRepo func(database.QueryExecutor) InstanceRepository
cryptoRepo func(database.QueryExecutor) CryptoRepository cryptoRepo func(database.QueryExecutor) CryptoRepository
orgRepo func(database.QueryExecutor) OrgRepository orgRepo func(database.QueryExecutor) OrgRepository
instanceCache cache.Cache[string, string, *Instance] instanceCache cache.Cache[instanceCacheIndex, string, *Instance]
orgCache cache.Cache[orgCacheIndex, string, *Org]
generateID func() (string, error) = func() (string, error) { generateID func() (string, error) = func() (string, error) {
return strconv.FormatUint(rand.Uint64(), 10), nil return strconv.FormatUint(rand.Uint64(), 10), nil
@@ -39,10 +42,18 @@ func SetTracer(t tracing.Tracer) {
tracer = t tracer = t
} }
func SetLogger(l logging.Logger) {
logger = l
}
func SetUserRepository(repo func(database.QueryExecutor) UserRepository) { func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
userRepo = repo userRepo = repo
} }
func SetOrgRepository(repo func(database.QueryExecutor) OrgRepository) {
orgRepo = repo
}
func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) { func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
instanceRepo = repo instanceRepo = repo
} }

View File

@@ -1,45 +1,65 @@
package domain_test package domain_test
// import ( import (
// "context" "context"
// "testing" "log/slog"
"testing"
// "github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
// "github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
// "go.opentelemetry.io/otel" "go.opentelemetry.io/otel"
// "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" "go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
// sdktrace "go.opentelemetry.io/otel/sdk/trace" sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.uber.org/mock/gomock"
// . "github.com/zitadel/zitadel/backend/v3/domain" . "github.com/zitadel/zitadel/backend/v3/domain"
// "github.com/zitadel/zitadel/backend/v3/storage/database/repository" "github.com/zitadel/zitadel/backend/v3/storage/database/dbmock"
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" "github.com/zitadel/zitadel/backend/v3/storage/database/repository"
// ) "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
// func TestExample(t *testing.T) { func TestExample(t *testing.T) {
// ctx := context.Background() ctx := context.Background()
// // SetPool(pool) ctrl := gomock.NewController(t)
pool := dbmock.NewMockPool(ctrl)
tx := dbmock.NewMockTransaction(ctrl)
// exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint()) pool.EXPECT().Begin(gomock.Any(), gomock.Any()).Return(tx, nil)
// require.NoError(t, err) tx.EXPECT().End(gomock.Any(), gomock.Any()).Return(nil)
// tracerProvider := sdktrace.NewTracerProvider( SetPool(pool)
// sdktrace.WithSyncer(exporter),
// )
// otel.SetTracerProvider(tracerProvider)
// SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
// defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
// SetUserRepository(repository.User) exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
// SetInstanceRepository(repository.Instance) require.NoError(t, err)
// SetCryptoRepository(repository.Crypto) tracerProvider := sdktrace.NewTracerProvider(
sdktrace.WithSyncer(exporter),
)
otel.SetTracerProvider(tracerProvider)
SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
// t.Run("verified email", func(t *testing.T) { SetLogger(logging.Logger{Logger: slog.Default()})
// err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
// assert.NoError(t, err)
// })
// t.Run("unverified email", func(t *testing.T) { SetUserRepository(repository.UserRepository)
// err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false))) SetOrgRepository(repository.OrgRepository)
// assert.NoError(t, err) // SetInstanceRepository(repository.Instance)
// }) // SetCryptoRepository(repository.Crypto)
// }
t.Run("create org", func(t *testing.T) {
org := NewAddOrgCommand("testorg", NewAddMemberCommand("testuser", "ORG_OWNER"))
user := NewCreateHumanCommand("testuser")
err := Invoke(ctx, BatchCommands(org, user))
assert.NoError(t, err)
})
t.Run("verified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
assert.NoError(t, err)
})
t.Run("unverified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
assert.NoError(t, err)
})
}

View File

@@ -19,6 +19,11 @@ func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedComma
} }
} }
// String implements [Commander].
func (cmd *EmailVerifiedCommand) String() string {
return "EmailVerifiedCommand"
}
var ( var (
_ Commander = (*EmailVerifiedCommand)(nil) _ Commander = (*EmailVerifiedCommand)(nil)
_ SetEmailOpt = (*EmailVerifiedCommand)(nil) _ SetEmailOpt = (*EmailVerifiedCommand)(nil)
@@ -57,6 +62,11 @@ func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
} }
} }
// String implements [Commander].
func (cmd *SendCodeCommand) String() string {
return "SendCodeCommand"
}
// Execute implements [Commander] // Execute implements [Commander]
func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil { if err := cmd.ensureEmail(ctx, opts); err != nil {
@@ -122,6 +132,11 @@ func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
} }
} }
// String implements [Commander].
func (cmd *ReturnCodeCommand) String() string {
return "ReturnCodeCommand"
}
// Execute implements [Commander] // Execute implements [Commander]
func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil { if err := cmd.ensureEmail(ctx, opts); err != nil {

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"time" "time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
) )
@@ -15,12 +16,23 @@ type Instance struct {
DeletedAt time.Time `json:"-"` DeletedAt time.Time `json:"-"`
} }
type instanceCacheIndex uint8
const (
instanceCacheIndexUndefined instanceCacheIndex = iota
instanceCacheIndexID
)
// Keys implements the [cache.Entry]. // Keys implements the [cache.Entry].
func (i *Instance) Keys(index string) (key []string) { func (i *Instance) Keys(index instanceCacheIndex) (key []string) {
// TODO: Return the correct keys for the instance cache, e.g., i.ID, i.Domain if index == instanceCacheIndexID {
return []string{} return []string{i.ID}
}
return nil
} }
var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil)
type instanceColumns interface { type instanceColumns interface {
// IDColumn returns the column for the id field. // IDColumn returns the column for the id field.
IDColumn() database.Column IDColumn() database.Column

View File

@@ -7,12 +7,11 @@ import (
"github.com/zitadel/zitadel/backend/v3/storage/eventstore" "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
) )
var defaultInvoker = newEventStoreInvoker(newTraceInvoker(nil))
func Invoke(ctx context.Context, cmd Commander) error { func Invoke(ctx context.Context, cmd Commander) error {
invoker := newEventStoreInvoker(newTraceInvoker(nil)) invoker := newEventStoreInvoker(newLoggingInvoker(newTraceInvoker(nil)))
opts := &CommandOpts{ opts := &CommandOpts{
Invoker: invoker.collector, Invoker: invoker.collector,
DB: pool,
} }
return invoker.Invoke(ctx, cmd, opts) return invoker.Invoke(ctx, cmd, opts)
} }
@@ -60,14 +59,9 @@ func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *Co
i.events = append(i.events, e.Events()...) i.events = append(i.events, e.Events()...)
} }
if i.next != nil { if i.next != nil {
err = i.next.Invoke(ctx, command, opts) return i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
} }
if err != nil { return command.Execute(ctx, opts)
return err
}
return nil
} }
type traceInvoker struct { type traceInvoker struct {
@@ -80,15 +74,71 @@ func newTraceInvoker(next Invoker) *traceInvoker {
func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command)) ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
defer span.End() defer func() {
if err != nil {
span.RecordError(err)
}
span.End()
}()
if i.next != nil {
return i.next.Invoke(ctx, command, opts)
}
return command.Execute(ctx, opts)
}
type loggingInvoker struct {
next Invoker
}
func newLoggingInvoker(next Invoker) *loggingInvoker {
return &loggingInvoker{next: next}
}
func (i *loggingInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
logger.InfoContext(ctx, "Invoking command", "command", command.String())
if i.next != nil { if i.next != nil {
err = i.next.Invoke(ctx, command, opts) err = i.next.Invoke(ctx, command, opts)
} else { } else {
err = command.Execute(ctx, opts) err = command.Execute(ctx, opts)
} }
if err != nil { if err != nil {
span.RecordError(err) logger.ErrorContext(ctx, "Command invocation failed", "command", command.String(), "error", err)
return err
}
logger.InfoContext(ctx, "Command invocation succeeded", "command", command.String())
return nil
}
type noopInvoker struct {
next Invoker
}
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
if i.next != nil {
return i.next.Invoke(ctx, command, opts)
}
return command.Execute(ctx, opts)
}
type cacheInvoker struct {
next Invoker
}
type cacher interface {
Cache(opts *CommandOpts)
}
func (i *cacheInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
if c, ok := command.(cacher); ok {
c.Cache(opts)
}
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
} }
return err return err
} }

View File

@@ -4,6 +4,7 @@ import (
"context" "context"
"time" "time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
) )
@@ -24,6 +25,23 @@ type Org struct {
UpdatedAt time.Time `json:"updatedAt"` UpdatedAt time.Time `json:"updatedAt"`
} }
type orgCacheIndex uint8
const (
orgCacheIndexUndefined orgCacheIndex = iota
orgCacheIndexID
)
// Keys implements [cache.Entry].
func (o *Org) Keys(index orgCacheIndex) (key []string) {
if index == orgCacheIndexID {
return []string{o.ID}
}
return nil
}
var _ cache.Entry[orgCacheIndex, string] = (*Org)(nil)
type orgColumns interface { type orgColumns interface {
// InstanceIDColumn returns the column for the instance id field. // InstanceIDColumn returns the column for the instance id field.
InstanceIDColumn() database.Column InstanceIDColumn() database.Column

View File

@@ -19,6 +19,11 @@ func NewAddOrgCommand(name string, admins ...*AddMemberCommand) *AddOrgCommand {
} }
} }
// String implements [Commander].
func (cmd *AddOrgCommand) String() string {
return "AddOrgCommand"
}
// Execute implements Commander. // Execute implements Commander.
func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) { func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
if len(cmd.Admins) == 0 { if len(cmd.Admins) == 0 {
@@ -48,6 +53,11 @@ func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err e
} }
} }
orgCache.Set(ctx, &Org{
ID: cmd.ID,
Name: cmd.Name,
})
return nil return nil
} }
@@ -82,6 +92,18 @@ type AddMemberCommand struct {
Roles []string `json:"roles"` Roles []string `json:"roles"`
} }
func NewAddMemberCommand(userID string, roles ...string) *AddMemberCommand {
return &AddMemberCommand{
UserID: userID,
Roles: roles,
}
}
// String implements [Commander].
func (cmd *AddMemberCommand) String() string {
return "AddMemberCommand"
}
// Execute implements Commander. // Execute implements Commander.
func (a *AddMemberCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) { func (a *AddMemberCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
close, err := opts.EnsureTx(ctx) close, err := opts.EnsureTx(ctx)

View File

@@ -31,6 +31,11 @@ func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *Set
return cmd return cmd
} }
// String implements [Commander].
func (cmd *SetEmailCommand) String() string {
return "SetEmailCommand"
}
func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error { func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
close, err := opts.EnsureTx(ctx) close, err := opts.EnsureTx(ctx)
if err != nil { if err != nil {

View File

@@ -2,6 +2,7 @@ package repository
import ( import (
"context" "context"
"time"
"github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database" "github.com/zitadel/zitadel/backend/v3/storage/database"
@@ -25,12 +26,14 @@ func OrgRepository(client database.QueryExecutor) domain.OrgRepository {
// Create implements [domain.OrgRepository]. // Create implements [domain.OrgRepository].
func (o *org) Create(ctx context.Context, org *domain.Org) error { func (o *org) Create(ctx context.Context, org *domain.Org) error {
panic("unimplemented") org.CreatedAt = time.Now()
org.UpdatedAt = org.CreatedAt
return nil
} }
// Delete implements [domain.OrgRepository]. // Delete implements [domain.OrgRepository].
func (o *org) Delete(ctx context.Context, condition database.Condition) error { func (o *org) Delete(ctx context.Context, condition database.Condition) error {
panic("unimplemented") return nil
} }
// Get implements [domain.OrgRepository]. // Get implements [domain.OrgRepository].

View File

@@ -12,17 +12,17 @@ type orgMember struct {
// AddMember implements [domain.MemberRepository]. // AddMember implements [domain.MemberRepository].
func (o *orgMember) AddMember(ctx context.Context, orgID string, userID string, roles []string) error { func (o *orgMember) AddMember(ctx context.Context, orgID string, userID string, roles []string) error {
panic("unimplemented") return nil
} }
// RemoveMember implements [domain.MemberRepository]. // RemoveMember implements [domain.MemberRepository].
func (o *orgMember) RemoveMember(ctx context.Context, orgID string, userID string) error { func (o *orgMember) RemoveMember(ctx context.Context, orgID string, userID string) error {
panic("unimplemented") return nil
} }
// SetMemberRoles implements [domain.MemberRepository]. // SetMemberRoles implements [domain.MemberRepository].
func (o *orgMember) SetMemberRoles(ctx context.Context, orgID string, userID string, roles []string) error { func (o *orgMember) SetMemberRoles(ctx context.Context, orgID string, userID string, roles []string) error {
panic("unimplemented") return nil
} }
var _ domain.MemberRepository = (*orgMember)(nil) var _ domain.MemberRepository = (*orgMember)(nil)

View File

@@ -5,3 +5,7 @@ import "log/slog"
type Logger struct { type Logger struct {
*slog.Logger *slog.Logger
} }
func NewLogger(logger *slog.Logger) *Logger {
return &Logger{Logger: logger}
}