This commit is contained in:
adlerhurst
2025-05-08 15:30:06 +02:00
parent 8ba497cb87
commit 47e63ed801
14 changed files with 259 additions and 63 deletions

View File

@@ -2,12 +2,14 @@ package domain
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Commander interface {
Execute(ctx context.Context, opts *CommandOpts) (err error)
fmt.Stringer
}
type Invoker interface {
@@ -93,13 +95,28 @@ func DefaultOpts(invoker Invoker) *CommandOpts {
}
}
type noopInvoker struct {
next Invoker
type commandBatch struct {
Commands []Commander
}
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
if i.next != nil {
return i.next.Invoke(ctx, command, opts)
func BatchCommands(cmds ...Commander) *commandBatch {
return &commandBatch{
Commands: cmds,
}
return command.Execute(ctx, opts)
}
// String implements [Commander].
func (cmd *commandBatch) String() string {
return "commandBatch"
}
func (b *commandBatch) Execute(ctx context.Context, opts *CommandOpts) (err error) {
for _, cmd := range b.Commands {
if err = opts.Invoke(ctx, cmd); err != nil {
return err
}
}
return nil
}
var _ Commander = (*commandBatch)(nil)

View File

@@ -30,9 +30,21 @@ func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserC
return cmd
}
// String implements [Commander].
func (cmd *CreateUserCommand) String() string {
return "CreateUserCommand"
}
// Events implements [eventer].
func (c *CreateUserCommand) Events() []*eventstore.Event {
panic("unimplemented")
return []*eventstore.Event{
{
AggregateType: "user",
AggregateID: c.user.ID,
Type: "user.added",
Payload: c.user,
},
}
}
// Execute implements [Commander].

View File

@@ -15,6 +15,11 @@ type CryptoRepository interface {
GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
}
// String implements [Commander].
func (cmd *generateCodeCommand) String() string {
return "generateCodeCommand"
}
func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
if err != nil {
@@ -24,3 +29,5 @@ func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts)
cmd.value, cmd.code, err = crypto.NewCode(generator)
return err
}
var _ Commander = (*generateCodeCommand)(nil)

View File

@@ -6,6 +6,7 @@ import (
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
"github.com/zitadel/zitadel/internal/crypto"
)
@@ -14,13 +15,15 @@ var (
pool database.Pool
userCodeAlgorithm crypto.EncryptionAlgorithm
tracer tracing.Tracer
logger logging.Logger
userRepo func(database.QueryExecutor) UserRepository
instanceRepo func(database.QueryExecutor) InstanceRepository
cryptoRepo func(database.QueryExecutor) CryptoRepository
orgRepo func(database.QueryExecutor) OrgRepository
instanceCache cache.Cache[string, string, *Instance]
instanceCache cache.Cache[instanceCacheIndex, string, *Instance]
orgCache cache.Cache[orgCacheIndex, string, *Org]
generateID func() (string, error) = func() (string, error) {
return strconv.FormatUint(rand.Uint64(), 10), nil
@@ -39,10 +42,18 @@ func SetTracer(t tracing.Tracer) {
tracer = t
}
func SetLogger(l logging.Logger) {
logger = l
}
func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
userRepo = repo
}
func SetOrgRepository(repo func(database.QueryExecutor) OrgRepository) {
orgRepo = repo
}
func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
instanceRepo = repo
}

View File

@@ -1,45 +1,65 @@
package domain_test
// import (
// "context"
// "testing"
import (
"context"
"log/slog"
"testing"
// "github.com/stretchr/testify/assert"
// "github.com/stretchr/testify/require"
// "go.opentelemetry.io/otel"
// "go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
// sdktrace "go.opentelemetry.io/otel/sdk/trace"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"go.opentelemetry.io/otel"
"go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
sdktrace "go.opentelemetry.io/otel/sdk/trace"
"go.uber.org/mock/gomock"
// . "github.com/zitadel/zitadel/backend/v3/domain"
// "github.com/zitadel/zitadel/backend/v3/storage/database/repository"
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
// )
. "github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database/dbmock"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/backend/v3/telemetry/logging"
"github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
)
// func TestExample(t *testing.T) {
// ctx := context.Background()
func TestExample(t *testing.T) {
ctx := context.Background()
// // SetPool(pool)
ctrl := gomock.NewController(t)
pool := dbmock.NewMockPool(ctrl)
tx := dbmock.NewMockTransaction(ctrl)
// exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
// require.NoError(t, err)
// tracerProvider := sdktrace.NewTracerProvider(
// sdktrace.WithSyncer(exporter),
// )
// otel.SetTracerProvider(tracerProvider)
// SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
// defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
pool.EXPECT().Begin(gomock.Any(), gomock.Any()).Return(tx, nil)
tx.EXPECT().End(gomock.Any(), gomock.Any()).Return(nil)
SetPool(pool)
// SetUserRepository(repository.User)
// SetInstanceRepository(repository.Instance)
// SetCryptoRepository(repository.Crypto)
exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
require.NoError(t, err)
tracerProvider := sdktrace.NewTracerProvider(
sdktrace.WithSyncer(exporter),
)
otel.SetTracerProvider(tracerProvider)
SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
// t.Run("verified email", func(t *testing.T) {
// err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
// assert.NoError(t, err)
// })
SetLogger(logging.Logger{Logger: slog.Default()})
// t.Run("unverified email", func(t *testing.T) {
// err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
// assert.NoError(t, err)
// })
// }
SetUserRepository(repository.UserRepository)
SetOrgRepository(repository.OrgRepository)
// SetInstanceRepository(repository.Instance)
// SetCryptoRepository(repository.Crypto)
t.Run("create org", func(t *testing.T) {
org := NewAddOrgCommand("testorg", NewAddMemberCommand("testuser", "ORG_OWNER"))
user := NewCreateHumanCommand("testuser")
err := Invoke(ctx, BatchCommands(org, user))
assert.NoError(t, err)
})
t.Run("verified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
assert.NoError(t, err)
})
t.Run("unverified email", func(t *testing.T) {
err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
assert.NoError(t, err)
})
}

View File

@@ -19,6 +19,11 @@ func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedComma
}
}
// String implements [Commander].
func (cmd *EmailVerifiedCommand) String() string {
return "EmailVerifiedCommand"
}
var (
_ Commander = (*EmailVerifiedCommand)(nil)
_ SetEmailOpt = (*EmailVerifiedCommand)(nil)
@@ -57,6 +62,11 @@ func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
}
}
// String implements [Commander].
func (cmd *SendCodeCommand) String() string {
return "SendCodeCommand"
}
// Execute implements [Commander]
func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil {
@@ -122,6 +132,11 @@ func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
}
}
// String implements [Commander].
func (cmd *ReturnCodeCommand) String() string {
return "ReturnCodeCommand"
}
// Execute implements [Commander]
func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
if err := cmd.ensureEmail(ctx, opts); err != nil {

View File

@@ -4,6 +4,7 @@ import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
@@ -15,12 +16,23 @@ type Instance struct {
DeletedAt time.Time `json:"-"`
}
type instanceCacheIndex uint8
const (
instanceCacheIndexUndefined instanceCacheIndex = iota
instanceCacheIndexID
)
// Keys implements the [cache.Entry].
func (i *Instance) Keys(index string) (key []string) {
// TODO: Return the correct keys for the instance cache, e.g., i.ID, i.Domain
return []string{}
func (i *Instance) Keys(index instanceCacheIndex) (key []string) {
if index == instanceCacheIndexID {
return []string{i.ID}
}
return nil
}
var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil)
type instanceColumns interface {
// IDColumn returns the column for the id field.
IDColumn() database.Column

View File

@@ -7,12 +7,11 @@ import (
"github.com/zitadel/zitadel/backend/v3/storage/eventstore"
)
var defaultInvoker = newEventStoreInvoker(newTraceInvoker(nil))
func Invoke(ctx context.Context, cmd Commander) error {
invoker := newEventStoreInvoker(newTraceInvoker(nil))
invoker := newEventStoreInvoker(newLoggingInvoker(newTraceInvoker(nil)))
opts := &CommandOpts{
Invoker: invoker.collector,
DB: pool,
}
return invoker.Invoke(ctx, cmd, opts)
}
@@ -60,14 +59,9 @@ func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *Co
i.events = append(i.events, e.Events()...)
}
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
return i.next.Invoke(ctx, command, opts)
}
if err != nil {
return err
}
return nil
return command.Execute(ctx, opts)
}
type traceInvoker struct {
@@ -80,15 +74,71 @@ func newTraceInvoker(next Invoker) *traceInvoker {
func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
defer span.End()
defer func() {
if err != nil {
span.RecordError(err)
}
span.End()
}()
if i.next != nil {
return i.next.Invoke(ctx, command, opts)
}
return command.Execute(ctx, opts)
}
type loggingInvoker struct {
next Invoker
}
func newLoggingInvoker(next Invoker) *loggingInvoker {
return &loggingInvoker{next: next}
}
func (i *loggingInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
logger.InfoContext(ctx, "Invoking command", "command", command.String())
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
}
if err != nil {
span.RecordError(err)
logger.ErrorContext(ctx, "Command invocation failed", "command", command.String(), "error", err)
return err
}
logger.InfoContext(ctx, "Command invocation succeeded", "command", command.String())
return nil
}
type noopInvoker struct {
next Invoker
}
func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
if i.next != nil {
return i.next.Invoke(ctx, command, opts)
}
return command.Execute(ctx, opts)
}
type cacheInvoker struct {
next Invoker
}
type cacher interface {
Cache(opts *CommandOpts)
}
func (i *cacheInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
if c, ok := command.(cacher); ok {
c.Cache(opts)
}
if i.next != nil {
err = i.next.Invoke(ctx, command, opts)
} else {
err = command.Execute(ctx, opts)
}
return err
}

View File

@@ -4,6 +4,7 @@ import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
@@ -24,6 +25,23 @@ type Org struct {
UpdatedAt time.Time `json:"updatedAt"`
}
type orgCacheIndex uint8
const (
orgCacheIndexUndefined orgCacheIndex = iota
orgCacheIndexID
)
// Keys implements [cache.Entry].
func (o *Org) Keys(index orgCacheIndex) (key []string) {
if index == orgCacheIndexID {
return []string{o.ID}
}
return nil
}
var _ cache.Entry[orgCacheIndex, string] = (*Org)(nil)
type orgColumns interface {
// InstanceIDColumn returns the column for the instance id field.
InstanceIDColumn() database.Column

View File

@@ -19,6 +19,11 @@ func NewAddOrgCommand(name string, admins ...*AddMemberCommand) *AddOrgCommand {
}
}
// String implements [Commander].
func (cmd *AddOrgCommand) String() string {
return "AddOrgCommand"
}
// Execute implements Commander.
func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
if len(cmd.Admins) == 0 {
@@ -48,6 +53,11 @@ func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err e
}
}
orgCache.Set(ctx, &Org{
ID: cmd.ID,
Name: cmd.Name,
})
return nil
}
@@ -82,6 +92,18 @@ type AddMemberCommand struct {
Roles []string `json:"roles"`
}
func NewAddMemberCommand(userID string, roles ...string) *AddMemberCommand {
return &AddMemberCommand{
UserID: userID,
Roles: roles,
}
}
// String implements [Commander].
func (cmd *AddMemberCommand) String() string {
return "AddMemberCommand"
}
// Execute implements Commander.
func (a *AddMemberCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
close, err := opts.EnsureTx(ctx)

View File

@@ -31,6 +31,11 @@ func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *Set
return cmd
}
// String implements [Commander].
func (cmd *SetEmailCommand) String() string {
return "SetEmailCommand"
}
func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
close, err := opts.EnsureTx(ctx)
if err != nil {

View File

@@ -2,6 +2,7 @@ package repository
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
@@ -25,12 +26,14 @@ func OrgRepository(client database.QueryExecutor) domain.OrgRepository {
// Create implements [domain.OrgRepository].
func (o *org) Create(ctx context.Context, org *domain.Org) error {
panic("unimplemented")
org.CreatedAt = time.Now()
org.UpdatedAt = org.CreatedAt
return nil
}
// Delete implements [domain.OrgRepository].
func (o *org) Delete(ctx context.Context, condition database.Condition) error {
panic("unimplemented")
return nil
}
// Get implements [domain.OrgRepository].

View File

@@ -12,17 +12,17 @@ type orgMember struct {
// AddMember implements [domain.MemberRepository].
func (o *orgMember) AddMember(ctx context.Context, orgID string, userID string, roles []string) error {
panic("unimplemented")
return nil
}
// RemoveMember implements [domain.MemberRepository].
func (o *orgMember) RemoveMember(ctx context.Context, orgID string, userID string) error {
panic("unimplemented")
return nil
}
// SetMemberRoles implements [domain.MemberRepository].
func (o *orgMember) SetMemberRoles(ctx context.Context, orgID string, userID string, roles []string) error {
panic("unimplemented")
return nil
}
var _ domain.MemberRepository = (*orgMember)(nil)

View File

@@ -5,3 +5,7 @@ import "log/slog"
type Logger struct {
*slog.Logger
}
func NewLogger(logger *slog.Logger) *Logger {
return &Logger{Logger: logger}
}