diff --git a/Makefile b/Makefile index 630e676bc1a..c647fad5d70 100644 --- a/Makefile +++ b/Makefile @@ -138,7 +138,7 @@ core_integration_server_start: core_integration_setup .PHONY: core_integration_test_packages core_integration_test_packages: - go test -race -count 1 -tags integration -timeout 60m -parallel 1 $$(go list -tags integration ./... | grep "integration_test") + go test -race -count 1 -tags integration -timeout 5m -parallel 1 $$(go list -tags integration ./... | grep -e "integration_test" -e "events_testing") -run ^TestServer_TestInstanceReduces$ .PHONY: core_integration_server_stop core_integration_server_stop: @@ -152,7 +152,7 @@ core_integration_server_stop: .PHONY: core_integration_reports core_integration_reports: - go tool covdata textfmt -i=tmp/coverage -pkg=github.com/zitadel/zitadel/internal/...,github.com/zitadel/zitadel/cmd/... -o profile.cov + go tool covdata textfmt -i=tmp/coverage -pkg=github.com/zitadel/zitadel/internal/...,github.com/zitadel/zitadel/cmd/...,github.com/zitadel/zitadel/backend/... -o profile.cov .PHONY: core_integration_test core_integration_test: core_integration_server_start core_integration_test_packages core_integration_server_stop core_integration_reports diff --git a/backend/main.go b/backend/main.go new file mode 100644 index 00000000000..97681df7ada --- /dev/null +++ b/backend/main.go @@ -0,0 +1,10 @@ +// huhu thanks for looking at this. +// you can find more comments and doc.go files in the v3 package. +// to get an overview i would start in the api package. +package main + +// import "github.com/zitadel/zitadel/backend/cmd" + +// func main() { +// cmd.Execute() +// } diff --git a/backend/v3/api/instance/v2/server.go b/backend/v3/api/instance/v2/server.go new file mode 100644 index 00000000000..f50eed8dde4 --- /dev/null +++ b/backend/v3/api/instance/v2/server.go @@ -0,0 +1,21 @@ +package v2 + +// this file has been commented out to pass the linter + +// import ( +// "github.com/zitadel/zitadel/backend/v3/telemetry/logging" +// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +// ) + +// var ( +// logger logging.Logger +// tracer tracing.Tracer +// ) + +// func SetLogger(l logging.Logger) { +// logger = l +// } + +// func SetTracer(t tracing.Tracer) { +// tracer = t +// } diff --git a/backend/v3/api/org/v2/org.go b/backend/v3/api/org/v2/org.go new file mode 100644 index 00000000000..601846848ac --- /dev/null +++ b/backend/v3/api/org/v2/org.go @@ -0,0 +1,33 @@ +package orgv2 + +// import ( +// "context" + +// "github.com/zitadel/zitadel/backend/v3/domain" +// "github.com/zitadel/zitadel/pkg/grpc/org/v2" +// ) + +// func CreateOrg(ctx context.Context, req *org.AddOrganizationRequest) (resp *org.AddOrganizationResponse, err error) { +// cmd := domain.NewAddOrgCommand( +// req.GetName(), +// addOrgAdminToCommand(req.GetAdmins()...)..., +// ) +// err = domain.Invoke(ctx, cmd) +// if err != nil { +// return nil, err +// } +// return &org.AddOrganizationResponse{ +// OrganizationId: cmd.ID, +// }, nil +// } + +// func addOrgAdminToCommand(admins ...*org.AddOrganizationRequest_Admin) []*domain.AddMemberCommand { +// cmds := make([]*domain.AddMemberCommand, len(admins)) +// for i, admin := range admins { +// cmds[i] = &domain.AddMemberCommand{ +// UserID: admin.GetUserId(), +// Roles: admin.GetRoles(), +// } +// } +// return cmds +// } diff --git a/backend/v3/api/org/v2/server.go b/backend/v3/api/org/v2/server.go new file mode 100644 index 00000000000..fdfc45a2b92 --- /dev/null +++ b/backend/v3/api/org/v2/server.go @@ -0,0 +1,21 @@ +package orgv2 + +// this file has been commented out to pass the linter + +// import ( +// "github.com/zitadel/zitadel/backend/v3/telemetry/logging" +// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +// ) + +// var ( +// logger logging.Logger +// tracer tracing.Tracer +// ) + +// func SetLogger(l logging.Logger) { +// logger = l +// } + +// func SetTracer(t tracing.Tracer) { +// tracer = t +// } diff --git a/backend/v3/api/user/v2/email.go b/backend/v3/api/user/v2/email.go new file mode 100644 index 00000000000..5b5285ab905 --- /dev/null +++ b/backend/v3/api/user/v2/email.go @@ -0,0 +1,93 @@ +package userv2 + +// import ( +// "context" + +// "github.com/zitadel/zitadel/backend/v3/domain" +// "github.com/zitadel/zitadel/pkg/grpc/user/v2" +// ) + +// func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) { +// var ( +// verification domain.SetEmailOpt +// returnCode *domain.ReturnCodeCommand +// ) + +// switch req.GetVerification().(type) { +// case *user.SetEmailRequest_IsVerified: +// verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified()) +// case *user.SetEmailRequest_SendCode: +// verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) +// case *user.SetEmailRequest_ReturnCode: +// returnCode = domain.NewReturnCodeCommand(req.GetUserId()) +// verification = returnCode +// default: +// verification = domain.NewSendCodeCommand(req.GetUserId(), nil) +// } + +// err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification)) +// if err != nil { +// return nil, err +// } + +// var code *string +// if returnCode != nil && returnCode.Code != "" { +// code = &returnCode.Code +// } + +// return &user.SetEmailResponse{ +// VerificationCode: code, +// }, nil +// } + +// func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) { +// var ( +// returnCode *domain.ReturnCodeCommand +// cmd domain.Commander +// ) + +// switch req.GetVerification().(type) { +// case *user.SendEmailCodeRequest_SendCode: +// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) +// case *user.SendEmailCodeRequest_ReturnCode: +// returnCode = domain.NewReturnCodeCommand(req.GetUserId()) +// cmd = returnCode +// default: +// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) +// } +// err = domain.Invoke(ctx, cmd) +// if err != nil { +// return nil, err +// } +// resp = new(user.SendEmailCodeResponse) +// if returnCode != nil { +// resp.VerificationCode = &returnCode.Code +// } +// return resp, nil +// } + +// func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) { +// var ( +// returnCode *domain.ReturnCodeCommand +// cmd domain.Commander +// ) + +// switch req.GetVerification().(type) { +// case *user.ResendEmailCodeRequest_SendCode: +// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) +// case *user.ResendEmailCodeRequest_ReturnCode: +// returnCode = domain.NewReturnCodeCommand(req.GetUserId()) +// cmd = returnCode +// default: +// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate) +// } +// err = domain.Invoke(ctx, cmd) +// if err != nil { +// return nil, err +// } +// resp = new(user.SendEmailCodeResponse) +// if returnCode != nil { +// resp.VerificationCode = &returnCode.Code +// } +// return resp, nil +// } diff --git a/backend/v3/api/user/v2/server.go b/backend/v3/api/user/v2/server.go new file mode 100644 index 00000000000..79b87478516 --- /dev/null +++ b/backend/v3/api/user/v2/server.go @@ -0,0 +1,19 @@ +package userv2 + +// this file has been commented out to pass the linter + +// import ( +// "github.com/zitadel/zitadel/backend/v3/telemetry/logging" +// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +// ) + +// logger logging.Logger +// var tracer tracing.Tracer + +// func SetLogger(l logging.Logger) { +// logger = l +// } + +// func SetTracer(t tracing.Tracer) { +// tracer = t +// } diff --git a/backend/v3/doc.go b/backend/v3/doc.go new file mode 100644 index 00000000000..8798c1ce94e --- /dev/null +++ b/backend/v3/doc.go @@ -0,0 +1,21 @@ +// the test used the manly relies on the following patterns: +// - api: +// - some example stubs for the grpc api, it maps the calls and responses to the domain objects +// +// - domain: +// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package +// - command pattern which implements the changes +// - the invoker decorates the commands by checking for events, tracing, logging, potentially caching, etc. +// - the database connections are manged in this package +// - the database connections are passed to the repositories +// +// - storage: +// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package +// - the repositories are used by the domain package to access the database +// - the eventstore to store events. At the beginning it writes to the same events table as the /internal package, afterwards it writes to a different table +// +// - telemetry: +// - logging for standard output +// - tracing for distributed tracing +// - metrics for monitoring +package v3 diff --git a/backend/v3/domain/command.go b/backend/v3/domain/command.go new file mode 100644 index 00000000000..d63da49512d --- /dev/null +++ b/backend/v3/domain/command.go @@ -0,0 +1,131 @@ +package domain + +// import ( +// "context" +// "fmt" + +// "github.com/zitadel/zitadel/backend/v3/storage/database" +// ) + +// // Commander is the all it needs to implement the command pattern. +// // It is the interface all manipulations need to implement. +// // If possible it should also be used for queries. We will find out if this is possible in the future. +// type Commander interface { +// Execute(ctx context.Context, opts *CommandOpts) (err error) +// fmt.Stringer +// } + +// // Invoker is part of the command pattern. +// // It is the interface that is used to execute commands. +// type Invoker interface { +// Invoke(ctx context.Context, command Commander, opts *CommandOpts) error +// } + +// // CommandOpts are passed to each command +// // the provide common fields used by commands like the database client. +// type CommandOpts struct { +// DB database.QueryExecutor +// Invoker Invoker +// } + +// type ensureTxOpts struct { +// *database.TransactionOptions +// } + +// type EnsureTransactionOpt func(*ensureTxOpts) + +// // EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction. +// // The returned close function will end the transaction. If the DB is already a transaction, the close function +// // will do nothing because another [Commander] is already responsible for ending the transaction. +// func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) { +// beginner, ok := o.DB.(database.Beginner) +// if !ok { +// // db is already a transaction +// return func(_ context.Context, err error) error { +// return err +// }, nil +// } + +// txOpts := &ensureTxOpts{ +// TransactionOptions: new(database.TransactionOptions), +// } +// for _, opt := range opts { +// opt(txOpts) +// } + +// tx, err := beginner.Begin(ctx, txOpts.TransactionOptions) +// if err != nil { +// return nil, err +// } +// o.DB = tx + +// return func(ctx context.Context, err error) error { +// return tx.End(ctx, err) +// }, nil +// } + +// // EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool]. +// // The returned close function will release the client. If the o.DB is already a client or transaction, the close function +// // will do nothing because another [Commander] is already responsible for releasing the client. +// func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) { +// pool, ok := o.DB.(database.Pool) +// if !ok { +// // o.DB is already a client +// return func(_ context.Context) error { +// return nil +// }, nil +// } +// client, err := pool.Acquire(ctx) +// if err != nil { +// return nil, err +// } +// o.DB = client +// return func(ctx context.Context) error { +// return client.Release(ctx) +// }, nil +// } + +// func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error { +// if o.Invoker == nil { +// return command.Execute(ctx, o) +// } +// return o.Invoker.Invoke(ctx, command, o) +// } + +// func DefaultOpts(invoker Invoker) *CommandOpts { +// if invoker == nil { +// invoker = &noopInvoker{} +// } +// return &CommandOpts{ +// DB: pool, +// Invoker: invoker, +// } +// } + +// // commandBatch is a batch of commands. +// // It uses the [Invoker] provided by the opts to execute each command. +// type commandBatch struct { +// Commands []Commander +// } + +// func BatchCommands(cmds ...Commander) *commandBatch { +// return &commandBatch{ +// Commands: cmds, +// } +// } + +// // String implements [Commander]. +// func (cmd *commandBatch) String() string { +// return "commandBatch" +// } + +// func (b *commandBatch) Execute(ctx context.Context, opts *CommandOpts) (err error) { +// for _, cmd := range b.Commands { +// if err = opts.Invoke(ctx, cmd); err != nil { +// return err +// } +// } +// return nil +// } + +// var _ Commander = (*commandBatch)(nil) diff --git a/backend/v3/domain/create_user.go b/backend/v3/domain/create_user.go new file mode 100644 index 00000000000..c4eb48e1d06 --- /dev/null +++ b/backend/v3/domain/create_user.go @@ -0,0 +1,90 @@ +package domain + +// import ( +// "context" + +// "github.com/zitadel/zitadel/backend/v3/storage/eventstore" +// ) + +// // CreateUserCommand adds a new user including the email verification for humans. +// // In the future it might make sense to separate the command into two commands: +// // - CreateHumanCommand: creates a new human user +// // - CreateMachineCommand: creates a new machine user +// type CreateUserCommand struct { +// user *User +// email *SetEmailCommand +// } + +// var ( +// _ Commander = (*CreateUserCommand)(nil) +// _ eventer = (*CreateUserCommand)(nil) +// ) + +// // opts heavily reduces the complexity for email verification because each type of verification is a simple option which implements the [Commander] interface. +// func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand { +// cmd := &CreateUserCommand{ +// user: &User{ +// Username: username, +// Traits: &Human{}, +// }, +// } + +// for _, opt := range opts { +// opt.applyOnCreateHuman(cmd) +// } +// return cmd +// } + +// // String implements [Commander]. +// func (cmd *CreateUserCommand) String() string { +// return "CreateUserCommand" +// } + +// // Events implements [eventer]. +// func (c *CreateUserCommand) Events() []*eventstore.Event { +// return []*eventstore.Event{ +// { +// AggregateType: "user", +// AggregateID: c.user.ID, +// Type: "user.added", +// Payload: c.user, +// }, +// } +// } + +// // Execute implements [Commander]. +// func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error { +// if err := c.ensureUserID(); err != nil { +// return err +// } +// c.email.UserID = c.user.ID +// if err := opts.Invoke(ctx, c.email); err != nil { +// return err +// } +// return nil +// } + +// type CreateHumanOpt interface { +// applyOnCreateHuman(*CreateUserCommand) +// } + +// type createHumanIDOpt string + +// // applyOnCreateHuman implements [CreateHumanOpt]. +// func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) { +// cmd.user.ID = string(c) +// } + +// var _ CreateHumanOpt = (*createHumanIDOpt)(nil) + +// func CreateHumanWithID(id string) CreateHumanOpt { +// return createHumanIDOpt(id) +// } + +// func (c *CreateUserCommand) ensureUserID() (err error) { +// if c.user.ID != "" { +// return nil +// } +// c.user.ID, err = generateID() +// return err +// } diff --git a/backend/v3/domain/crypto.go b/backend/v3/domain/crypto.go new file mode 100644 index 00000000000..c0d76c35814 --- /dev/null +++ b/backend/v3/domain/crypto.go @@ -0,0 +1,37 @@ +package domain + +// import ( +// "context" + +// "github.com/zitadel/zitadel/internal/crypto" +// ) + +// type generateCodeCommand struct { +// code string +// value *crypto.CryptoValue +// } + +// // I didn't update this repository to the solution proposed please view one of the following interfaces for correct usage: +// // - [UserRepository] +// // - [InstanceRepository] +// // - [OrgRepository] +// type CryptoRepository interface { +// GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error) +// } + +// // String implements [Commander]. +// func (cmd *generateCodeCommand) String() string { +// return "generateCodeCommand" +// } + +// func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { +// config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx) +// if err != nil { +// return err +// } +// generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm) +// cmd.value, cmd.code, err = crypto.NewCode(generator) +// return err +// } + +// var _ Commander = (*generateCodeCommand)(nil) diff --git a/backend/v3/domain/domain.go b/backend/v3/domain/domain.go new file mode 100644 index 00000000000..53aa4b672f3 --- /dev/null +++ b/backend/v3/domain/domain.go @@ -0,0 +1,127 @@ +package domain + +import ( + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +//go:generate enumer -type DomainValidationType -transform lower -trimprefix DomainValidationType -sql +type DomainValidationType uint8 + +const ( + DomainValidationTypeDNS DomainValidationType = iota + DomainValidationTypeHTTP +) + +//go:generate enumer -type DomainType -transform lower -trimprefix DomainType -sql +type DomainType uint8 + +const ( + DomainTypeCustom DomainType = iota + DomainTypeTrusted +) + +type domainColumns interface { + // InstanceIDColumn returns the column for the instance id field. + InstanceIDColumn() database.Column + // DomainColumn returns the column for the domain field. + DomainColumn() database.Column + // IsPrimaryColumn returns the column for the is primary field. + IsPrimaryColumn() database.Column + // CreatedAtColumn returns the column for the created at field. + CreatedAtColumn() database.Column + // UpdatedAtColumn returns the column for the updated at field. + UpdatedAtColumn() database.Column +} + +type domainConditions interface { + // InstanceIDCondition returns a filter on the instance id field. + InstanceIDCondition(instanceID string) database.Condition + // DomainCondition returns a filter on the domain field. + DomainCondition(op database.TextOperation, domain string) database.Condition + // IsPrimaryCondition returns a filter on the is primary field. + IsPrimaryCondition(isPrimary bool) database.Condition +} + +type domainChanges interface { + // SetPrimary sets a domain as primary based on the condition. + // All other domains will be set to non-primary. + // + // An error is returned if: + // - The condition identifies multiple domains. + // - The condition does not identify any domain. + // + // This is a no-op if: + // - The domain is already primary. + // - No domain matches the condition. + SetPrimary() database.Change + // SetUpdatedAt sets the updated at column. + // This is used for reducing events. + SetUpdatedAt(t time.Time) database.Change +} + +// import ( +// "math/rand/v2" +// "strconv" + +// "github.com/zitadel/zitadel/backend/v3/storage/cache" +// "github.com/zitadel/zitadel/backend/v3/storage/database" + +// // "github.com/zitadel/zitadel/backend/v3/telemetry/logging" +// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +// "github.com/zitadel/zitadel/internal/crypto" +// ) + +// // The variables could also be moved to a struct. +// // I just started with the singleton pattern and kept it like this. +// var ( +// pool database.Pool +// userCodeAlgorithm crypto.EncryptionAlgorithm +// tracer tracing.Tracer +// // logger logging.Logger + +// userRepo func(database.QueryExecutor) UserRepository +// // instanceRepo func(database.QueryExecutor) InstanceRepository +// cryptoRepo func(database.QueryExecutor) CryptoRepository +// orgRepo func(database.QueryExecutor) OrgRepository + +// // instanceCache cache.Cache[instanceCacheIndex, string, *Instance] +// orgCache cache.Cache[orgCacheIndex, string, *Org] + +// generateID func() (string, error) = func() (string, error) { +// return strconv.FormatUint(rand.Uint64(), 10), nil +// } +// ) + +// func SetPool(p database.Pool) { +// pool = p +// } + +// func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) { +// userCodeAlgorithm = algorithm +// } + +// func SetTracer(t tracing.Tracer) { +// tracer = t +// } + +// // func SetLogger(l logging.Logger) { +// // logger = l +// // } + +// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) { +// userRepo = repo +// } + +// func SetOrgRepository(repo func(database.QueryExecutor) OrgRepository) { +// orgRepo = repo +// } + +// // func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) { +// // instanceRepo = repo +// // } + +// func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) { +// cryptoRepo = repo +// } diff --git a/backend/v3/domain/domain_test.go b/backend/v3/domain/domain_test.go new file mode 100644 index 00000000000..503d322a5ce --- /dev/null +++ b/backend/v3/domain/domain_test.go @@ -0,0 +1,67 @@ +package domain_test + +// import ( +// "context" +// "log/slog" +// "testing" + +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/require" +// "go.opentelemetry.io/otel" +// "go.opentelemetry.io/otel/exporters/stdout/stdouttrace" +// sdktrace "go.opentelemetry.io/otel/sdk/trace" +// "go.uber.org/mock/gomock" + +// . "github.com/zitadel/zitadel/backend/v3/domain" +// "github.com/zitadel/zitadel/backend/v3/storage/database/dbmock" +// "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +// "github.com/zitadel/zitadel/backend/v3/telemetry/logging" +// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing" +// ) + +// These tests give an overview of how to use the domain package. +// func TestExample(t *testing.T) { +// t.Skip("skip example test because it is not a real test") +// ctx := context.Background() + +// ctrl := gomock.NewController(t) +// pool := dbmock.NewMockPool(ctrl) +// tx := dbmock.NewMockTransaction(ctrl) + +// pool.EXPECT().Begin(gomock.Any(), gomock.Any()).Return(tx, nil) +// tx.EXPECT().End(gomock.Any(), gomock.Any()).Return(nil) +// SetPool(pool) + +// exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint()) +// require.NoError(t, err) +// tracerProvider := sdktrace.NewTracerProvider( +// sdktrace.WithSyncer(exporter), +// ) +// otel.SetTracerProvider(tracerProvider) +// SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")}) +// defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }() + +// SetLogger(logging.Logger{Logger: slog.Default()}) + +// SetUserRepository(repository.UserRepository) +// SetOrgRepository(repository.OrgRepository) +// // SetInstanceRepository(repository.Instance) +// // SetCryptoRepository(repository.Crypto) + +// t.Run("create org", func(t *testing.T) { +// org := NewAddOrgCommand("testorg", NewAddMemberCommand("testuser", "ORG_OWNER")) +// user := NewCreateHumanCommand("testuser") +// err := Invoke(ctx, BatchCommands(org, user)) +// assert.NoError(t, err) +// }) + +// t.Run("verified email", func(t *testing.T) { +// err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true))) +// assert.NoError(t, err) +// }) + +// t.Run("unverified email", func(t *testing.T) { +// err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false))) +// assert.NoError(t, err) +// }) +// } diff --git a/backend/v3/domain/domaintype_enumer.go b/backend/v3/domain/domaintype_enumer.go new file mode 100644 index 00000000000..67b3b58b007 --- /dev/null +++ b/backend/v3/domain/domaintype_enumer.go @@ -0,0 +1,109 @@ +// Code generated by "enumer -type DomainType -transform lower -trimprefix DomainType -sql"; DO NOT EDIT. + +package domain + +import ( + "database/sql/driver" + "fmt" + "strings" +) + +const _DomainTypeName = "customtrusted" + +var _DomainTypeIndex = [...]uint8{0, 6, 13} + +const _DomainTypeLowerName = "customtrusted" + +func (i DomainType) String() string { + if i >= DomainType(len(_DomainTypeIndex)-1) { + return fmt.Sprintf("DomainType(%d)", i) + } + return _DomainTypeName[_DomainTypeIndex[i]:_DomainTypeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _DomainTypeNoOp() { + var x [1]struct{} + _ = x[DomainTypeCustom-(0)] + _ = x[DomainTypeTrusted-(1)] +} + +var _DomainTypeValues = []DomainType{DomainTypeCustom, DomainTypeTrusted} + +var _DomainTypeNameToValueMap = map[string]DomainType{ + _DomainTypeName[0:6]: DomainTypeCustom, + _DomainTypeLowerName[0:6]: DomainTypeCustom, + _DomainTypeName[6:13]: DomainTypeTrusted, + _DomainTypeLowerName[6:13]: DomainTypeTrusted, +} + +var _DomainTypeNames = []string{ + _DomainTypeName[0:6], + _DomainTypeName[6:13], +} + +// DomainTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func DomainTypeString(s string) (DomainType, error) { + if val, ok := _DomainTypeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _DomainTypeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to DomainType values", s) +} + +// DomainTypeValues returns all values of the enum +func DomainTypeValues() []DomainType { + return _DomainTypeValues +} + +// DomainTypeStrings returns a slice of all String values of the enum +func DomainTypeStrings() []string { + strs := make([]string, len(_DomainTypeNames)) + copy(strs, _DomainTypeNames) + return strs +} + +// IsADomainType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i DomainType) IsADomainType() bool { + for _, v := range _DomainTypeValues { + if i == v { + return true + } + } + return false +} + +func (i DomainType) Value() (driver.Value, error) { + return i.String(), nil +} + +func (i *DomainType) Scan(value interface{}) error { + if value == nil { + return nil + } + + var str string + switch v := value.(type) { + case []byte: + str = string(v) + case string: + str = v + case fmt.Stringer: + str = v.String() + default: + return fmt.Errorf("invalid value of DomainType: %[1]T(%[1]v)", value) + } + + val, err := DomainTypeString(str) + if err != nil { + return err + } + + *i = val + return nil +} diff --git a/backend/v3/domain/domainvalidationtype_enumer.go b/backend/v3/domain/domainvalidationtype_enumer.go new file mode 100644 index 00000000000..9d3fdab51c4 --- /dev/null +++ b/backend/v3/domain/domainvalidationtype_enumer.go @@ -0,0 +1,109 @@ +// Code generated by "enumer -type DomainValidationType -transform lower -trimprefix DomainValidationType -sql"; DO NOT EDIT. + +package domain + +import ( + "database/sql/driver" + "fmt" + "strings" +) + +const _DomainValidationTypeName = "dnshttp" + +var _DomainValidationTypeIndex = [...]uint8{0, 3, 7} + +const _DomainValidationTypeLowerName = "dnshttp" + +func (i DomainValidationType) String() string { + if i >= DomainValidationType(len(_DomainValidationTypeIndex)-1) { + return fmt.Sprintf("DomainValidationType(%d)", i) + } + return _DomainValidationTypeName[_DomainValidationTypeIndex[i]:_DomainValidationTypeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _DomainValidationTypeNoOp() { + var x [1]struct{} + _ = x[DomainValidationTypeDNS-(0)] + _ = x[DomainValidationTypeHTTP-(1)] +} + +var _DomainValidationTypeValues = []DomainValidationType{DomainValidationTypeDNS, DomainValidationTypeHTTP} + +var _DomainValidationTypeNameToValueMap = map[string]DomainValidationType{ + _DomainValidationTypeName[0:3]: DomainValidationTypeDNS, + _DomainValidationTypeLowerName[0:3]: DomainValidationTypeDNS, + _DomainValidationTypeName[3:7]: DomainValidationTypeHTTP, + _DomainValidationTypeLowerName[3:7]: DomainValidationTypeHTTP, +} + +var _DomainValidationTypeNames = []string{ + _DomainValidationTypeName[0:3], + _DomainValidationTypeName[3:7], +} + +// DomainValidationTypeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func DomainValidationTypeString(s string) (DomainValidationType, error) { + if val, ok := _DomainValidationTypeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _DomainValidationTypeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to DomainValidationType values", s) +} + +// DomainValidationTypeValues returns all values of the enum +func DomainValidationTypeValues() []DomainValidationType { + return _DomainValidationTypeValues +} + +// DomainValidationTypeStrings returns a slice of all String values of the enum +func DomainValidationTypeStrings() []string { + strs := make([]string, len(_DomainValidationTypeNames)) + copy(strs, _DomainValidationTypeNames) + return strs +} + +// IsADomainValidationType returns "true" if the value is listed in the enum definition. "false" otherwise +func (i DomainValidationType) IsADomainValidationType() bool { + for _, v := range _DomainValidationTypeValues { + if i == v { + return true + } + } + return false +} + +func (i DomainValidationType) Value() (driver.Value, error) { + return i.String(), nil +} + +func (i *DomainValidationType) Scan(value interface{}) error { + if value == nil { + return nil + } + + var str string + switch v := value.(type) { + case []byte: + str = string(v) + case string: + str = v + case fmt.Stringer: + str = v.String() + default: + return fmt.Errorf("invalid value of DomainValidationType: %[1]T(%[1]v)", value) + } + + val, err := DomainValidationTypeString(str) + if err != nil { + return err + } + + *i = val + return nil +} diff --git a/backend/v3/domain/email_verification.go b/backend/v3/domain/email_verification.go new file mode 100644 index 00000000000..168b2fc945a --- /dev/null +++ b/backend/v3/domain/email_verification.go @@ -0,0 +1,175 @@ +package domain + +// import ( +// "context" +// "time" +// ) + +// // EmailVerifiedCommand verifies an email address for a user. +// type EmailVerifiedCommand struct { +// UserID string `json:"userId"` +// Email *Email `json:"email"` +// } + +// func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand { +// return &EmailVerifiedCommand{ +// UserID: userID, +// Email: &Email{ +// VerifiedAt: time.Time{}, +// }, +// } +// } + +// // String implements [Commander]. +// func (cmd *EmailVerifiedCommand) String() string { +// return "EmailVerifiedCommand" +// } + +// var ( +// _ Commander = (*EmailVerifiedCommand)(nil) +// _ SetEmailOpt = (*EmailVerifiedCommand)(nil) +// ) + +// // Execute implements [Commander] +// func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error { +// repo := userRepo(opts.DB).Human() +// return repo.Update(ctx, repo.IDCondition(cmd.UserID), repo.SetEmailVerifiedAt(time.Time{})) +// } + +// // applyOnSetEmail implements [SetEmailOpt] +// func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) { +// cmd.UserID = setEmailCmd.UserID +// cmd.Email.Address = setEmailCmd.Email +// setEmailCmd.verification = cmd +// } + +// // SendCodeCommand sends a verification code to the user's email address. +// // If the URLTemplate is not set it will use the default of the organization / instance. +// type SendCodeCommand struct { +// UserID string `json:"userId"` +// Email string `json:"email"` +// URLTemplate *string `json:"urlTemplate"` +// generator *generateCodeCommand +// } + +// var ( +// _ Commander = (*SendCodeCommand)(nil) +// _ SetEmailOpt = (*SendCodeCommand)(nil) +// ) + +// func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand { +// return &SendCodeCommand{ +// UserID: userID, +// generator: &generateCodeCommand{}, +// URLTemplate: urlTemplate, +// } +// } + +// // String implements [Commander]. +// func (cmd *SendCodeCommand) String() string { +// return "SendCodeCommand" +// } + +// // Execute implements [Commander] +// func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { +// if err := cmd.ensureEmail(ctx, opts); err != nil { +// return err +// } +// if err := cmd.ensureURL(ctx, opts); err != nil { +// return err +// } + +// if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil { +// return err +// } +// // TODO: queue notification + +// return nil +// } + +// func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error { +// if cmd.Email != "" { +// return nil +// } +// repo := userRepo(opts.DB).Human() +// email, err := repo.GetEmail(ctx, repo.IDCondition(cmd.UserID)) +// if err != nil || !email.VerifiedAt.IsZero() { +// return err +// } +// cmd.Email = email.Address +// return nil +// } + +// func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error { +// if cmd.URLTemplate != nil && *cmd.URLTemplate != "" { +// return nil +// } +// _, _ = ctx, opts +// // TODO: load default template +// return nil +// } + +// // applyOnSetEmail implements [SetEmailOpt] +// func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) { +// cmd.UserID = setEmailCmd.UserID +// cmd.Email = setEmailCmd.Email +// setEmailCmd.verification = cmd +// } + +// // ReturnCodeCommand creates the code and returns it to the caller. +// // The caller gets the code by calling the Code field after the command got executed. +// type ReturnCodeCommand struct { +// UserID string `json:"userId"` +// Email string `json:"email"` +// Code string `json:"code"` +// generator *generateCodeCommand +// } + +// var ( +// _ Commander = (*ReturnCodeCommand)(nil) +// _ SetEmailOpt = (*ReturnCodeCommand)(nil) +// ) + +// func NewReturnCodeCommand(userID string) *ReturnCodeCommand { +// return &ReturnCodeCommand{ +// UserID: userID, +// generator: &generateCodeCommand{}, +// } +// } + +// // String implements [Commander]. +// func (cmd *ReturnCodeCommand) String() string { +// return "ReturnCodeCommand" +// } + +// // Execute implements [Commander] +// func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error { +// if err := cmd.ensureEmail(ctx, opts); err != nil { +// return err +// } +// if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil { +// return err +// } +// cmd.Code = cmd.generator.code +// return nil +// } + +// func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error { +// if cmd.Email != "" { +// return nil +// } +// repo := userRepo(opts.DB).Human() +// email, err := repo.GetEmail(ctx, repo.IDCondition(cmd.UserID)) +// if err != nil || !email.VerifiedAt.IsZero() { +// return err +// } +// cmd.Email = email.Address +// return nil +// } + +// // applyOnSetEmail implements [SetEmailOpt] +// func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) { +// cmd.UserID = setEmailCmd.UserID +// cmd.Email = setEmailCmd.Email +// setEmailCmd.verification = cmd +// } diff --git a/backend/v3/domain/errors.go b/backend/v3/domain/errors.go new file mode 100644 index 00000000000..a11c31c07d4 --- /dev/null +++ b/backend/v3/domain/errors.go @@ -0,0 +1,7 @@ +package domain + +import "errors" + +var ( + ErrNoAdminSpecified = errors.New("at least one admin must be specified") +) diff --git a/backend/v3/domain/instance.go b/backend/v3/domain/instance.go new file mode 100644 index 00000000000..03dd40cf2ec --- /dev/null +++ b/backend/v3/domain/instance.go @@ -0,0 +1,117 @@ +package domain + +import ( + "context" + "time" + + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Instance struct { + ID string `json:"id,omitempty" db:"id"` + Name string `json:"name,omitempty" db:"name"` + DefaultOrgID string `json:"defaultOrgId,omitempty" db:"default_org_id"` + IAMProjectID string `json:"iamProjectId,omitempty" db:"iam_project_id"` + ConsoleClientID string `json:"consoleClientId,omitempty" db:"console_client_id"` + ConsoleAppID string `json:"consoleAppId,omitempty" db:"console_app_id"` + DefaultLanguage string `json:"defaultLanguage,omitempty" db:"default_language"` + CreatedAt time.Time `json:"createdAt" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt" db:"updated_at"` + + Domains []*InstanceDomain `json:"domains,omitempty" db:"-"` +} + +type instanceCacheIndex uint8 + +const ( + instanceCacheIndexUndefined instanceCacheIndex = iota + instanceCacheIndexID +) + +// Keys implements the [cache.Entry]. +func (i *Instance) Keys(index instanceCacheIndex) (key []string) { + if index == instanceCacheIndexID { + return []string{i.ID} + } + return nil +} + +var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil) + +// instanceColumns define all the columns of the instance table. +type instanceColumns interface { + // IDColumn returns the column for the id field. + IDColumn() database.Column + // NameColumn returns the column for the name field. + NameColumn() database.Column + // DefaultOrgIDColumn returns the column for the default org id field + DefaultOrgIDColumn() database.Column + // IAMProjectIDColumn returns the column for the default IAM org id field + IAMProjectIDColumn() database.Column + // ConsoleClientIDColumn returns the column for the default IAM org id field + ConsoleClientIDColumn() database.Column + // ConsoleAppIDColumn returns the column for the console client id field + ConsoleAppIDColumn() database.Column + // DefaultLanguageColumn returns the column for the default language field + DefaultLanguageColumn() database.Column + // CreatedAtColumn returns the column for the created at field. + CreatedAtColumn() database.Column + // UpdatedAtColumn returns the column for the updated at field. + UpdatedAtColumn() database.Column +} + +// instanceConditions define all the conditions for the instance table. +type instanceConditions interface { + // IDCondition returns an equal filter on the id field. + IDCondition(instanceID string) database.Condition + // NameCondition returns a filter on the name field. + NameCondition(op database.TextOperation, name string) database.Condition +} + +// instanceChanges define all the changes for the instance table. +type instanceChanges interface { + // SetName sets the name column. + SetName(name string) database.Change + // SetUpdatedAt sets the updated at column. + SetUpdatedAt(time time.Time) database.Change + // SetIAMProject sets the iam project column. + SetIAMProject(id string) database.Change + // SetDefaultOrg sets the default org column. + SetDefaultOrg(id string) database.Change + // SetDefaultLanguage sets the default language column. + SetDefaultLanguage(language language.Tag) database.Change + // SetConsoleClientID sets the console client id column. + SetConsoleClientID(id string) database.Change + // SetConsoleAppID sets the console app id column. + SetConsoleAppID(id string) database.Change +} + +// InstanceRepository is the interface for the instance repository. +type InstanceRepository interface { + instanceColumns + instanceConditions + instanceChanges + + // TODO + // Member returns the member repository which is a sub repository of the instance repository. + // Member() MemberRepository + + Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error) + List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error) + + Create(ctx context.Context, instance *Instance) error + Update(ctx context.Context, id string, changes ...database.Change) (int64, error) + Delete(ctx context.Context, id string) (int64, error) + + // Domains returns the domain sub repository for the instance. + // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field. + // If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future. + Domains(shouldLoad bool) InstanceDomainRepository +} + +type CreateInstance struct { + Name string `json:"name"` +} diff --git a/backend/v3/domain/instance_domain.go b/backend/v3/domain/instance_domain.go new file mode 100644 index 00000000000..4c6a71b2e9f --- /dev/null +++ b/backend/v3/domain/instance_domain.go @@ -0,0 +1,79 @@ +package domain + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type InstanceDomain struct { + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + Domain string `json:"domain,omitempty" db:"domain"` + // IsPrimary indicates if the domain is the primary domain of the instance. + // It is only set for custom domains. + IsPrimary *bool `json:"isPrimary,omitempty" db:"is_primary"` + // IsGenerated indicates if the domain is a generated domain. + // It is only set for custom domains. + IsGenerated *bool `json:"isGenerated,omitempty" db:"is_generated"` + Type DomainType `json:"type,omitempty" db:"type"` + + CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"` +} + +type AddInstanceDomain struct { + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + Domain string `json:"domain,omitempty" db:"domain"` + IsPrimary *bool `json:"isPrimary,omitempty" db:"is_primary"` + IsGenerated *bool `json:"isGenerated,omitempty" db:"is_generated"` + Type DomainType `json:"type,omitempty" db:"type"` + + // CreatedAt is the time when the domain was added. + // It is set by the repository and should not be set by the caller. + CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"` + // UpdatedAt is the time when the domain was last updated. + // It is set by the repository and should not be set by the caller. + UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"` +} + +type instanceDomainColumns interface { + domainColumns + // IsGeneratedColumn returns the column for the is generated field. + IsGeneratedColumn() database.Column + // TypeColumn returns the column for the type field. + TypeColumn() database.Column +} + +type instanceDomainConditions interface { + domainConditions + // TypeCondition returns a filter for the type field. + TypeCondition(typ DomainType) database.Condition +} + +type instanceDomainChanges interface { + domainChanges + // SetType sets the type column. + SetType(typ DomainType) database.Change +} + +type InstanceDomainRepository interface { + instanceDomainColumns + instanceDomainConditions + instanceDomainChanges + + // Get returns a single domain based on the criteria. + // If no domain is found, it returns an error of type [database.ErrNotFound]. + // If multiple domains are found, it returns an error of type [database.ErrMultipleRows]. + Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error) + // List returns a list of domains based on the criteria. + // If no domains are found, it returns an empty slice. + List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error) + + // Add adds a new domain to the instance. + Add(ctx context.Context, domain *AddInstanceDomain) error + // Update updates an existing domain in the instance. + Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) + // Remove removes a domain from the instance. + Remove(ctx context.Context, condition database.Condition) (int64, error) +} diff --git a/backend/v3/domain/invoke.go b/backend/v3/domain/invoke.go new file mode 100644 index 00000000000..8c25bad6bae --- /dev/null +++ b/backend/v3/domain/invoke.go @@ -0,0 +1,158 @@ +package domain + +// import ( +// "context" +// "fmt" + +// "github.com/zitadel/zitadel/backend/v3/storage/eventstore" +// ) + +// // Invoke provides a way to execute commands within the domain package. +// // It uses a chain of responsibility pattern to handle the command execution. +// // The default chain includes logging, tracing, and event publishing. +// // If you want to invoke multiple commands in a single transaction, you can use the [commandBatch]. +// func Invoke(ctx context.Context, cmd Commander) error { +// invoker := newEventStoreInvoker(newLoggingInvoker(newTraceInvoker(nil))) +// opts := &CommandOpts{ +// Invoker: invoker.collector, +// DB: pool, +// } +// return invoker.Invoke(ctx, cmd, opts) +// } + +// // eventStoreInvoker checks if the command implements the [eventer] interface. +// // If it does, it collects the events and publishes them to the event store. +// type eventStoreInvoker struct { +// collector *eventCollector +// } + +// func newEventStoreInvoker(next Invoker) *eventStoreInvoker { +// return &eventStoreInvoker{collector: &eventCollector{next: next}} +// } + +// func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { +// err = i.collector.Invoke(ctx, command, opts) +// if err != nil { +// return err +// } +// if len(i.collector.events) > 0 { +// err = eventstore.Publish(ctx, i.collector.events, opts.DB) +// if err != nil { +// return err +// } +// } +// return nil +// } + +// // eventCollector collects events from all commands. The [eventStoreInvoker] pushes the collected events after all commands are executed. +// type eventCollector struct { +// next Invoker +// events []*eventstore.Event +// } + +// type eventer interface { +// Events() []*eventstore.Event +// } + +// func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { +// if e, ok := command.(eventer); ok && len(e.Events()) > 0 { +// // we need to ensure all commands are executed in the same transaction +// close, err := opts.EnsureTx(ctx) +// if err != nil { +// return err +// } +// defer func() { err = close(ctx, err) }() + +// i.events = append(i.events, e.Events()...) +// } +// if i.next != nil { +// return i.next.Invoke(ctx, command, opts) +// } +// return command.Execute(ctx, opts) +// } + +// // traceInvoker decorates each command with tracing. +// type traceInvoker struct { +// next Invoker +// } + +// func newTraceInvoker(next Invoker) *traceInvoker { +// return &traceInvoker{next: next} +// } + +// func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { +// ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command)) +// defer func() { +// if err != nil { +// span.RecordError(err) +// } +// span.End() +// }() + +// if i.next != nil { +// return i.next.Invoke(ctx, command, opts) +// } +// return command.Execute(ctx, opts) +// } + +// // loggingInvoker decorates each command with logging. +// // It is an example implementation and logs the command name at the beginning and success or failure after the command got executed. +// type loggingInvoker struct { +// next Invoker +// } + +// func newLoggingInvoker(next Invoker) *loggingInvoker { +// return &loggingInvoker{next: next} +// } + +// func (i *loggingInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { +// logger.InfoContext(ctx, "Invoking command", "command", command.String()) + +// if i.next != nil { +// err = i.next.Invoke(ctx, command, opts) +// } else { +// err = command.Execute(ctx, opts) +// } + +// if err != nil { +// logger.ErrorContext(ctx, "Command invocation failed", "command", command.String(), "error", err) +// return err +// } +// logger.InfoContext(ctx, "Command invocation succeeded", "command", command.String()) +// return nil +// } + +// type noopInvoker struct { +// next Invoker +// } + +// func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error { +// if i.next != nil { +// return i.next.Invoke(ctx, command, opts) +// } +// return command.Execute(ctx, opts) +// } + +// // cacheInvoker could be used in the future to do the caching. +// // My goal would be to have two interfaces: +// // - cacheSetter: which caches an object +// // - cacheGetter: which gets an object from the cache, this should also skip the command execution +// type cacheInvoker struct { +// next Invoker +// } + +// type cacher interface { +// Cache(opts *CommandOpts) +// } + +// func (i *cacheInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) { +// if c, ok := command.(cacher); ok { +// c.Cache(opts) +// } +// if i.next != nil { +// err = i.next.Invoke(ctx, command, opts) +// } else { +// err = command.Execute(ctx, opts) +// } +// return err +// } diff --git a/backend/v3/domain/org_add.go b/backend/v3/domain/org_add.go new file mode 100644 index 00000000000..808c18f06f4 --- /dev/null +++ b/backend/v3/domain/org_add.go @@ -0,0 +1,137 @@ +package domain + +// import ( +// "context" + +// "github.com/zitadel/zitadel/backend/v3/storage/eventstore" +// ) + +// // AddOrgCommand adds a new organization. +// // I'm unsure if we should add the Admins here or if this should be a separate command. +// type AddOrgCommand struct { +// ID string `json:"id"` +// Name string `json:"name"` +// Admins []*AddMemberCommand `json:"admins"` +// } + +// func NewAddOrgCommand(name string, admins ...*AddMemberCommand) *AddOrgCommand { +// return &AddOrgCommand{ +// Name: name, +// Admins: admins, +// } +// } + +// // String implements [Commander]. +// func (cmd *AddOrgCommand) String() string { +// return "AddOrgCommand" +// } + +// // Execute implements Commander. +// func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) { +// if len(cmd.Admins) == 0 { +// return ErrNoAdminSpecified +// } +// if err = cmd.ensureID(); err != nil { +// return err +// } + +// close, err := opts.EnsureTx(ctx) +// if err != nil { +// return err +// } +// defer func() { err = close(ctx, err) }() +// err = orgRepo(opts.DB).Create(ctx, &Org{ +// ID: cmd.ID, +// Name: cmd.Name, +// }) +// if err != nil { +// return err +// } + +// for _, admin := range cmd.Admins { +// admin.orgID = cmd.ID +// if err = opts.Invoke(ctx, admin); err != nil { +// return err +// } +// } + +// orgCache.Set(ctx, &Org{ +// ID: cmd.ID, +// Name: cmd.Name, +// }) + +// return nil +// } + +// // Events implements [eventer]. +// func (cmd *AddOrgCommand) Events() []*eventstore.Event { +// return []*eventstore.Event{ +// { +// AggregateType: "org", +// AggregateID: cmd.ID, +// Type: "org.added", +// Payload: cmd, +// }, +// } +// } + +// var ( +// _ Commander = (*AddOrgCommand)(nil) +// _ eventer = (*AddOrgCommand)(nil) +// ) + +// func (cmd *AddOrgCommand) ensureID() (err error) { +// if cmd.ID != "" { +// return nil +// } +// cmd.ID, err = generateID() +// return err +// } + +// // AddMemberCommand adds a new member to an organization. +// // I'm not sure if we should make it more generic to also use it for instances. +// type AddMemberCommand struct { +// orgID string +// UserID string `json:"userId"` +// Roles []string `json:"roles"` +// } + +// func NewAddMemberCommand(userID string, roles ...string) *AddMemberCommand { +// return &AddMemberCommand{ +// UserID: userID, +// Roles: roles, +// } +// } + +// // String implements [Commander]. +// func (cmd *AddMemberCommand) String() string { +// return "AddMemberCommand" +// } + +// // Execute implements Commander. +// func (a *AddMemberCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) { +// close, err := opts.EnsureTx(ctx) +// if err != nil { +// return err +// } +// defer func() { err = close(ctx, err) }() + +// return orgRepo(opts.DB).Member().AddMember(ctx, a.orgID, a.UserID, a.Roles) +// } + +// // Events implements [eventer]. +// func (a *AddMemberCommand) Events() []*eventstore.Event { +// return []*eventstore.Event{ +// { +// AggregateType: "org", +// AggregateID: a.UserID, +// Type: "member.added", +// Payload: a, +// }, +// } +// } + +// var ( +// _ Commander = (*AddMemberCommand)(nil) +// _ eventer = (*AddMemberCommand)(nil) +// ) diff --git a/backend/v3/domain/organization.go b/backend/v3/domain/organization.go new file mode 100644 index 00000000000..570e96c39e9 --- /dev/null +++ b/backend/v3/domain/organization.go @@ -0,0 +1,100 @@ +package domain + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +//go:generate enumer -type OrgState -transform lower -trimprefix OrgState -sql +type OrgState uint8 + +const ( + OrgStateActive OrgState = iota + OrgStateInactive +) + +type Organization struct { + ID string `json:"id,omitempty" db:"id"` + Name string `json:"name,omitempty" db:"name"` + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + State OrgState `json:"state,omitempty" db:"state"` + CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"` + + Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately +} + +// OrgIdentifierCondition is used to help specify a single Organization, +// it will either be used as the organization ID or organization name, +// as organizations can be identified either using (instanceID + ID) OR (instanceID + name) +type OrgIdentifierCondition interface { + database.Condition +} + +// organizationColumns define all the columns of the instance table. +type organizationColumns interface { + // IDColumn returns the column for the id field. + IDColumn() database.Column + // NameColumn returns the column for the name field. + NameColumn() database.Column + // InstanceIDColumn returns the column for the default org id field + InstanceIDColumn() database.Column + // StateColumn returns the column for the name field. + StateColumn() database.Column + // CreatedAtColumn returns the column for the created at field. + CreatedAtColumn() database.Column + // UpdatedAtColumn returns the column for the updated at field. + UpdatedAtColumn() database.Column +} + +// organizationConditions define all the conditions for the instance table. +type organizationConditions interface { + // IDCondition returns an equal filter on the id field. + IDCondition(instanceID string) OrgIdentifierCondition + // NameCondition returns a filter on the name field. + NameCondition(name string) OrgIdentifierCondition + // InstanceIDCondition returns a filter on the instance id field. + InstanceIDCondition(instanceID string) database.Condition + // StateCondition returns a filter on the name field. + StateCondition(state OrgState) database.Condition +} + +// organizationChanges define all the changes for the instance table. +type organizationChanges interface { + // SetName sets the name column. + SetName(name string) database.Change + // SetState sets the name column. + SetState(state OrgState) database.Change +} + +// OrganizationRepository is the interface for the instance repository. +type OrganizationRepository interface { + organizationColumns + organizationConditions + organizationChanges + + Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error) + List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error) + + Create(ctx context.Context, instance *Organization) error + Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error) + Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error) + + // Domains returns the domain sub repository for the organization. + // If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field. + // If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future. + Domains(shouldLoad bool) OrganizationDomainRepository +} + +type CreateOrganization struct { + Name string `json:"name"` +} + +// MemberRepository is a sub repository of the org repository and maybe the instance repository. +type MemberRepository interface { + AddMember(ctx context.Context, orgID, userID string, roles []string) error + SetMemberRoles(ctx context.Context, orgID, userID string, roles []string) error + RemoveMember(ctx context.Context, orgID, userID string) error +} diff --git a/backend/v3/domain/organization_domain.go b/backend/v3/domain/organization_domain.go new file mode 100644 index 00000000000..c0868e3a620 --- /dev/null +++ b/backend/v3/domain/organization_domain.go @@ -0,0 +1,84 @@ +package domain + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type OrganizationDomain struct { + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + OrgID string `json:"orgId,omitempty" db:"org_id"` + Domain string `json:"domain,omitempty" db:"domain"` + IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` + IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` + ValidationType *DomainValidationType `json:"validationType,omitempty" db:"validation_type"` + + CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"` + UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"` +} + +type AddOrganizationDomain struct { + InstanceID string `json:"instanceId,omitempty" db:"instance_id"` + OrgID string `json:"orgId,omitempty" db:"org_id"` + Domain string `json:"domain,omitempty" db:"domain"` + IsVerified bool `json:"isVerified,omitempty" db:"is_verified"` + IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"` + ValidationType *DomainValidationType `json:"validationType,omitempty" db:"validation_type"` + + // CreatedAt is the time when the domain was added. + // It is set by the repository and should not be set by the caller. + CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"` + // UpdatedAt is the time when the domain was added. + // It is set by the repository and should not be set by the caller. + UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"` +} + +type organizationDomainColumns interface { + domainColumns + // OrgIDColumn returns the column for the org id field. + OrgIDColumn() database.Column + // IsVerifiedColumn returns the column for the is verified field. + IsVerifiedColumn() database.Column + // ValidationTypeColumn returns the column for the verification type field. + ValidationTypeColumn() database.Column +} + +type organizationDomainConditions interface { + domainConditions + // OrgIDCondition returns a filter on the org id field. + OrgIDCondition(orgID string) database.Condition + // IsVerifiedCondition returns a filter on the is verified field. + IsVerifiedCondition(isVerified bool) database.Condition +} + +type organizationDomainChanges interface { + domainChanges + // SetVerified sets the is verified column to true. + SetVerified() database.Change + // SetValidationType sets the verification type column. + // If the domain is already verified, this is a no-op. + SetValidationType(verificationType DomainValidationType) database.Change +} + +type OrganizationDomainRepository interface { + organizationDomainColumns + organizationDomainConditions + organizationDomainChanges + + // Get returns a single domain based on the criteria. + // If no domain is found, it returns an error of type [database.ErrNotFound]. + // If multiple domains are found, it returns an error of type [database.ErrMultipleRows]. + Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error) + // List returns a list of domains based on the criteria. + // If no domains are found, it returns an empty slice. + List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error) + + // Add adds a new domain to the organization. + Add(ctx context.Context, domain *AddOrganizationDomain) error + // Update updates an existing domain in the organization. + Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) + // Remove removes a domain from the organization. + Remove(ctx context.Context, condition database.Condition) (int64, error) +} diff --git a/backend/v3/domain/orgstate_enumer.go b/backend/v3/domain/orgstate_enumer.go new file mode 100644 index 00000000000..69d4c053bc6 --- /dev/null +++ b/backend/v3/domain/orgstate_enumer.go @@ -0,0 +1,109 @@ +// Code generated by "enumer -type OrgState -transform lower -trimprefix OrgState -sql"; DO NOT EDIT. + +package domain + +import ( + "database/sql/driver" + "fmt" + "strings" +) + +const _OrgStateName = "activeinactive" + +var _OrgStateIndex = [...]uint8{0, 6, 14} + +const _OrgStateLowerName = "activeinactive" + +func (i OrgState) String() string { + if i >= OrgState(len(_OrgStateIndex)-1) { + return fmt.Sprintf("OrgState(%d)", i) + } + return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _OrgStateNoOp() { + var x [1]struct{} + _ = x[OrgStateActive-(0)] + _ = x[OrgStateInactive-(1)] +} + +var _OrgStateValues = []OrgState{OrgStateActive, OrgStateInactive} + +var _OrgStateNameToValueMap = map[string]OrgState{ + _OrgStateName[0:6]: OrgStateActive, + _OrgStateLowerName[0:6]: OrgStateActive, + _OrgStateName[6:14]: OrgStateInactive, + _OrgStateLowerName[6:14]: OrgStateInactive, +} + +var _OrgStateNames = []string{ + _OrgStateName[0:6], + _OrgStateName[6:14], +} + +// OrgStateString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func OrgStateString(s string) (OrgState, error) { + if val, ok := _OrgStateNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _OrgStateNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to OrgState values", s) +} + +// OrgStateValues returns all values of the enum +func OrgStateValues() []OrgState { + return _OrgStateValues +} + +// OrgStateStrings returns a slice of all String values of the enum +func OrgStateStrings() []string { + strs := make([]string, len(_OrgStateNames)) + copy(strs, _OrgStateNames) + return strs +} + +// IsAOrgState returns "true" if the value is listed in the enum definition. "false" otherwise +func (i OrgState) IsAOrgState() bool { + for _, v := range _OrgStateValues { + if i == v { + return true + } + } + return false +} + +func (i OrgState) Value() (driver.Value, error) { + return i.String(), nil +} + +func (i *OrgState) Scan(value interface{}) error { + if value == nil { + return nil + } + + var str string + switch v := value.(type) { + case []byte: + str = string(v) + case string: + str = v + case fmt.Stringer: + str = v.String() + default: + return fmt.Errorf("invalid value of OrgState: %[1]T(%[1]v)", value) + } + + val, err := OrgStateString(str) + if err != nil { + return err + } + + *i = val + return nil +} diff --git a/backend/v3/domain/set_email.go b/backend/v3/domain/set_email.go new file mode 100644 index 00000000000..630529bd9df --- /dev/null +++ b/backend/v3/domain/set_email.go @@ -0,0 +1,74 @@ +package domain + +// import ( +// "context" + +// "github.com/zitadel/zitadel/backend/v3/storage/eventstore" +// ) + +// // SetEmailCommand sets the email address of a user. +// // If allows verification as a sub command. +// // The verification command is executed after the email address is set. +// // The verification command is executed in the same transaction as the email address update. +// type SetEmailCommand struct { +// UserID string `json:"userId"` +// Email string `json:"email"` +// verification Commander +// } + +// var ( +// _ Commander = (*SetEmailCommand)(nil) +// _ eventer = (*SetEmailCommand)(nil) +// _ CreateHumanOpt = (*SetEmailCommand)(nil) +// ) + +// type SetEmailOpt interface { +// applyOnSetEmail(*SetEmailCommand) +// } + +// func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand { +// cmd := &SetEmailCommand{ +// UserID: userID, +// Email: email, +// } +// verificationType.applyOnSetEmail(cmd) +// return cmd +// } + +// // String implements [Commander]. +// func (cmd *SetEmailCommand) String() string { +// return "SetEmailCommand" +// } + +// func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error { +// close, err := opts.EnsureTx(ctx) +// if err != nil { +// return err +// } +// defer func() { err = close(ctx, err) }() +// // userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email) +// repo := userRepo(opts.DB).Human() +// err = repo.Update(ctx, repo.IDCondition(cmd.UserID), repo.SetEmailAddress(cmd.Email)) +// if err != nil { +// return err +// } + +// return opts.Invoke(ctx, cmd.verification) +// } + +// // Events implements [eventer]. +// func (cmd *SetEmailCommand) Events() []*eventstore.Event { +// return []*eventstore.Event{ +// { +// AggregateType: "user", +// AggregateID: cmd.UserID, +// Type: "user.email.set", +// Payload: cmd, +// }, +// } +// } + +// // applyOnCreateHuman implements [CreateHumanOpt]. +// func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand) { +// createUserCmd.email = cmd +// } diff --git a/backend/v3/domain/user.go b/backend/v3/domain/user.go new file mode 100644 index 00000000000..fae0d75b6e7 --- /dev/null +++ b/backend/v3/domain/user.go @@ -0,0 +1,241 @@ +package domain + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +// userColumns define all the columns of the user table. +type userColumns interface { + // InstanceIDColumn returns the column for the instance id field. + InstanceIDColumn() database.Column + // OrgIDColumn returns the column for the org id field. + OrgIDColumn() database.Column + // IDColumn returns the column for the id field. + IDColumn() database.Column + // UsernameColumn returns the column for the username field. + UsernameColumn() database.Column + // CreatedAtColumn returns the column for the created at field. + CreatedAtColumn() database.Column + // UpdatedAtColumn returns the column for the updated at field. + UpdatedAtColumn() database.Column + // DeletedAtColumn returns the column for the deleted at field. + DeletedAtColumn() database.Column +} + +// userConditions define all the conditions for the user table. +type userConditions interface { + // InstanceIDCondition returns an equal filter on the instance id field. + InstanceIDCondition(instanceID string) database.Condition + // OrgIDCondition returns an equal filter on the org id field. + OrgIDCondition(orgID string) database.Condition + // IDCondition returns an equal filter on the id field. + IDCondition(userID string) database.Condition + // UsernameCondition returns a filter on the username field. + UsernameCondition(op database.TextOperation, username string) database.Condition + // CreatedAtCondition returns a filter on the created at field. + CreatedAtCondition(op database.NumberOperation, createdAt time.Time) database.Condition + // UpdatedAtCondition returns a filter on the updated at field. + UpdatedAtCondition(op database.NumberOperation, updatedAt time.Time) database.Condition + // DeletedAtCondition filters for deleted users is isDeleted is set to true otherwise only not deleted users must be filtered. + DeletedCondition(isDeleted bool) database.Condition + // DeletedAtCondition filters for deleted users based on the given parameters. + DeletedAtCondition(op database.NumberOperation, deletedAt time.Time) database.Condition +} + +// userChanges define all the changes for the user table. +type userChanges interface { + // SetUsername sets the username column. + SetUsername(username string) database.Change +} + +// UserRepository is the interface for the user repository. +type UserRepository interface { + userColumns + userConditions + userChanges + // Get returns a user based on the given condition. + Get(ctx context.Context, opts ...database.QueryOption) (*User, error) + // List returns a list of users based on the given condition. + List(ctx context.Context, opts ...database.QueryOption) ([]*User, error) + // Create creates a new user. + Create(ctx context.Context, user *User) error + // Delete removes users based on the given condition. + Delete(ctx context.Context, condition database.Condition) error + // Human returns the [HumanRepository]. + Human() HumanRepository + // Machine returns the [MachineRepository]. + Machine() MachineRepository +} + +// humanColumns define all the columns of the human table which inherits the user table. +type humanColumns interface { + userColumns + // FirstNameColumn returns the column for the first name field. + FirstNameColumn() database.Column + // LastNameColumn returns the column for the last name field. + LastNameColumn() database.Column + // EmailAddressColumn returns the column for the email address field. + EmailAddressColumn() database.Column + // EmailVerifiedAtColumn returns the column for the email verified at field. + EmailVerifiedAtColumn() database.Column + // PhoneNumberColumn returns the column for the phone number field. + PhoneNumberColumn() database.Column + // PhoneVerifiedAtColumn returns the column for the phone verified at field. + PhoneVerifiedAtColumn() database.Column +} + +// humanConditions define all the conditions for the human table which inherits the user table. +type humanConditions interface { + userConditions + // FirstNameCondition returns a filter on the first name field. + FirstNameCondition(op database.TextOperation, firstName string) database.Condition + // LastNameCondition returns a filter on the last name field. + LastNameCondition(op database.TextOperation, lastName string) database.Condition + // EmailAddressCondition returns a filter on the email address field. + EmailAddressCondition(op database.TextOperation, email string) database.Condition + // EmailVerifiedCondition returns a filter that checks if the email is verified or not. + EmailVerifiedCondition(isVerified bool) database.Condition + // EmailVerifiedAtCondition returns a filter on the email verified at field. + EmailVerifiedAtCondition(op database.NumberOperation, emailVerifiedAt time.Time) database.Condition + + // PhoneNumberCondition returns a filter on the phone number field. + PhoneNumberCondition(op database.TextOperation, phoneNumber string) database.Condition + // PhoneVerifiedCondition returns a filter that checks if the phone is verified or not. + PhoneVerifiedCondition(isVerified bool) database.Condition + // PhoneVerifiedAtCondition returns a filter on the phone verified at field. + PhoneVerifiedAtCondition(op database.NumberOperation, phoneVerifiedAt time.Time) database.Condition +} + +// humanChanges define all the changes for the human table which inherits the user table. +type humanChanges interface { + userChanges + // SetFirstName sets the first name field of the human. + SetFirstName(firstName string) database.Change + // SetLastName sets the last name field of the human. + SetLastName(lastName string) database.Change + + // SetEmail sets the email address and verified field of the email + // if verifiedAt is nil the email is not verified + SetEmail(address string, verifiedAt *time.Time) database.Change + // SetEmailAddress sets the email address field of the email + SetEmailAddress(email string) database.Change + // SetEmailVerifiedAt sets the verified column of the email + // if at is zero the statement uses the database timestamp + SetEmailVerifiedAt(at time.Time) database.Change + + // SetPhone sets the phone number and verified field + // if verifiedAt is nil the phone is not verified + SetPhone(number string, verifiedAt *time.Time) database.Change + // SetPhoneNumber sets the phone number field + SetPhoneNumber(phoneNumber string) database.Change + // SetPhoneVerifiedAt sets the verified field of the phone + // if at is zero the statement uses the database timestamp + SetPhoneVerifiedAt(at time.Time) database.Change +} + +// HumanRepository is the interface for the human repository it inherits the user repository. +type HumanRepository interface { + humanColumns + humanConditions + humanChanges + + // Get returns an email based on the given condition. + GetEmail(ctx context.Context, condition database.Condition) (*Email, error) + // Update updates human users based on the given condition and changes. + Update(ctx context.Context, condition database.Condition, changes ...database.Change) error +} + +// machineColumns define all the columns of the machine table which inherits the user table. +type machineColumns interface { + userColumns + // DescriptionColumn returns the column for the description field. + DescriptionColumn() database.Column +} + +// machineConditions define all the conditions for the machine table which inherits the user table. +type machineConditions interface { + userConditions + // DescriptionCondition returns a filter on the description field. + DescriptionCondition(op database.TextOperation, description string) database.Condition +} + +// machineChanges define all the changes for the machine table which inherits the user table. +type machineChanges interface { + userChanges + // SetDescription sets the description field of the machine. + SetDescription(description string) database.Change +} + +// MachineRepository is the interface for the machine repository it inherits the user repository. +type MachineRepository interface { + // Update updates machine users based on the given condition and changes. + Update(ctx context.Context, condition database.Condition, changes ...database.Change) error + + machineColumns + machineConditions + machineChanges +} + +// UserTraits is implemented by [Human] and [Machine]. +type UserTraits interface { + Type() UserType +} + +type UserType string + +const ( + UserTypeHuman UserType = "human" + UserTypeMachine UserType = "machine" +) + +type User struct { + InstanceID string + OrgID string + ID string + + CreatedAt time.Time + UpdatedAt time.Time + DeletedAt time.Time + + Username string + + Traits UserTraits +} + +type Human struct { + FirstName string `json:"firstName"` + LastName string `json:"lastName"` + Email *Email `json:"email,omitempty"` + Phone *Phone `json:"phone,omitempty"` +} + +// Type implements [UserTraits]. +func (h *Human) Type() UserType { + return UserTypeHuman +} + +var _ UserTraits = (*Human)(nil) + +type Email struct { + Address string `json:"address"` + VerifiedAt time.Time `json:"verifiedAt"` +} + +type Phone struct { + Number string `json:"number"` + VerifiedAt time.Time `json:"verifiedAt"` +} + +type Machine struct { + Description string `json:"description"` +} + +// Type implements [UserTraits]. +func (m *Machine) Type() UserType { + return UserTypeMachine +} + +var _ UserTraits = (*Machine)(nil) diff --git a/backend/v3/storage/cache/cache.go b/backend/v3/storage/cache/cache.go new file mode 100644 index 00000000000..dc05208caa5 --- /dev/null +++ b/backend/v3/storage/cache/cache.go @@ -0,0 +1,112 @@ +// Package cache provides abstraction of cache implementations that can be used by zitadel. +package cache + +import ( + "context" + "time" + + "github.com/zitadel/logging" +) + +// Purpose describes which object types are stored by a cache. +type Purpose int + +//go:generate enumer -type Purpose -transform snake -trimprefix Purpose +const ( + PurposeUnspecified Purpose = iota + PurposeAuthzInstance + PurposeMilestones + PurposeOrganization + PurposeIdPFormCallback +) + +// Cache stores objects with a value of type `V`. +// Objects may be referred to by one or more indices. +// Implementations may encode the value for storage. +// This means non-exported fields may be lost and objects +// with function values may fail to encode. +// See https://pkg.go.dev/encoding/json#Marshal for example. +// +// `I` is the type by which indices are identified, +// typically an enum for type-safe access. +// Indices are defined when calling the constructor of an implementation of this interface. +// It is illegal to refer to an idex not defined during construction. +// +// `K` is the type used as key in each index. +// Due to the limitations in type constraints, all indices use the same key type. +// +// Implementations are free to use stricter type constraints or fixed typing. +type Cache[I, K comparable, V Entry[I, K]] interface { + // Get an object through specified index. + // An [IndexUnknownError] may be returned if the index is unknown. + // [ErrCacheMiss] is returned if the key was not found in the index, + // or the object is not valid. + Get(ctx context.Context, index I, key K) (V, bool) + + // Set an object. + // Keys are created on each index based in the [Entry.Keys] method. + // If any key maps to an existing object, the object is invalidated, + // regardless if the object has other keys defined in the new entry. + // This to prevent ghost objects when an entry reduces the amount of keys + // for a given index. + Set(ctx context.Context, value V) + + // Invalidate an object through specified index. + // Implementations may choose to instantly delete the object, + // defer until prune or a separate cleanup routine. + // Invalidated object are no longer returned from Get. + // It is safe to call Invalidate multiple times or on non-existing entries. + Invalidate(ctx context.Context, index I, key ...K) error + + // Delete one or more keys from a specific index. + // An [IndexUnknownError] may be returned if the index is unknown. + // The referred object is not invalidated and may still be accessible though + // other indices and keys. + // It is safe to call Delete multiple times or on non-existing entries + Delete(ctx context.Context, index I, key ...K) error + + // Truncate deletes all cached objects. + Truncate(ctx context.Context) error +} + +// Entry contains a value of type `V` to be cached. +// +// `I` is the type by which indices are identified, +// typically an enum for type-safe access. +// +// `K` is the type used as key in an index. +// Due to the limitations in type constraints, all indices use the same key type. +type Entry[I, K comparable] interface { + // Keys returns which keys map to the object in a specified index. + // May return nil if the index in unknown or when there are no keys. + Keys(index I) (key []K) +} + +type Connector int + +//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text +const ( + // Empty line comment ensures empty string for unspecified value + ConnectorUnspecified Connector = iota // + ConnectorMemory + ConnectorPostgres + ConnectorRedis +) + +type Config struct { + Connector Connector + + // Age since an object was added to the cache, + // after which the object is considered invalid. + // 0 disables max age checks. + MaxAge time.Duration + + // Age since last use (Get) of an object, + // after which the object is considered invalid. + // 0 disables last use age checks. + LastUseAge time.Duration + + // Log allows logging of the specific cache. + // By default only errors are logged to stdout. + Log *logging.Config +} diff --git a/backend/v3/storage/cache/connector/connector.go b/backend/v3/storage/cache/connector/connector.go new file mode 100644 index 00000000000..487680155c8 --- /dev/null +++ b/backend/v3/storage/cache/connector/connector.go @@ -0,0 +1,49 @@ +// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages. +package connector + +import ( + "context" + "fmt" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" + "github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap" + "github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop" +) + +type CachesConfig struct { + Connectors struct { + Memory gomap.Config + } + Instance *cache.Config + Milestones *cache.Config + Organization *cache.Config + IdPFormCallbacks *cache.Config +} + +type Connectors struct { + Config CachesConfig + Memory *gomap.Connector +} + +func StartConnectors(conf *CachesConfig) (Connectors, error) { + if conf == nil { + return Connectors{}, nil + } + return Connectors{ + Config: *conf, + Memory: gomap.NewConnector(conf.Connectors.Memory), + }, nil +} + +func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) { + if conf == nil || conf.Connector == cache.ConnectorUnspecified { + return noop.NewCache[I, K, V](), nil + } + if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil { + c := gomap.NewCache[I, K, V](background, indices, *conf) + connectors.Memory.Config.StartAutoPrune(background, c, purpose) + return c, nil + } + + return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector) +} diff --git a/backend/v3/storage/cache/connector/gomap/connector.go b/backend/v3/storage/cache/connector/gomap/connector.go new file mode 100644 index 00000000000..c453e34fc5a --- /dev/null +++ b/backend/v3/storage/cache/connector/gomap/connector.go @@ -0,0 +1,23 @@ +package gomap + +import ( + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type Config struct { + Enabled bool + AutoPrune cache.AutoPruneConfig +} + +type Connector struct { + Config cache.AutoPruneConfig +} + +func NewConnector(config Config) *Connector { + if !config.Enabled { + return nil + } + return &Connector{ + Config: config.AutoPrune, + } +} diff --git a/backend/v3/storage/cache/connector/gomap/gomap.go b/backend/v3/storage/cache/connector/gomap/gomap.go new file mode 100644 index 00000000000..6b25d642c45 --- /dev/null +++ b/backend/v3/storage/cache/connector/gomap/gomap.go @@ -0,0 +1,200 @@ +package gomap + +import ( + "context" + "errors" + "log/slog" + "maps" + "os" + "sync" + "sync/atomic" + "time" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type mapCache[I, K comparable, V cache.Entry[I, K]] struct { + config *cache.Config + indexMap map[I]*index[K, V] + logger *slog.Logger +} + +// NewCache returns an in-memory Cache implementation based on the builtin go map type. +// Object values are stored as-is and there is no encoding or decoding involved. +func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] { + m := &mapCache[I, K, V]{ + config: &config, + indexMap: make(map[I]*index[K, V], len(indices)), + logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{ + AddSource: true, + Level: slog.LevelError, + })), + } + if config.Log != nil { + m.logger = config.Log.Slog() + } + m.logger.InfoContext(background, "map cache logging enabled") + + for _, name := range indices { + m.indexMap[name] = &index[K, V]{ + config: m.config, + entries: make(map[K]*entry[V]), + } + } + return m +} + +func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) { + i, ok := c.indexMap[index] + if !ok { + c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key) + return value, false + } + entry, err := i.Get(key) + if err == nil { + c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key) + return entry.value, true + } + if errors.Is(err, cache.ErrCacheMiss) { + c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key) + return value, false + } + c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key) + return value, false +} + +func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) { + now := time.Now() + entry := &entry[V]{ + value: value, + created: now, + } + entry.lastUse.Store(now.UnixMicro()) + + for name, i := range c.indexMap { + keys := value.Keys(name) + i.Set(keys, entry) + c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys) + } +} + +func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error { + i, ok := c.indexMap[index] + if !ok { + return cache.NewIndexUnknownErr(index) + } + i.Invalidate(keys) + c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys) + return nil +} + +func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error { + i, ok := c.indexMap[index] + if !ok { + return cache.NewIndexUnknownErr(index) + } + i.Delete(keys) + c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys) + return nil +} + +func (c *mapCache[I, K, V]) Prune(ctx context.Context) error { + for name, index := range c.indexMap { + index.Prune() + c.logger.DebugContext(ctx, "map cache prune", "index", name) + } + return nil +} + +func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error { + for name, index := range c.indexMap { + index.Truncate() + c.logger.DebugContext(ctx, "map cache truncate", "index", name) + } + return nil +} + +type index[K comparable, V any] struct { + mutex sync.RWMutex + config *cache.Config + entries map[K]*entry[V] +} + +func (i *index[K, V]) Get(key K) (*entry[V], error) { + i.mutex.RLock() + entry, ok := i.entries[key] + i.mutex.RUnlock() + if ok && entry.isValid(i.config) { + return entry, nil + } + return nil, cache.ErrCacheMiss +} + +func (c *index[K, V]) Set(keys []K, entry *entry[V]) { + c.mutex.Lock() + for _, key := range keys { + c.entries[key] = entry + } + c.mutex.Unlock() +} + +func (i *index[K, V]) Invalidate(keys []K) { + i.mutex.RLock() + for _, key := range keys { + if entry, ok := i.entries[key]; ok { + entry.invalid.Store(true) + } + } + i.mutex.RUnlock() +} + +func (c *index[K, V]) Delete(keys []K) { + c.mutex.Lock() + for _, key := range keys { + delete(c.entries, key) + } + c.mutex.Unlock() +} + +func (c *index[K, V]) Prune() { + c.mutex.Lock() + maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool { + return !entry.isValid(c.config) + }) + c.mutex.Unlock() +} + +func (c *index[K, V]) Truncate() { + c.mutex.Lock() + c.entries = make(map[K]*entry[V]) + c.mutex.Unlock() +} + +type entry[V any] struct { + value V + created time.Time + invalid atomic.Bool + lastUse atomic.Int64 // UnixMicro time +} + +func (e *entry[V]) isValid(c *cache.Config) bool { + if e.invalid.Load() { + return false + } + now := time.Now() + if c.MaxAge > 0 { + if e.created.Add(c.MaxAge).Before(now) { + e.invalid.Store(true) + return false + } + } + if c.LastUseAge > 0 { + lastUse := e.lastUse.Load() + if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) { + e.invalid.Store(true) + return false + } + e.lastUse.CompareAndSwap(lastUse, now.UnixMicro()) + } + return true +} diff --git a/backend/v3/storage/cache/connector/gomap/gomap_test.go b/backend/v3/storage/cache/connector/gomap/gomap_test.go new file mode 100644 index 00000000000..62bbc471a1f --- /dev/null +++ b/backend/v3/storage/cache/connector/gomap/gomap_test.go @@ -0,0 +1,329 @@ +package gomap + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type testIndex int + +const ( + testIndexID testIndex = iota + testIndexName +) + +var testIndices = []testIndex{ + testIndexID, + testIndexName, +} + +type testObject struct { + id string + names []string +} + +func (o *testObject) Keys(index testIndex) []string { + switch index { + case testIndexID: + return []string{o.id} + case testIndexName: + return o.names + default: + return nil + } +} + +func Test_mapCache_Get(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + + type args struct { + index testIndex + key string + } + tests := []struct { + name string + args args + want *testObject + wantOk bool + }{ + { + name: "ok", + args: args{ + index: testIndexID, + key: "id", + }, + want: obj, + wantOk: true, + }, + { + name: "miss", + args: args{ + index: testIndexID, + key: "spanac", + }, + want: nil, + wantOk: false, + }, + { + name: "unknown index", + args: args{ + index: 99, + key: "id", + }, + want: nil, + wantOk: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, ok := c.Get(context.Background(), tt.args.index, tt.args.key) + assert.Equal(t, tt.want, got) + assert.Equal(t, tt.wantOk, ok) + }) + } +} + +func Test_mapCache_Invalidate(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + err := c.Invalidate(context.Background(), testIndexName, "bar") + require.NoError(t, err) + got, ok := c.Get(context.Background(), testIndexID, "id") + assert.Nil(t, got) + assert.False(t, ok) +} + +func Test_mapCache_Delete(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + obj := &testObject{ + id: "id", + names: []string{"foo", "bar"}, + } + c.Set(context.Background(), obj) + err := c.Delete(context.Background(), testIndexName, "bar") + require.NoError(t, err) + + // Shouldn't find object by deleted name + got, ok := c.Get(context.Background(), testIndexName, "bar") + assert.Nil(t, got) + assert.False(t, ok) + + // Should find object by other name + got, ok = c.Get(context.Background(), testIndexName, "foo") + assert.Equal(t, obj, got) + assert.True(t, ok) + + // Should find object by id + got, ok = c.Get(context.Background(), testIndexID, "id") + assert.Equal(t, obj, got) + assert.True(t, ok) +} + +func Test_mapCache_Prune(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + + objects := []*testObject{ + { + id: "id1", + names: []string{"foo", "bar"}, + }, + { + id: "id2", + names: []string{"hello"}, + }, + } + for _, obj := range objects { + c.Set(context.Background(), obj) + } + // invalidate one entry + err := c.Invalidate(context.Background(), testIndexName, "bar") + require.NoError(t, err) + + err = c.(cache.Pruner).Prune(context.Background()) + require.NoError(t, err) + + // Other object should still be found + got, ok := c.Get(context.Background(), testIndexID, "id2") + assert.Equal(t, objects[1], got) + assert.True(t, ok) +} + +func Test_mapCache_Truncate(t *testing.T) { + c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{ + MaxAge: time.Second, + LastUseAge: time.Second / 4, + Log: &logging.Config{ + Level: "debug", + AddSource: true, + }, + }) + objects := []*testObject{ + { + id: "id1", + names: []string{"foo", "bar"}, + }, + { + id: "id2", + names: []string{"hello"}, + }, + } + for _, obj := range objects { + c.Set(context.Background(), obj) + } + + err := c.Truncate(context.Background()) + require.NoError(t, err) + + mc := c.(*mapCache[testIndex, string, *testObject]) + for _, index := range mc.indexMap { + index.mutex.RLock() + assert.Len(t, index.entries, 0) + index.mutex.RUnlock() + } +} + +func Test_entry_isValid(t *testing.T) { + type fields struct { + created time.Time + invalid bool + lastUse time.Time + } + tests := []struct { + name string + fields fields + config *cache.Config + want bool + }{ + { + name: "invalid", + fields: fields{ + created: time.Now(), + invalid: true, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "max age exceeded", + fields: fields{ + created: time.Now().Add(-(time.Minute + time.Second)), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "max age disabled", + fields: fields{ + created: time.Now().Add(-(time.Minute + time.Second)), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + LastUseAge: time.Second, + }, + want: true, + }, + { + name: "last use age exceeded", + fields: fields{ + created: time.Now().Add(-(time.Minute / 2)), + invalid: false, + lastUse: time.Now().Add(-(time.Second * 2)), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: false, + }, + { + name: "last use age disabled", + fields: fields{ + created: time.Now().Add(-(time.Minute / 2)), + invalid: false, + lastUse: time.Now().Add(-(time.Second * 2)), + }, + config: &cache.Config{ + MaxAge: time.Minute, + }, + want: true, + }, + { + name: "valid", + fields: fields{ + created: time.Now(), + invalid: false, + lastUse: time.Now(), + }, + config: &cache.Config{ + MaxAge: time.Minute, + LastUseAge: time.Second, + }, + want: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + e := &entry[any]{ + created: tt.fields.created, + } + e.invalid.Store(tt.fields.invalid) + e.lastUse.Store(tt.fields.lastUse.UnixMicro()) + got := e.isValid(tt.config) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/backend/v3/storage/cache/connector/noop/noop.go b/backend/v3/storage/cache/connector/noop/noop.go new file mode 100644 index 00000000000..12d261a77d3 --- /dev/null +++ b/backend/v3/storage/cache/connector/noop/noop.go @@ -0,0 +1,21 @@ +package noop + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/storage/cache" +) + +type noop[I, K comparable, V cache.Entry[I, K]] struct{} + +// NewCache returns a cache that does nothing +func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] { + return noop[I, K, V]{} +} + +func (noop[I, K, V]) Set(context.Context, V) {} +func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return } +func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return } +func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return } +func (noop[I, K, V]) Prune(context.Context) (err error) { return } +func (noop[I, K, V]) Truncate(context.Context) (err error) { return } diff --git a/backend/v3/storage/cache/connector_enumer.go b/backend/v3/storage/cache/connector_enumer.go new file mode 100644 index 00000000000..7ea014db166 --- /dev/null +++ b/backend/v3/storage/cache/connector_enumer.go @@ -0,0 +1,98 @@ +// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _ConnectorName = "memorypostgresredis" + +var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19} + +const _ConnectorLowerName = "memorypostgresredis" + +func (i Connector) String() string { + if i < 0 || i >= Connector(len(_ConnectorIndex)-1) { + return fmt.Sprintf("Connector(%d)", i) + } + return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _ConnectorNoOp() { + var x [1]struct{} + _ = x[ConnectorUnspecified-(0)] + _ = x[ConnectorMemory-(1)] + _ = x[ConnectorPostgres-(2)] + _ = x[ConnectorRedis-(3)] +} + +var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis} + +var _ConnectorNameToValueMap = map[string]Connector{ + _ConnectorName[0:0]: ConnectorUnspecified, + _ConnectorLowerName[0:0]: ConnectorUnspecified, + _ConnectorName[0:6]: ConnectorMemory, + _ConnectorLowerName[0:6]: ConnectorMemory, + _ConnectorName[6:14]: ConnectorPostgres, + _ConnectorLowerName[6:14]: ConnectorPostgres, + _ConnectorName[14:19]: ConnectorRedis, + _ConnectorLowerName[14:19]: ConnectorRedis, +} + +var _ConnectorNames = []string{ + _ConnectorName[0:0], + _ConnectorName[0:6], + _ConnectorName[6:14], + _ConnectorName[14:19], +} + +// ConnectorString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func ConnectorString(s string) (Connector, error) { + if val, ok := _ConnectorNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Connector values", s) +} + +// ConnectorValues returns all values of the enum +func ConnectorValues() []Connector { + return _ConnectorValues +} + +// ConnectorStrings returns a slice of all String values of the enum +func ConnectorStrings() []string { + strs := make([]string, len(_ConnectorNames)) + copy(strs, _ConnectorNames) + return strs +} + +// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Connector) IsAConnector() bool { + for _, v := range _ConnectorValues { + if i == v { + return true + } + } + return false +} + +// MarshalText implements the encoding.TextMarshaler interface for Connector +func (i Connector) MarshalText() ([]byte, error) { + return []byte(i.String()), nil +} + +// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector +func (i *Connector) UnmarshalText(text []byte) error { + var err error + *i, err = ConnectorString(string(text)) + return err +} diff --git a/backend/v3/storage/cache/doc.go b/backend/v3/storage/cache/doc.go new file mode 100644 index 00000000000..a6e357989b5 --- /dev/null +++ b/backend/v3/storage/cache/doc.go @@ -0,0 +1,2 @@ +// this package is copy pasted from the internal/cache package +package cache diff --git a/backend/v3/storage/cache/error.go b/backend/v3/storage/cache/error.go new file mode 100644 index 00000000000..b66b9447bf1 --- /dev/null +++ b/backend/v3/storage/cache/error.go @@ -0,0 +1,29 @@ +package cache + +import ( + "errors" + "fmt" +) + +type IndexUnknownError[I comparable] struct { + index I +} + +func NewIndexUnknownErr[I comparable](index I) error { + return IndexUnknownError[I]{index} +} + +func (i IndexUnknownError[I]) Error() string { + return fmt.Sprintf("index %v unknown", i.index) +} + +func (a IndexUnknownError[I]) Is(err error) bool { + if b, ok := err.(IndexUnknownError[I]); ok { + return a.index == b.index + } + return false +} + +var ( + ErrCacheMiss = errors.New("cache miss") +) diff --git a/backend/v3/storage/cache/pruner.go b/backend/v3/storage/cache/pruner.go new file mode 100644 index 00000000000..959762d410b --- /dev/null +++ b/backend/v3/storage/cache/pruner.go @@ -0,0 +1,76 @@ +package cache + +import ( + "context" + "math/rand" + "time" + + "github.com/jonboulle/clockwork" + "github.com/zitadel/logging" +) + +// Pruner is an optional [Cache] interface. +type Pruner interface { + // Prune deletes all invalidated or expired objects. + Prune(ctx context.Context) error +} + +type PrunerCache[I, K comparable, V Entry[I, K]] interface { + Cache[I, K, V] + Pruner +} + +type AutoPruneConfig struct { + // Interval at which the cache is automatically pruned. + // 0 or lower disables automatic pruning. + Interval time.Duration + + // Timeout for an automatic prune. + // It is recommended to keep the value shorter than AutoPruneInterval + // 0 or lower disables automatic pruning. + Timeout time.Duration +} + +func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) { + return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock()) +} + +func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) { + if c.Interval <= 0 { + return func() {} + } + background, cancel := context.WithCancel(background) + // randomize the first interval + timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval)))) + go c.pruneTimer(background, pruner, purpose, timer) + return cancel +} + +func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) { + defer func() { + if !timer.Stop() { + <-timer.Chan() + } + }() + + for { + select { + case <-background.Done(): + return + case <-timer.Chan(): + err := c.doPrune(background, pruner) + logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune") + timer.Reset(c.Interval) + } + } +} + +func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error { + ctx, cancel := context.WithCancel(background) + defer cancel() + if c.Timeout > 0 { + ctx, cancel = context.WithTimeout(background, c.Timeout) + defer cancel() + } + return pruner.Prune(ctx) +} diff --git a/backend/v3/storage/cache/pruner_test.go b/backend/v3/storage/cache/pruner_test.go new file mode 100644 index 00000000000..faaedeb88ce --- /dev/null +++ b/backend/v3/storage/cache/pruner_test.go @@ -0,0 +1,43 @@ +package cache + +import ( + "context" + "testing" + "time" + + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" +) + +type testPruner struct { + called chan struct{} +} + +func (p *testPruner) Prune(context.Context) error { + p.called <- struct{}{} + return nil +} + +func TestAutoPruneConfig_startAutoPrune(t *testing.T) { + c := AutoPruneConfig{ + Interval: time.Second, + Timeout: time.Millisecond, + } + ctx, cancel := context.WithTimeout(context.Background(), time.Second) + defer cancel() + + pruner := testPruner{ + called: make(chan struct{}), + } + clock := clockwork.NewFakeClock() + close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock) + defer close() + clock.Advance(time.Second) + + select { + case _, ok := <-pruner.called: + assert.True(t, ok) + case <-ctx.Done(): + t.Fatal(ctx.Err()) + } +} diff --git a/backend/v3/storage/cache/purpose_enumer.go b/backend/v3/storage/cache/purpose_enumer.go new file mode 100644 index 00000000000..a93a978efbc --- /dev/null +++ b/backend/v3/storage/cache/purpose_enumer.go @@ -0,0 +1,90 @@ +// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT. + +package cache + +import ( + "fmt" + "strings" +) + +const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" + +var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65} + +const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback" + +func (i Purpose) String() string { + if i < 0 || i >= Purpose(len(_PurposeIndex)-1) { + return fmt.Sprintf("Purpose(%d)", i) + } + return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]] +} + +// An "invalid array index" compiler error signifies that the constant values have changed. +// Re-run the stringer command to generate them again. +func _PurposeNoOp() { + var x [1]struct{} + _ = x[PurposeUnspecified-(0)] + _ = x[PurposeAuthzInstance-(1)] + _ = x[PurposeMilestones-(2)] + _ = x[PurposeOrganization-(3)] + _ = x[PurposeIdPFormCallback-(4)] +} + +var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback} + +var _PurposeNameToValueMap = map[string]Purpose{ + _PurposeName[0:11]: PurposeUnspecified, + _PurposeLowerName[0:11]: PurposeUnspecified, + _PurposeName[11:25]: PurposeAuthzInstance, + _PurposeLowerName[11:25]: PurposeAuthzInstance, + _PurposeName[25:35]: PurposeMilestones, + _PurposeLowerName[25:35]: PurposeMilestones, + _PurposeName[35:47]: PurposeOrganization, + _PurposeLowerName[35:47]: PurposeOrganization, + _PurposeName[47:65]: PurposeIdPFormCallback, + _PurposeLowerName[47:65]: PurposeIdPFormCallback, +} + +var _PurposeNames = []string{ + _PurposeName[0:11], + _PurposeName[11:25], + _PurposeName[25:35], + _PurposeName[35:47], + _PurposeName[47:65], +} + +// PurposeString retrieves an enum value from the enum constants string name. +// Throws an error if the param is not part of the enum. +func PurposeString(s string) (Purpose, error) { + if val, ok := _PurposeNameToValueMap[s]; ok { + return val, nil + } + + if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok { + return val, nil + } + return 0, fmt.Errorf("%s does not belong to Purpose values", s) +} + +// PurposeValues returns all values of the enum +func PurposeValues() []Purpose { + return _PurposeValues +} + +// PurposeStrings returns a slice of all String values of the enum +func PurposeStrings() []string { + strs := make([]string, len(_PurposeNames)) + copy(strs, _PurposeNames) + return strs +} + +// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise +func (i Purpose) IsAPurpose() bool { + for _, v := range _PurposeValues { + if i == v { + return true + } + } + return false +} diff --git a/backend/v3/storage/database/change.go b/backend/v3/storage/database/change.go new file mode 100644 index 00000000000..724e029dbe9 --- /dev/null +++ b/backend/v3/storage/database/change.go @@ -0,0 +1,53 @@ +package database + +// Change represents a change to a column in a database table. +// Its written in the SET clause of an UPDATE statement. +type Change interface { + Write(builder *StatementBuilder) +} + +type change[V Value] struct { + column Column + value V +} + +var _ Change = (*change[string])(nil) + +func NewChange[V Value](col Column, value V) Change { + return &change[V]{ + column: col, + value: value, + } +} + +func NewChangePtr[V Value](col Column, value *V) Change { + if value == nil { + return NewChange(col, NullInstruction) + } + return NewChange(col, *value) +} + +// Write implements [Change]. +func (c change[V]) Write(builder *StatementBuilder) { + c.column.WriteUnqualified(builder) + builder.WriteString(" = ") + builder.WriteArg(c.value) +} + +type Changes []Change + +func NewChanges(cols ...Change) Change { + return Changes(cols) +} + +// Write implements [Change]. +func (m Changes) Write(builder *StatementBuilder) { + for i, col := range m { + if i > 0 { + builder.WriteString(", ") + } + col.Write(builder) + } +} + +var _ Change = Changes(nil) diff --git a/backend/v3/storage/database/column.go b/backend/v3/storage/database/column.go new file mode 100644 index 00000000000..7f57637d38a --- /dev/null +++ b/backend/v3/storage/database/column.go @@ -0,0 +1,85 @@ +package database + +type Columns []Column + +// WriteQualified implements [Column]. +func (m Columns) WriteQualified(builder *StatementBuilder) { + for i, col := range m { + if i > 0 { + builder.WriteString(", ") + } + col.WriteQualified(builder) + } +} + +// WriteUnqualified implements [Column]. +func (m Columns) WriteUnqualified(builder *StatementBuilder) { + for i, col := range m { + if i > 0 { + builder.WriteString(", ") + } + col.WriteUnqualified(builder) + } +} + +// Column represents a column in a database table. +type Column interface { + // Write(builder *StatementBuilder) + WriteQualified(builder *StatementBuilder) + WriteUnqualified(builder *StatementBuilder) +} + +type column struct { + table string + name string +} + +func NewColumn(table, name string) Column { + return column{table: table, name: name} +} + +// WriteQualified implements [Column]. +func (c column) WriteQualified(builder *StatementBuilder) { + builder.Grow(len(c.table) + len(c.name) + 1) + builder.WriteString(c.table) + builder.WriteRune('.') + builder.WriteString(c.name) +} + +// WriteUnqualified implements [Column]. +func (c column) WriteUnqualified(builder *StatementBuilder) { + builder.WriteString(c.name) +} + +var _ Column = (*column)(nil) + +// // ignoreCaseColumn represents two database columns, one for the +// // original value and one for the lower case value. +// type ignoreCaseColumn interface { +// Column +// WriteIgnoreCase(builder *StatementBuilder) +// } + +// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn { +// return ignoreCaseCol{ +// column: col, +// suffix: suffix, +// } +// } + +// type ignoreCaseCol struct { +// column Column +// suffix string +// } + +// // WriteIgnoreCase implements [ignoreCaseColumn]. +// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) { +// c.column.WriteQualified(builder) +// builder.WriteString(c.suffix) +// } + +// // WriteQualified implements [ignoreCaseColumn]. +// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) { +// c.column.WriteQualified(builder) +// builder.WriteString(c.suffix) +// } diff --git a/backend/v3/storage/database/condition.go b/backend/v3/storage/database/condition.go new file mode 100644 index 00000000000..5c8da5ff4b2 --- /dev/null +++ b/backend/v3/storage/database/condition.go @@ -0,0 +1,130 @@ +package database + +// Condition represents a SQL condition. +// Its written after the WHERE keyword in a SQL statement. +type Condition interface { + Write(builder *StatementBuilder) +} + +type and struct { + conditions []Condition +} + +// Write implements [Condition]. +func (a *and) Write(builder *StatementBuilder) { + if len(a.conditions) > 1 { + builder.WriteString("(") + defer builder.WriteString(")") + } + for i, condition := range a.conditions { + if i > 0 { + builder.WriteString(" AND ") + } + condition.Write(builder) + } +} + +// And combines multiple conditions with AND. +func And(conditions ...Condition) *and { + return &and{conditions: conditions} +} + +var _ Condition = (*and)(nil) + +type or struct { + conditions []Condition +} + +// Write implements [Condition]. +func (o *or) Write(builder *StatementBuilder) { + if len(o.conditions) > 1 { + builder.WriteString("(") + defer builder.WriteString(")") + } + for i, condition := range o.conditions { + if i > 0 { + builder.WriteString(" OR ") + } + condition.Write(builder) + } +} + +// Or combines multiple conditions with OR. +func Or(conditions ...Condition) *or { + return &or{conditions: conditions} +} + +var _ Condition = (*or)(nil) + +type isNull struct { + column Column +} + +// Write implements [Condition]. +func (i *isNull) Write(builder *StatementBuilder) { + i.column.WriteQualified(builder) + builder.WriteString(" IS NULL") +} + +// IsNull creates a condition that checks if a column is NULL. +func IsNull(column Column) *isNull { + return &isNull{column: column} +} + +var _ Condition = (*isNull)(nil) + +type isNotNull struct { + column Column +} + +// Write implements [Condition]. +func (i *isNotNull) Write(builder *StatementBuilder) { + i.column.WriteQualified(builder) + builder.WriteString(" IS NOT NULL") +} + +// IsNotNull creates a condition that checks if a column is NOT NULL. +func IsNotNull(column Column) *isNotNull { + return &isNotNull{column: column} +} + +var _ Condition = (*isNotNull)(nil) + +type valueCondition func(builder *StatementBuilder) + +// NewTextCondition creates a condition that compares a text column with a value. +func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition { + return valueCondition(func(builder *StatementBuilder) { + writeTextOperation(builder, col, op, value) + }) +} + +// NewDateCondition creates a condition that compares a numeric column with a value. +func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition { + return valueCondition(func(builder *StatementBuilder) { + writeNumberOperation(builder, col, op, value) + }) +} + +// NewDateCondition creates a condition that compares a boolean column with a value. +func NewBooleanCondition[V Boolean](col Column, value V) Condition { + return valueCondition(func(builder *StatementBuilder) { + writeBooleanOperation(builder, col, value) + }) +} + +// NewColumnCondition creates a condition that compares two columns on equality. +func NewColumnCondition(col1, col2 Column) Condition { + return valueCondition(func(builder *StatementBuilder) { + col1.WriteQualified(builder) + builder.WriteString(" = ") + col2.WriteQualified(builder) + }) +} + +// Write implements [Condition]. +func (c valueCondition) Write(builder *StatementBuilder) { + c(builder) +} + +var _ Condition = (*valueCondition)(nil) diff --git a/backend/v3/storage/database/config.go b/backend/v3/storage/database/config.go new file mode 100644 index 00000000000..8b20eb24d7f --- /dev/null +++ b/backend/v3/storage/database/config.go @@ -0,0 +1,10 @@ +package database + +import ( + "context" +) + +// Connector abstracts the database driver. +type Connector interface { + Connect(ctx context.Context) (Pool, error) +} diff --git a/backend/v3/storage/database/database.go b/backend/v3/storage/database/database.go new file mode 100644 index 00000000000..00a852e7a8a --- /dev/null +++ b/backend/v3/storage/database/database.go @@ -0,0 +1,83 @@ +package database + +import ( + "context" +) + +// Pool is a connection pool. e.g. pgxpool +type Pool interface { + Beginner + QueryExecutor + Migrator + + Acquire(ctx context.Context) (Client, error) + Close(ctx context.Context) error +} + +type PoolTest interface { + Pool + // MigrateTest is the same as [Migrator] but executes the migrations multiple times instead of only once. + MigrateTest(ctx context.Context) error +} + +// Client is a single database connection which can be released back to the pool. +type Client interface { + Beginner + QueryExecutor + Migrator + + Release(ctx context.Context) error +} + +// Querier is a database client that can execute queries and return rows. +type Querier interface { + Query(ctx context.Context, stmt string, args ...any) (Rows, error) + QueryRow(ctx context.Context, stmt string, args ...any) Row +} + +// Executor is a database client that can execute statements. +// It returns the number of rows affected or an error +type Executor interface { + Exec(ctx context.Context, stmt string, args ...any) (int64, error) +} + +// QueryExecutor is a database client that can execute queries and statements. +type QueryExecutor interface { + Querier + Executor +} + +// Scanner scans a single row of data into the destination. +type Scanner interface { + Scan(dest ...any) error +} + +// Row is an abstraction of sql.Row. +type Row interface { + Scanner +} + +// Rows is an abstraction of sql.Rows. +type Rows interface { + Scanner + Next() bool + Close() error + Err() error +} + +type CollectableRows interface { + // Collect collects all rows and scans them into dest. + // dest must be a pointer to a slice of pointer to structs + // e.g. *[]*MyStruct + // Rows are closed after this call. + Collect(dest any) error + // CollectFirst collects the first row and scans it into dest. + // dest must be a pointer to a struct + // e.g. *MyStruct{} + // Rows are closed after this call. + CollectFirst(dest any) error + // CollectExactlyOneRow collects exactly one row and scans it into dest. + // e.g. *MyStruct{} + // Rows are closed after this call. + CollectExactlyOneRow(dest any) error +} diff --git a/backend/v3/storage/database/dbmock/database.mock.go b/backend/v3/storage/database/dbmock/database.mock.go new file mode 100644 index 00000000000..1ff898257c0 --- /dev/null +++ b/backend/v3/storage/database/dbmock/database.mock.go @@ -0,0 +1,1146 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/zitadel/zitadel/backend/v3/storage/database (interfaces: Pool,Client,Row,Rows,Transaction) +// +// Generated by this command: +// +// mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction +// + +// Package dbmock is a generated GoMock package. +package dbmock + +import ( + context "context" + reflect "reflect" + + database "github.com/zitadel/zitadel/backend/v3/storage/database" + gomock "go.uber.org/mock/gomock" +) + +// MockPool is a mock of Pool interface. +type MockPool struct { + ctrl *gomock.Controller + recorder *MockPoolMockRecorder +} + +// MockPoolMockRecorder is the mock recorder for MockPool. +type MockPoolMockRecorder struct { + mock *MockPool +} + +// NewMockPool creates a new mock instance. +func NewMockPool(ctrl *gomock.Controller) *MockPool { + mock := &MockPool{ctrl: ctrl} + mock.recorder = &MockPoolMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockPool) EXPECT() *MockPoolMockRecorder { + return m.recorder +} + +// Acquire mocks base method. +func (m *MockPool) Acquire(arg0 context.Context) (database.Client, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Acquire", arg0) + ret0, _ := ret[0].(database.Client) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Acquire indicates an expected call of Acquire. +func (mr *MockPoolMockRecorder) Acquire(arg0 any) *MockPoolAcquireCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Acquire", reflect.TypeOf((*MockPool)(nil).Acquire), arg0) + return &MockPoolAcquireCall{Call: call} +} + +// MockPoolAcquireCall wrap *gomock.Call +type MockPoolAcquireCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolAcquireCall) Return(arg0 database.Client, arg1 error) *MockPoolAcquireCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolAcquireCall) Do(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolAcquireCall) DoAndReturn(f func(context.Context) (database.Client, error)) *MockPoolAcquireCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Begin mocks base method. +func (m *MockPool) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin", arg0, arg1) + ret0, _ := ret[0].(database.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin. +func (mr *MockPoolMockRecorder) Begin(arg0, arg1 any) *MockPoolBeginCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockPool)(nil).Begin), arg0, arg1) + return &MockPoolBeginCall{Call: call} +} + +// MockPoolBeginCall wrap *gomock.Call +type MockPoolBeginCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolBeginCall) Return(arg0 database.Transaction, arg1 error) *MockPoolBeginCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockPoolBeginCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockPoolBeginCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Close mocks base method. +func (m *MockPool) Close(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockPoolMockRecorder) Close(arg0 any) *MockPoolCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockPool)(nil).Close), arg0) + return &MockPoolCloseCall{Call: call} +} + +// MockPoolCloseCall wrap *gomock.Call +type MockPoolCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolCloseCall) Return(arg0 error) *MockPoolCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolCloseCall) Do(f func(context.Context) error) *MockPoolCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPoolCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockPoolMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockPoolExecCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockPool)(nil).Exec), varargs...) + return &MockPoolExecCall{Call: call} +} + +// MockPoolExecCall wrap *gomock.Call +type MockPoolExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolExecCall) Return(arg0 int64, arg1 error) *MockPoolExecCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockPoolExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockPoolExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Migrate mocks base method. +func (m *MockPool) Migrate(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Migrate", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Migrate indicates an expected call of Migrate. +func (mr *MockPoolMockRecorder) Migrate(arg0 any) *MockPoolMigrateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockPool)(nil).Migrate), arg0) + return &MockPoolMigrateCall{Call: call} +} + +// MockPoolMigrateCall wrap *gomock.Call +type MockPoolMigrateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolMigrateCall) Return(arg0 error) *MockPoolMigrateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolMigrateCall) Do(f func(context.Context) error) *MockPoolMigrateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolMigrateCall) DoAndReturn(f func(context.Context) error) *MockPoolMigrateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Query mocks base method. +func (m *MockPool) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(database.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockPoolMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockPoolQueryCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockPool)(nil).Query), varargs...) + return &MockPoolQueryCall{Call: call} +} + +// MockPoolQueryCall wrap *gomock.Call +type MockPoolQueryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolQueryCall) Return(arg0 database.Rows, arg1 error) *MockPoolQueryCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockPoolQueryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockPoolQueryCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryRow mocks base method. +func (m *MockPool) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(database.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow. +func (mr *MockPoolMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockPoolQueryRowCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockPool)(nil).QueryRow), varargs...) + return &MockPoolQueryRowCall{Call: call} +} + +// MockPoolQueryRowCall wrap *gomock.Call +type MockPoolQueryRowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockPoolQueryRowCall) Return(arg0 database.Row) *MockPoolQueryRowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockPoolQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockPoolQueryRowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockPoolQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockPoolQueryRowCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockClient is a mock of Client interface. +type MockClient struct { + ctrl *gomock.Controller + recorder *MockClientMockRecorder +} + +// MockClientMockRecorder is the mock recorder for MockClient. +type MockClientMockRecorder struct { + mock *MockClient +} + +// NewMockClient creates a new mock instance. +func NewMockClient(ctrl *gomock.Controller) *MockClient { + mock := &MockClient{ctrl: ctrl} + mock.recorder = &MockClientMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockClient) EXPECT() *MockClientMockRecorder { + return m.recorder +} + +// Begin mocks base method. +func (m *MockClient) Begin(arg0 context.Context, arg1 *database.TransactionOptions) (database.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin", arg0, arg1) + ret0, _ := ret[0].(database.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin. +func (mr *MockClientMockRecorder) Begin(arg0, arg1 any) *MockClientBeginCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockClient)(nil).Begin), arg0, arg1) + return &MockClientBeginCall{Call: call} +} + +// MockClientBeginCall wrap *gomock.Call +type MockClientBeginCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientBeginCall) Return(arg0 database.Transaction, arg1 error) *MockClientBeginCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientBeginCall) Do(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.TransactionOptions) (database.Transaction, error)) *MockClientBeginCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockClientMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockClientExecCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockClient)(nil).Exec), varargs...) + return &MockClientExecCall{Call: call} +} + +// MockClientExecCall wrap *gomock.Call +type MockClientExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Migrate mocks base method. +func (m *MockClient) Migrate(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Migrate", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Migrate indicates an expected call of Migrate. +func (mr *MockClientMockRecorder) Migrate(arg0 any) *MockClientMigrateCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Migrate", reflect.TypeOf((*MockClient)(nil).Migrate), arg0) + return &MockClientMigrateCall{Call: call} +} + +// MockClientMigrateCall wrap *gomock.Call +type MockClientMigrateCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientMigrateCall) Return(arg0 error) *MockClientMigrateCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientMigrateCall) Do(f func(context.Context) error) *MockClientMigrateCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientMigrateCall) DoAndReturn(f func(context.Context) error) *MockClientMigrateCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Query mocks base method. +func (m *MockClient) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(database.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockClientMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockClientQueryCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockClient)(nil).Query), varargs...) + return &MockClientQueryCall{Call: call} +} + +// MockClientQueryCall wrap *gomock.Call +type MockClientQueryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientQueryCall) Return(arg0 database.Rows, arg1 error) *MockClientQueryCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockClientQueryCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryRow mocks base method. +func (m *MockClient) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(database.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow. +func (mr *MockClientMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockClientQueryRowCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockClient)(nil).QueryRow), varargs...) + return &MockClientQueryRowCall{Call: call} +} + +// MockClientQueryRowCall wrap *gomock.Call +type MockClientQueryRowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientQueryRowCall) Return(arg0 database.Row) *MockClientQueryRowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockClientQueryRowCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Release mocks base method. +func (m *MockClient) Release(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Release", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Release indicates an expected call of Release. +func (mr *MockClientMockRecorder) Release(arg0 any) *MockClientReleaseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Release", reflect.TypeOf((*MockClient)(nil).Release), arg0) + return &MockClientReleaseCall{Call: call} +} + +// MockClientReleaseCall wrap *gomock.Call +type MockClientReleaseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockClientReleaseCall) Return(arg0 error) *MockClientReleaseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockClientReleaseCall) Do(f func(context.Context) error) *MockClientReleaseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockClientReleaseCall) DoAndReturn(f func(context.Context) error) *MockClientReleaseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockRow is a mock of Row interface. +type MockRow struct { + ctrl *gomock.Controller + recorder *MockRowMockRecorder +} + +// MockRowMockRecorder is the mock recorder for MockRow. +type MockRowMockRecorder struct { + mock *MockRow +} + +// NewMockRow creates a new mock instance. +func NewMockRow(ctrl *gomock.Controller) *MockRow { + mock := &MockRow{ctrl: ctrl} + mock.recorder = &MockRowMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRow) EXPECT() *MockRowMockRecorder { + return m.recorder +} + +// Scan mocks base method. +func (m *MockRow) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowMockRecorder) Scan(arg0 ...any) *MockRowScanCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRow)(nil).Scan), arg0...) + return &MockRowScanCall{Call: call} +} + +// MockRowScanCall wrap *gomock.Call +type MockRowScanCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowScanCall) Return(arg0 error) *MockRowScanCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowScanCall) Do(f func(...any) error) *MockRowScanCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowScanCall) DoAndReturn(f func(...any) error) *MockRowScanCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockRows is a mock of Rows interface. +type MockRows struct { + ctrl *gomock.Controller + recorder *MockRowsMockRecorder +} + +// MockRowsMockRecorder is the mock recorder for MockRows. +type MockRowsMockRecorder struct { + mock *MockRows +} + +// NewMockRows creates a new mock instance. +func NewMockRows(ctrl *gomock.Controller) *MockRows { + mock := &MockRows{ctrl: ctrl} + mock.recorder = &MockRowsMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRows) EXPECT() *MockRowsMockRecorder { + return m.recorder +} + +// Close mocks base method. +func (m *MockRows) Close() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Close") + ret0, _ := ret[0].(error) + return ret0 +} + +// Close indicates an expected call of Close. +func (mr *MockRowsMockRecorder) Close() *MockRowsCloseCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Close", reflect.TypeOf((*MockRows)(nil).Close)) + return &MockRowsCloseCall{Call: call} +} + +// MockRowsCloseCall wrap *gomock.Call +type MockRowsCloseCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsCloseCall) Return(arg0 error) *MockRowsCloseCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsCloseCall) Do(f func() error) *MockRowsCloseCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsCloseCall) DoAndReturn(f func() error) *MockRowsCloseCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Err mocks base method. +func (m *MockRows) Err() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Err") + ret0, _ := ret[0].(error) + return ret0 +} + +// Err indicates an expected call of Err. +func (mr *MockRowsMockRecorder) Err() *MockRowsErrCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Err", reflect.TypeOf((*MockRows)(nil).Err)) + return &MockRowsErrCall{Call: call} +} + +// MockRowsErrCall wrap *gomock.Call +type MockRowsErrCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsErrCall) Return(arg0 error) *MockRowsErrCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsErrCall) Do(f func() error) *MockRowsErrCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsErrCall) DoAndReturn(f func() error) *MockRowsErrCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Next mocks base method. +func (m *MockRows) Next() bool { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Next") + ret0, _ := ret[0].(bool) + return ret0 +} + +// Next indicates an expected call of Next. +func (mr *MockRowsMockRecorder) Next() *MockRowsNextCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Next", reflect.TypeOf((*MockRows)(nil).Next)) + return &MockRowsNextCall{Call: call} +} + +// MockRowsNextCall wrap *gomock.Call +type MockRowsNextCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsNextCall) Return(arg0 bool) *MockRowsNextCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsNextCall) Do(f func() bool) *MockRowsNextCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsNextCall) DoAndReturn(f func() bool) *MockRowsNextCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Scan mocks base method. +func (m *MockRows) Scan(arg0 ...any) error { + m.ctrl.T.Helper() + varargs := []any{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Scan", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Scan indicates an expected call of Scan. +func (mr *MockRowsMockRecorder) Scan(arg0 ...any) *MockRowsScanCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Scan", reflect.TypeOf((*MockRows)(nil).Scan), arg0...) + return &MockRowsScanCall{Call: call} +} + +// MockRowsScanCall wrap *gomock.Call +type MockRowsScanCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockRowsScanCall) Return(arg0 error) *MockRowsScanCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockRowsScanCall) Do(f func(...any) error) *MockRowsScanCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockRowsScanCall) DoAndReturn(f func(...any) error) *MockRowsScanCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// MockTransaction is a mock of Transaction interface. +type MockTransaction struct { + ctrl *gomock.Controller + recorder *MockTransactionMockRecorder +} + +// MockTransactionMockRecorder is the mock recorder for MockTransaction. +type MockTransactionMockRecorder struct { + mock *MockTransaction +} + +// NewMockTransaction creates a new mock instance. +func NewMockTransaction(ctrl *gomock.Controller) *MockTransaction { + mock := &MockTransaction{ctrl: ctrl} + mock.recorder = &MockTransactionMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockTransaction) EXPECT() *MockTransactionMockRecorder { + return m.recorder +} + +// Begin mocks base method. +func (m *MockTransaction) Begin(arg0 context.Context) (database.Transaction, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Begin", arg0) + ret0, _ := ret[0].(database.Transaction) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Begin indicates an expected call of Begin. +func (mr *MockTransactionMockRecorder) Begin(arg0 any) *MockTransactionBeginCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Begin", reflect.TypeOf((*MockTransaction)(nil).Begin), arg0) + return &MockTransactionBeginCall{Call: call} +} + +// MockTransactionBeginCall wrap *gomock.Call +type MockTransactionBeginCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionBeginCall) Return(arg0 database.Transaction, arg1 error) *MockTransactionBeginCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionBeginCall) Do(f func(context.Context) (database.Transaction, error)) *MockTransactionBeginCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionBeginCall) DoAndReturn(f func(context.Context) (database.Transaction, error)) *MockTransactionBeginCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Commit mocks base method. +func (m *MockTransaction) Commit(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Commit", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Commit indicates an expected call of Commit. +func (mr *MockTransactionMockRecorder) Commit(arg0 any) *MockTransactionCommitCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Commit", reflect.TypeOf((*MockTransaction)(nil).Commit), arg0) + return &MockTransactionCommitCall{Call: call} +} + +// MockTransactionCommitCall wrap *gomock.Call +type MockTransactionCommitCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionCommitCall) Return(arg0 error) *MockTransactionCommitCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionCommitCall) Do(f func(context.Context) error) *MockTransactionCommitCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionCommitCall) DoAndReturn(f func(context.Context) error) *MockTransactionCommitCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// End mocks base method. +func (m *MockTransaction) End(arg0 context.Context, arg1 error) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "End", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// End indicates an expected call of End. +func (mr *MockTransactionMockRecorder) End(arg0, arg1 any) *MockTransactionEndCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "End", reflect.TypeOf((*MockTransaction)(nil).End), arg0, arg1) + return &MockTransactionEndCall{Call: call} +} + +// MockTransactionEndCall wrap *gomock.Call +type MockTransactionEndCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionEndCall) Return(arg0 error) *MockTransactionEndCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionEndCall) Do(f func(context.Context, error) error) *MockTransactionEndCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) error) *MockTransactionEndCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Exec mocks base method. +func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Exec", varargs...) + ret0, _ := ret[0].(int64) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Exec indicates an expected call of Exec. +func (mr *MockTransactionMockRecorder) Exec(arg0, arg1 any, arg2 ...any) *MockTransactionExecCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Exec", reflect.TypeOf((*MockTransaction)(nil).Exec), varargs...) + return &MockTransactionExecCall{Call: call} +} + +// MockTransactionExecCall wrap *gomock.Call +type MockTransactionExecCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionExecCall) Return(arg0 int64, arg1 error) *MockTransactionExecCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockTransactionExecCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockTransactionExecCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Query mocks base method. +func (m *MockTransaction) Query(arg0 context.Context, arg1 string, arg2 ...any) (database.Rows, error) { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Query", varargs...) + ret0, _ := ret[0].(database.Rows) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Query indicates an expected call of Query. +func (mr *MockTransactionMockRecorder) Query(arg0, arg1 any, arg2 ...any) *MockTransactionQueryCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Query", reflect.TypeOf((*MockTransaction)(nil).Query), varargs...) + return &MockTransactionQueryCall{Call: call} +} + +// MockTransactionQueryCall wrap *gomock.Call +type MockTransactionQueryCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionQueryCall) Return(arg0 database.Rows, arg1 error) *MockTransactionQueryCall { + c.Call = c.Call.Return(arg0, arg1) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionQueryCall) Do(f func(context.Context, string, ...any) (database.Rows, error)) *MockTransactionQueryCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionQueryCall) DoAndReturn(f func(context.Context, string, ...any) (database.Rows, error)) *MockTransactionQueryCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// QueryRow mocks base method. +func (m *MockTransaction) QueryRow(arg0 context.Context, arg1 string, arg2 ...any) database.Row { + m.ctrl.T.Helper() + varargs := []any{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "QueryRow", varargs...) + ret0, _ := ret[0].(database.Row) + return ret0 +} + +// QueryRow indicates an expected call of QueryRow. +func (mr *MockTransactionMockRecorder) QueryRow(arg0, arg1 any, arg2 ...any) *MockTransactionQueryRowCall { + mr.mock.ctrl.T.Helper() + varargs := append([]any{arg0, arg1}, arg2...) + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "QueryRow", reflect.TypeOf((*MockTransaction)(nil).QueryRow), varargs...) + return &MockTransactionQueryRowCall{Call: call} +} + +// MockTransactionQueryRowCall wrap *gomock.Call +type MockTransactionQueryRowCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionQueryRowCall) Return(arg0 database.Row) *MockTransactionQueryRowCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionQueryRowCall) Do(f func(context.Context, string, ...any) database.Row) *MockTransactionQueryRowCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionQueryRowCall) DoAndReturn(f func(context.Context, string, ...any) database.Row) *MockTransactionQueryRowCall { + c.Call = c.Call.DoAndReturn(f) + return c +} + +// Rollback mocks base method. +func (m *MockTransaction) Rollback(arg0 context.Context) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Rollback", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// Rollback indicates an expected call of Rollback. +func (mr *MockTransactionMockRecorder) Rollback(arg0 any) *MockTransactionRollbackCall { + mr.mock.ctrl.T.Helper() + call := mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Rollback", reflect.TypeOf((*MockTransaction)(nil).Rollback), arg0) + return &MockTransactionRollbackCall{Call: call} +} + +// MockTransactionRollbackCall wrap *gomock.Call +type MockTransactionRollbackCall struct { + *gomock.Call +} + +// Return rewrite *gomock.Call.Return +func (c *MockTransactionRollbackCall) Return(arg0 error) *MockTransactionRollbackCall { + c.Call = c.Call.Return(arg0) + return c +} + +// Do rewrite *gomock.Call.Do +func (c *MockTransactionRollbackCall) Do(f func(context.Context) error) *MockTransactionRollbackCall { + c.Call = c.Call.Do(f) + return c +} + +// DoAndReturn rewrite *gomock.Call.DoAndReturn +func (c *MockTransactionRollbackCall) DoAndReturn(f func(context.Context) error) *MockTransactionRollbackCall { + c.Call = c.Call.DoAndReturn(f) + return c +} diff --git a/backend/v3/storage/database/dialect/config.go b/backend/v3/storage/database/dialect/config.go new file mode 100644 index 00000000000..6abe57c5a82 --- /dev/null +++ b/backend/v3/storage/database/dialect/config.go @@ -0,0 +1,92 @@ +package dialect + +import ( + "context" + "errors" + "reflect" + + "github.com/mitchellh/mapstructure" + "github.com/spf13/viper" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" +) + +type Hook struct { + Match func(string) bool + Decode func(config any) (database.Connector, error) + Name string + Constructor func() database.Connector +} + +var hooks = []Hook{ + { + Match: postgres.NameMatcher, + Decode: postgres.DecodeConfig, + Name: postgres.Name, + Constructor: func() database.Connector { return new(postgres.Config) }, + }, + // { + // Match: gosql.NameMatcher, + // Decode: gosql.DecodeConfig, + // Name: gosql.Name, + // Constructor: func() database.Connector { return new(gosql.Config) }, + // }, +} + +type Config struct { + Dialects map[string]any `mapstructure:",remain" yaml:",inline"` + + connector database.Connector +} + +func (c Config) Connect(ctx context.Context) (database.Pool, error) { + if len(c.Dialects) != 1 { + return nil, errors.New("exactly one dialect must be configured") + } + + return c.connector.Connect(ctx) +} + +// Hooks implements [configure.Unmarshaller]. +func (c Config) Hooks() []viper.DecoderConfigOption { + return []viper.DecoderConfigOption{ + viper.DecodeHook(decodeHook), + } +} + +func decodeHook(from, to reflect.Value) (_ any, err error) { + if to.Type() != reflect.TypeOf(Config{}) { + return from.Interface(), nil + } + + config := new(Config) + if err = mapstructure.Decode(from.Interface(), config); err != nil { + return nil, err + } + + if err = config.decodeDialect(); err != nil { + return nil, err + } + + return config, nil +} + +func (c *Config) decodeDialect() error { + for _, hook := range hooks { + for name, config := range c.Dialects { + if !hook.Match(name) { + continue + } + + connector, err := hook.Decode(config) + if err != nil { + return err + } + + c.connector = connector + return nil + } + } + return errors.New("no dialect found") +} diff --git a/backend/v3/storage/database/dialect/postgres/config.go b/backend/v3/storage/database/dialect/postgres/config.go new file mode 100644 index 00000000000..1e4e5d751b5 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/config.go @@ -0,0 +1,89 @@ +package postgres + +import ( + "context" + "errors" + "slices" + "strings" + + "github.com/jackc/pgx/v5/pgxpool" + "github.com/mitchellh/mapstructure" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var ( + _ database.Connector = (*Config)(nil) + Name = "postgres" + isMigrated bool +) + +type Config struct { + *pgxpool.Config + *pgxpool.Pool + + // Host string + // Port int32 + // Database string + // MaxOpenConns uint32 + // MaxIdleConns uint32 + // MaxConnLifetime time.Duration + // MaxConnIdleTime time.Duration + // User User + // // Additional options to be appended as options= + // // The value will be taken as is. Multiple options are space separated. + // Options string + + // configuredFields []string +} + +// Connect implements [database.Connector]. +func (c *Config) Connect(ctx context.Context) (database.Pool, error) { + pool, err := c.getPool(ctx) + if err != nil { + return nil, wrapError(err) + } + if err = pool.Ping(ctx); err != nil { + return nil, wrapError(err) + } + return &pgxPool{Pool: pool}, nil +} + +func (c *Config) getPool(ctx context.Context) (*pgxpool.Pool, error) { + if c.Pool != nil { + return c.Pool, nil + } + return pgxpool.NewWithConfig(ctx, c.Config) +} + +func NameMatcher(name string) bool { + return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name)) +} + +func DecodeConfig(input any) (database.Connector, error) { + switch c := input.(type) { + case string: + config, err := pgxpool.ParseConfig(c) + if err != nil { + return nil, err + } + return &Config{Config: config}, nil + case map[string]any: + connector := new(Config) + decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{ + DecodeHook: mapstructure.StringToTimeDurationHookFunc(), + WeaklyTypedInput: true, + Result: connector, + }) + if err != nil { + return nil, err + } + if err = decoder.Decode(c); err != nil { + return nil, err + } + return &Config{ + Config: &pgxpool.Config{}, + }, nil + } + return nil, errors.New("invalid configuration") +} diff --git a/backend/v3/storage/database/dialect/postgres/conn.go b/backend/v3/storage/database/dialect/postgres/conn.go new file mode 100644 index 00000000000..a556b7a545a --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/conn.go @@ -0,0 +1,67 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration" +) + +type pgxConn struct { + *pgxpool.Conn +} + +var _ database.Client = (*pgxConn)(nil) + +// Release implements [database.Client]. +func (c *pgxConn) Release(_ context.Context) error { + c.Conn.Release() + return nil +} + +// Begin implements [database.Client]. +func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts)) + if err != nil { + return nil, wrapError(err) + } + return &pgxTx{tx}, nil +} + +// Query implements sql.Client. +// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn. +func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := c.Conn.Query(ctx, sql, args...) + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil +} + +// QueryRow implements sql.Client. +// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn. +func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{c.Conn.QueryRow(ctx, sql, args...)} +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := c.Conn.Exec(ctx, sql, args...) + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected(), nil +} + +// Migrate implements [database.Migrator]. +func (c *pgxConn) Migrate(ctx context.Context) error { + if isMigrated { + return nil + } + err := migration.Migrate(ctx, c.Conn.Conn()) + isMigrated = err == nil + return wrapError(err) +} diff --git a/backend/v3/storage/database/dialect/postgres/doc.go b/backend/v3/storage/database/dialect/postgres/doc.go new file mode 100644 index 00000000000..8b21e4766d0 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/doc.go @@ -0,0 +1,2 @@ +// pgxpool v5 implementation of the interfaces defined in the database package. +package postgres diff --git a/backend/v3/storage/database/dialect/postgres/embedded/start.go b/backend/v3/storage/database/dialect/postgres/embedded/start.go new file mode 100644 index 00000000000..9a5f3ea82b9 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/embedded/start.go @@ -0,0 +1,50 @@ +// embedded is used for testing purposes +package embedded + +import ( + "net" + "os" + + embeddedpostgres "github.com/fergusstrange/embedded-postgres" + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" +) + +// StartEmbedded starts an embedded postgres v16 instance and returns a database connector and a stop function +// the database is started on a random port and data are stored in a temporary directory +// its used for testing purposes only +func StartEmbedded() (connector database.Connector, stop func(), err error) { + path, err := os.MkdirTemp("", "zitadel-embedded-postgres-*") + logging.OnError(err).Fatal("unable to create temp dir") + + port, close := getPort() + + config := embeddedpostgres.DefaultConfig().Version(embeddedpostgres.V16).Port(uint32(port)).RuntimePath(path) + embedded := embeddedpostgres.NewDatabase(config) + + close() + err = embedded.Start() + logging.OnError(err).Fatal("unable to start db") + + connector, err = postgres.DecodeConfig(config.GetConnectionURL()) + if err != nil { + return nil, nil, err + } + + return connector, func() { + logging.OnError(embedded.Stop()).Error("unable to stop db") + }, nil +} + +// getPort returns a free port and locks it until close is called +func getPort() (port uint16, close func()) { + l, err := net.Listen("tcp", ":0") + logging.OnError(err).Fatal("unable to get port") + port = uint16(l.Addr().(*net.TCPAddr).Port) + logging.WithFields("port", port).Info("Port is available") + return port, func() { + logging.OnError(l.Close()).Error("unable to close port listener") + } +} diff --git a/backend/v3/storage/database/dialect/postgres/error.go b/backend/v3/storage/database/dialect/postgres/error.go new file mode 100644 index 00000000000..89b3f8837a9 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/error.go @@ -0,0 +1,38 @@ +package postgres + +import ( + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +func wrapError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, pgx.ErrNoRows) { + return database.NewNoRowFoundError(err) + } + var pgxErr *pgconn.PgError + if !errors.As(err, &pgxErr) { + return database.NewUnknownError(err) + } + switch pgxErr.Code { + // 23514: check_violation - A value violates a CHECK constraint. + case "23514": + return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23505: unique_violation - A value violates a UNIQUE constraint. + case "23505": + return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23503: foreign_key_violation - A value violates a foreign key constraint. + case "23503": + return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23502: not_null_violation - A value violates a NOT NULL constraint. + case "23502": + return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + } + return database.NewUnknownError(err) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/001_instance_table.go b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table.go new file mode 100644 index 00000000000..6e9f9b4c48a --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table.go @@ -0,0 +1,16 @@ +package migration + +import ( + _ "embed" +) + +var ( + //go:embed 001_instance_table/up.sql + up001InstanceTable string + //go:embed 001_instance_table/down.sql + down001InstanceTable string +) + +func init() { + registerSQLMigration(1, up001InstanceTable, down001InstanceTable) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/down.sql b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/down.sql new file mode 100644 index 00000000000..e2fcd6b7c98 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/down.sql @@ -0,0 +1 @@ +DROP TABLE zitadel.instances; \ No newline at end of file diff --git a/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/up.sql b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/up.sql new file mode 100644 index 00000000000..b8faaedafd7 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/001_instance_table/up.sql @@ -0,0 +1,25 @@ +CREATE TABLE IF NOT EXISTS zitadel.instances( + id TEXT NOT NULL CHECK (id <> '') PRIMARY KEY, + name TEXT NOT NULL CHECK (name <> ''), + default_org_id TEXT, -- NOT NULL, + iam_project_id TEXT, -- NOT NULL, + console_client_id TEXT, -- NOT NULL, + console_app_id TEXT, -- NOT NULL, + default_language TEXT, -- NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, + updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL +); + +CREATE OR REPLACE FUNCTION zitadel.set_updated_at() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at := NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.instances +FOR EACH ROW +WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) +EXECUTE FUNCTION zitadel.set_updated_at(); diff --git a/backend/v3/storage/database/dialect/postgres/migration/002_organization_table.go b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table.go new file mode 100644 index 00000000000..2f0a04eee01 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table.go @@ -0,0 +1,16 @@ +package migration + +import ( + _ "embed" +) + +var ( + //go:embed 002_organization_table/up.sql + up002OrganizationTable string + //go:embed 002_organization_table/down.sql + down002OrganizationTable string +) + +func init() { + registerSQLMigration(2, up002OrganizationTable, down002OrganizationTable) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/down.sql b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/down.sql new file mode 100644 index 00000000000..654858cdac6 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/down.sql @@ -0,0 +1,2 @@ +DROP TABLE zitadel.organizations; +DROP Type zitadel.organization_state; diff --git a/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/up.sql b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/up.sql new file mode 100644 index 00000000000..e8b5acef4b8 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/002_organization_table/up.sql @@ -0,0 +1,24 @@ +CREATE TYPE zitadel.organization_state AS ENUM ( + 'active', + 'inactive' +); + +CREATE TABLE zitadel.organizations( + id TEXT NOT NULL CHECK (id <> ''), + name TEXT NOT NULL CHECK (name <> ''), + instance_id TEXT NOT NULL REFERENCES zitadel.instances (id) ON DELETE CASCADE, + state zitadel.organization_state NOT NULL, + created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, + updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL, + + PRIMARY KEY (instance_id, id) +); + +CREATE UNIQUE INDEX org_unique_instance_id_name_idx + ON zitadel.organizations (instance_id, name); + +CREATE TRIGGER trigger_set_updated_at +BEFORE UPDATE ON zitadel.organizations +FOR EACH ROW +WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) +EXECUTE FUNCTION zitadel.set_updated_at(); diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_domains_table.go b/backend/v3/storage/database/dialect/postgres/migration/003_domains_table.go new file mode 100644 index 00000000000..c00b7ea2810 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/003_domains_table.go @@ -0,0 +1,16 @@ +package migration + +import ( + _ "embed" +) + +var ( + //go:embed 003_domains_table/up.sql + up003DomainsTable string + //go:embed 003_domains_table/down.sql + down003DomainsTable string +) + +func init() { + registerSQLMigration(3, up003DomainsTable, down003DomainsTable) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_domains_table/down.sql b/backend/v3/storage/database/dialect/postgres/migration/003_domains_table/down.sql new file mode 100644 index 00000000000..3a53d4173dd --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/003_domains_table/down.sql @@ -0,0 +1,6 @@ +DROP TABLE IF EXISTS zitadel.instance_domains; +DROP TABLE IF EXISTS zitadel.org_domains; +DROP TYPE IF EXISTS zitadel.domain_type; +DROP TYPE IF EXISTS zitadel.domain_validation_type; +DROP FUNCTION IF EXISTS zitadel.check_verified_org_domain(); +DROP FUNCTION IF EXISTS zitadel.ensure_single_primary_instance_domain(); diff --git a/backend/v3/storage/database/dialect/postgres/migration/003_domains_table/up.sql b/backend/v3/storage/database/dialect/postgres/migration/003_domains_table/up.sql new file mode 100644 index 00000000000..9d5132b63de --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/003_domains_table/up.sql @@ -0,0 +1,137 @@ +CREATE TYPE zitadel.domain_validation_type AS ENUM ( + 'dns' + , 'http' +); + +CREATE TYPE zitadel.domain_type AS ENUM ( + 'custom' + , 'trusted' +); + +CREATE TABLE zitadel.instance_domains( + instance_id TEXT NOT NULL + , domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255) + , is_primary BOOLEAN + , is_generated BOOLEAN + , type zitadel.domain_type NOT NULL + + , created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL + , updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL + + , PRIMARY KEY (domain) + + , FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) ON DELETE CASCADE + + , CONSTRAINT primary_cannot_be_trusted CHECK (is_primary IS NULL OR type != 'trusted') + , CONSTRAINT generated_cannot_be_trusted CHECK (is_generated IS NULL OR type != 'trusted') + , CONSTRAINT custom_values_set CHECK ((is_primary IS NOT NULL AND is_generated IS NOT NULL) OR type != 'custom') +); + +CREATE INDEX idx_instance_domain_instance ON zitadel.instance_domains(instance_id); + +CREATE TABLE zitadel.org_domains( + instance_id TEXT NOT NULL + , org_id TEXT NOT NULL + , domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255) + , is_verified BOOLEAN NOT NULL DEFAULT FALSE + , is_primary BOOLEAN NOT NULL DEFAULT FALSE + , validation_type zitadel.domain_validation_type + + , created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL + , updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL + + , PRIMARY KEY (instance_id, org_id, domain) + + , FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) ON DELETE CASCADE + + , UNIQUE (instance_id, org_id, domain) +); + +CREATE INDEX idx_org_domain ON zitadel.org_domains(instance_id, domain); + +-- Trigger to update the updated_at timestamp on instance_domains +CREATE TRIGGER trg_set_updated_at_instance_domains + BEFORE UPDATE ON zitadel.instance_domains + FOR EACH ROW + WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) + EXECUTE FUNCTION zitadel.set_updated_at(); + +-- Trigger to update the updated_at timestamp on org_domains +CREATE TRIGGER trg_set_updated_at_org_domains + BEFORE UPDATE ON zitadel.org_domains + FOR EACH ROW + WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at) + EXECUTE FUNCTION zitadel.set_updated_at(); + +-- Function to check for already verified org domains +CREATE OR REPLACE FUNCTION zitadel.check_verified_org_domain() +RETURNS TRIGGER AS $$ +BEGIN + -- Check if there's already a verified domain within this instance (excluding the current record being updated) + IF EXISTS ( + SELECT 1 + FROM zitadel.org_domains + WHERE instance_id = NEW.instance_id + AND domain = NEW.domain + AND is_verified = TRUE + AND (TG_OP = 'INSERT' OR (org_id != NEW.org_id)) + ) THEN + RAISE EXCEPTION 'org domain is already taken'; + END IF; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Trigger to enforce verified domain constraint on org_domains +CREATE TRIGGER trg_check_verified_org_domain + BEFORE INSERT OR UPDATE ON zitadel.org_domains + FOR EACH ROW + WHEN (NEW.is_verified IS TRUE) + EXECUTE FUNCTION zitadel.check_verified_org_domain(); + +-- Function to ensure only one primary domain per instance in instance_domains +CREATE OR REPLACE FUNCTION zitadel.ensure_single_primary_instance_domain() +RETURNS TRIGGER AS $$ +BEGIN + -- If setting this domain as primary, update all other domains in the same instance to non-primary + UPDATE zitadel.instance_domains + SET is_primary = FALSE, updated_at = NOW() + WHERE instance_id = NEW.instance_id + AND domain != NEW.domain + AND is_primary = TRUE + AND type = 'custom'; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Trigger to enforce single primary domain constraint on instance_domains +CREATE TRIGGER trg_ensure_single_primary_instance_domain + BEFORE INSERT OR UPDATE ON zitadel.instance_domains + FOR EACH ROW + WHEN (NEW.is_primary IS TRUE) + EXECUTE FUNCTION zitadel.ensure_single_primary_instance_domain(); + +-- Function to ensure only one primary domain per organization in org_domains +CREATE OR REPLACE FUNCTION zitadel.ensure_single_primary_org_domain() +RETURNS TRIGGER AS $$ +BEGIN + -- If setting this domain as primary, update all other domains in the same organization to non-primary + UPDATE zitadel.org_domains + SET is_primary = FALSE, updated_at = NOW() + WHERE instance_id = NEW.instance_id + AND org_id = NEW.org_id + AND domain != NEW.domain + AND is_primary = TRUE; + + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +-- Trigger to enforce single primary domain constraint on org_domains +CREATE TRIGGER trg_ensure_single_primary_org_domain + BEFORE INSERT OR UPDATE ON zitadel.org_domains + FOR EACH ROW + WHEN (NEW.is_primary IS TRUE) + EXECUTE FUNCTION zitadel.ensure_single_primary_org_domain(); \ No newline at end of file diff --git a/backend/v3/storage/database/dialect/postgres/migration/doc.go b/backend/v3/storage/database/dialect/postgres/migration/doc.go new file mode 100644 index 00000000000..306347a641a --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/doc.go @@ -0,0 +1,13 @@ +// This package contains the migration logic for the PostgreSQL dialect. +// It uses the [github.com/jackc/tern/v2/migrate] package to handle the migration process. +// +// **Developer Note**: +// +// Each migration MUST be registered in an init function. +// Create a go file for each migration with the sequence of the migration as prefix and some descriptive name. +// The file name MUST be in the format _.go. +// Each migration SHOULD provide an up and down migration. +// Prefer to write SQL statements instead of funcs if it is reasonable. +// To keep the folder clean create a folder to store the sql files with the following format: _/{up/down}.sql. +// And use the go embed directive to embed the sql files. +package migration diff --git a/backend/v3/storage/database/dialect/postgres/migration/migrationgs.go b/backend/v3/storage/database/dialect/postgres/migration/migrationgs.go new file mode 100644 index 00000000000..96194a756e0 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/migrationgs.go @@ -0,0 +1,33 @@ +package migration + +import ( + "context" + + "github.com/jackc/pgx/v5" + "github.com/jackc/tern/v2/migrate" +) + +var migrations []*migrate.Migration + +func Migrate(ctx context.Context, conn *pgx.Conn) error { + // we need to ensure that the schema exists before we can run the migration + // because creating the migrations table already required the schema + _, err := conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS zitadel") + if err != nil { + return err + } + migrator, err := migrate.NewMigrator(ctx, conn, "zitadel.migrations") + if err != nil { + return err + } + migrator.Migrations = migrations + return migrator.Migrate(ctx) +} + +func registerSQLMigration(sequence int32, up, down string) { + migrations = append(migrations, &migrate.Migration{ + Sequence: sequence, + UpSQL: up, + DownSQL: down, + }) +} diff --git a/backend/v3/storage/database/dialect/postgres/migration/migrationgs_test.go b/backend/v3/storage/database/dialect/postgres/migration/migrationgs_test.go new file mode 100644 index 00000000000..37680cb94f5 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/migration/migrationgs_test.go @@ -0,0 +1,60 @@ +package migration_test + +import ( + "context" + "testing" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded" +) + +func TestMigrate(t *testing.T) { + tests := []struct { + name string + stmt string + args []any + res []any + }{ + { + name: "schema", + stmt: "SELECT EXISTS(SELECT 1 FROM information_schema.schemata where schema_name = 'zitadel') ;", + res: []any{true}, + }, + { + name: "001", + stmt: "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE schemaname = 'zitadel' and tablename=$1)", + args: []any{"instances"}, + res: []any{true}, + }, + } + + ctx := context.Background() + + connector, stop, err := embedded.StartEmbedded() + require.NoError(t, err, "failed to start embedded postgres") + defer stop() + + client, err := connector.Connect(ctx) + require.NoError(t, err, "failed to connect to embedded postgres") + + err = client.(database.Migrator).Migrate(ctx) + require.NoError(t, err, "failed to execute migration steps") + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := make([]any, len(tt.res)) + for i := range got { + got[i] = new(any) + tt.res[i] = gu.Ptr(tt.res[i]) + } + + require.NoError(t, client.QueryRow(ctx, tt.stmt, tt.args...).Scan(got...), "failed to execute check query") + + assert.Equal(t, tt.res, got, "query result does not match") + }) + } +} diff --git a/backend/v3/storage/database/dialect/postgres/pool.go b/backend/v3/storage/database/dialect/postgres/pool.go new file mode 100644 index 00000000000..bf4ca44c313 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/pool.go @@ -0,0 +1,100 @@ +package postgres + +import ( + "context" + + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration" +) + +type pgxPool struct { + *pgxpool.Pool +} + +var _ database.Pool = (*pgxPool)(nil) + +func PGxPool(pool *pgxpool.Pool) *pgxPool { + return &pgxPool{ + Pool: pool, + } +} + +// Acquire implements [database.Pool]. +func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) { + conn, err := c.Pool.Acquire(ctx) + if err != nil { + return nil, wrapError(err) + } + return &pgxConn{Conn: conn}, nil +} + +// Query implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. +func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := c.Pool.Query(ctx, sql, args...) + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil +} + +// QueryRow implements [database.Pool]. +// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. +func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{c.Pool.QueryRow(ctx, sql, args...)} +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := c.Pool.Exec(ctx, sql, args...) + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected(), nil +} + +// Begin implements [database.Pool]. +func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts)) + if err != nil { + return nil, wrapError(err) + } + return &pgxTx{tx}, nil +} + +// Close implements [database.Pool]. +func (c *pgxPool) Close(_ context.Context) error { + c.Pool.Close() + return nil +} + +// Migrate implements [database.Migrator]. +func (c *pgxPool) Migrate(ctx context.Context) error { + if isMigrated { + return nil + } + + client, err := c.Pool.Acquire(ctx) + if err != nil { + return err + } + + err = migration.Migrate(ctx, client.Conn()) + isMigrated = err == nil + return wrapError(err) +} + +// Migrate implements [database.PoolTest]. +func (c *pgxPool) MigrateTest(ctx context.Context) error { + client, err := c.Pool.Acquire(ctx) + if err != nil { + return err + } + + err = migration.Migrate(ctx, client.Conn()) + isMigrated = err == nil + return err +} diff --git a/backend/v3/storage/database/dialect/postgres/rows.go b/backend/v3/storage/database/dialect/postgres/rows.go new file mode 100644 index 00000000000..d151effd591 --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/rows.go @@ -0,0 +1,77 @@ +package postgres + +import ( + "github.com/georgysavva/scany/v2/pgxscan" + "github.com/jackc/pgx/v5" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var ( + _ database.Rows = (*Rows)(nil) + _ database.CollectableRows = (*Rows)(nil) + _ database.Row = (*Row)(nil) +) + +type Row struct{ pgx.Row } + +// Scan implements [database.Row]. +// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row. +func (r *Row) Scan(dest ...any) error { + return wrapError(r.Row.Scan(dest...)) +} + +type Rows struct{ pgx.Rows } + +// Err implements [database.Rows]. +// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows. +func (r *Rows) Err() error { + return wrapError(r.Rows.Err()) +} + +func (r *Rows) Scan(dest ...any) error { + return wrapError(r.Rows.Scan(dest...)) +} + +// Collect implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) Collect(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return wrapError(pgxscan.ScanAll(dest, r.Rows)) +} + +// CollectFirst implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) CollectFirst(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return wrapError(pgxscan.ScanRow(dest, r.Rows)) +} + +// CollectExactlyOneRow implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) CollectExactlyOneRow(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return wrapError(pgxscan.ScanOne(dest, r.Rows)) +} + +// Close implements [database.Rows]. +// Subtle: this method shadows the method (Rows).Close of Rows.Rows. +func (r *Rows) Close() error { + r.Rows.Close() + return nil +} diff --git a/backend/v3/storage/database/dialect/postgres/tx.go b/backend/v3/storage/database/dialect/postgres/tx.go new file mode 100644 index 00000000000..6a330c16b9a --- /dev/null +++ b/backend/v3/storage/database/dialect/postgres/tx.go @@ -0,0 +1,107 @@ +package postgres + +import ( + "context" + "errors" + + "github.com/jackc/pgx/v5" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type pgxTx struct{ pgx.Tx } + +var _ database.Transaction = (*pgxTx)(nil) + +// Commit implements [database.Transaction]. +func (tx *pgxTx) Commit(ctx context.Context) error { + err := tx.Tx.Commit(ctx) + return wrapError(err) +} + +// Rollback implements [database.Transaction]. +func (tx *pgxTx) Rollback(ctx context.Context) error { + err := tx.Tx.Rollback(ctx) + return wrapError(err) +} + +// End implements [database.Transaction]. +func (tx *pgxTx) End(ctx context.Context, err error) error { + if err != nil { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil { + err = errors.Join(err, rollbackErr) + } + return err + } + return tx.Commit(ctx) +} + +// Query implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx. +func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + rows, err := tx.Tx.Query(ctx, sql, args...) + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil +} + +// QueryRow implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx. +func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{tx.Tx.QueryRow(ctx, sql, args...)} +} + +// Exec implements [database.Transaction]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := tx.Tx.Exec(ctx, sql, args...) + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected(), nil +} + +// Begin implements [database.Transaction]. +// As postgres does not support nested transactions we use savepoints to emulate them. +func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) { + savepoint, err := tx.Tx.Begin(ctx) + if err != nil { + return nil, wrapError(err) + } + return &pgxTx{savepoint}, nil +} + +func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions { + if opts == nil { + return pgx.TxOptions{} + } + + return pgx.TxOptions{ + IsoLevel: isolationToPgx(opts.IsolationLevel), + AccessMode: accessModeToPgx(opts.AccessMode), + } +} + +func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel { + switch isolation { + case database.IsolationLevelSerializable: + return pgx.Serializable + case database.IsolationLevelReadCommitted: + return pgx.ReadCommitted + default: + return pgx.Serializable + } +} + +func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode { + switch accessMode { + case database.AccessModeReadWrite: + return pgx.ReadWrite + case database.AccessModeReadOnly: + return pgx.ReadOnly + default: + return pgx.ReadWrite + } +} diff --git a/backend/v3/storage/database/dialect/sql/conn.go b/backend/v3/storage/database/dialect/sql/conn.go new file mode 100644 index 00000000000..40749f690ba --- /dev/null +++ b/backend/v3/storage/database/dialect/sql/conn.go @@ -0,0 +1,60 @@ +package sql + +import ( + "context" + "database/sql" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type sqlConn struct { + *sql.Conn +} + +var _ database.Client = (*sqlConn)(nil) + +// Release implements [database.Client]. +func (c *sqlConn) Release(_ context.Context) error { + return c.Close() +} + +// Begin implements [database.Client]. +func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts)) + if err != nil { + return nil, wrapError(err) + } + return &sqlTx{tx}, nil +} + +// Query implements sql.Client. +// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn. +func (c *sqlConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + //nolint:rowserrcheck // Rows.Close is called by the caller + rows, err := c.QueryContext(ctx, sql, args...) + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil +} + +// QueryRow implements sql.Client. +// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn. +func (c *sqlConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{c.QueryRowContext(ctx, sql, args...)} +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *sqlConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := c.ExecContext(ctx, sql, args...) + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected() +} + +// Migrate implements [database.Migrator]. +func (c *sqlConn) Migrate(ctx context.Context) error { + return ErrMigrate +} diff --git a/backend/v3/storage/database/dialect/sql/doc.go b/backend/v3/storage/database/dialect/sql/doc.go new file mode 100644 index 00000000000..a6b2782dc53 --- /dev/null +++ b/backend/v3/storage/database/dialect/sql/doc.go @@ -0,0 +1,3 @@ +// [database/sql] implementation of the interfaces defined in the database package. +// This package is used to migrate from event driven to relational Zitadel. +package sql diff --git a/backend/v3/storage/database/dialect/sql/error.go b/backend/v3/storage/database/dialect/sql/error.go new file mode 100644 index 00000000000..f01b1e70e13 --- /dev/null +++ b/backend/v3/storage/database/dialect/sql/error.go @@ -0,0 +1,41 @@ +package sql + +import ( + "database/sql" + "errors" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgconn" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var ErrMigrate = errors.New("sql does not support migrations, use a different dialect") + +func wrapError(err error) error { + if err == nil { + return nil + } + if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) { + return database.NewNoRowFoundError(err) + } + var pgxErr *pgconn.PgError + if !errors.As(err, &pgxErr) { + return database.NewUnknownError(err) + } + switch pgxErr.Code { + // 23514: check_violation - A value violates a CHECK constraint. + case "23514": + return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23505: unique_violation - A value violates a UNIQUE constraint. + case "23505": + return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23503: foreign_key_violation - A value violates a foreign key constraint. + case "23503": + return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + // 23502: not_null_violation - A value violates a NOT NULL constraint. + case "23502": + return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr) + } + return database.NewUnknownError(err) +} diff --git a/backend/v3/storage/database/dialect/sql/pool.go b/backend/v3/storage/database/dialect/sql/pool.go new file mode 100644 index 00000000000..1bab0dd3970 --- /dev/null +++ b/backend/v3/storage/database/dialect/sql/pool.go @@ -0,0 +1,75 @@ +package sql + +import ( + "context" + "database/sql" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type sqlPool struct { + *sql.DB +} + +var _ database.Pool = (*sqlPool)(nil) + +func SQLPool(db *sql.DB) *sqlPool { + return &sqlPool{ + DB: db, + } +} + +// Acquire implements [database.Pool]. +func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) { + conn, err := c.Conn(ctx) + if err != nil { + return nil, wrapError(err) + } + return &sqlConn{Conn: conn}, nil +} + +// Query implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool. +func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + //nolint:rowserrcheck // Rows.Close is called by the caller + rows, err := c.QueryContext(ctx, sql, args...) + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil +} + +// QueryRow implements [database.Pool]. +// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool. +func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{c.QueryRowContext(ctx, sql, args...)} +} + +// Exec implements [database.Pool]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := c.ExecContext(ctx, sql, args...) + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected() +} + +// Begin implements [database.Pool]. +func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) { + tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts)) + if err != nil { + return nil, wrapError(err) + } + return &sqlTx{tx}, nil +} + +// Close implements [database.Pool]. +func (c *sqlPool) Close(_ context.Context) error { + return c.DB.Close() +} + +// Migrate implements [database.Migrator]. +func (c *sqlPool) Migrate(ctx context.Context) error { + return ErrMigrate +} diff --git a/backend/v3/storage/database/dialect/sql/rows.go b/backend/v3/storage/database/dialect/sql/rows.go new file mode 100644 index 00000000000..d49dc346d2d --- /dev/null +++ b/backend/v3/storage/database/dialect/sql/rows.go @@ -0,0 +1,78 @@ +package sql + +import ( + "database/sql" + + pgxscan "github.com/georgysavva/scany/v2/dbscan" + "github.com/jackc/pgx/v5" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var ( + _ database.Rows = (*Rows)(nil) + _ database.CollectableRows = (*Rows)(nil) + _ database.Row = (*Row)(nil) +) + +type Row struct{ pgx.Row } + +// Scan implements [database.Row]. +// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row. +func (r *Row) Scan(dest ...any) error { + return wrapError(r.Row.Scan(dest...)) +} + +type Rows struct{ *sql.Rows } + +// Err implements [database.Rows]. +// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows. +func (r *Rows) Err() error { + return wrapError(r.Rows.Err()) +} + +func (r *Rows) Scan(dest ...any) error { + return wrapError(r.Rows.Scan(dest...)) +} + +// Collect implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) Collect(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return wrapError(pgxscan.ScanAll(dest, r.Rows)) +} + +// CollectFirst implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) CollectFirst(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return wrapError(pgxscan.ScanRow(dest, r.Rows)) +} + +// CollectExactlyOneRow implements [database.CollectableRows]. +// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details. +func (r *Rows) CollectExactlyOneRow(dest any) (err error) { + defer func() { + closeErr := r.Close() + if err == nil { + err = closeErr + } + }() + return wrapError(pgxscan.ScanOne(dest, r.Rows)) +} + +// Close implements [database.Rows]. +// Subtle: this method shadows the method (Rows).Close of Rows.Rows. +func (r *Rows) Close() error { + return r.Rows.Close() +} diff --git a/backend/v3/storage/database/dialect/sql/savepoint.go b/backend/v3/storage/database/dialect/sql/savepoint.go new file mode 100644 index 00000000000..1933771e9a1 --- /dev/null +++ b/backend/v3/storage/database/dialect/sql/savepoint.go @@ -0,0 +1,69 @@ +package sql + +import ( + "context" + "errors" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var _ database.Transaction = (*sqlSavepoint)(nil) + +const ( + savepointName = "zitadel_savepoint" + createSavepoint = "SAVEPOINT " + savepointName + rollbackToSavepoint = "ROLLBACK TO SAVEPOINT " + savepointName + commitSavepoint = "RELEASE SAVEPOINT " + savepointName +) + +type sqlSavepoint struct { + parent database.Transaction +} + +// Commit implements [database.Transaction]. +func (s *sqlSavepoint) Commit(ctx context.Context) error { + _, err := s.parent.Exec(ctx, commitSavepoint) + return wrapError(err) +} + +// Rollback implements [database.Transaction]. +func (s *sqlSavepoint) Rollback(ctx context.Context) error { + _, err := s.parent.Exec(ctx, rollbackToSavepoint) + return wrapError(err) +} + +// End implements [database.Transaction]. +func (s *sqlSavepoint) End(ctx context.Context, err error) error { + if err != nil { + rollbackErr := s.Rollback(ctx) + if rollbackErr != nil { + err = errors.Join(err, rollbackErr) + } + return err + } + return s.Commit(ctx) +} + +// Query implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx. +func (s *sqlSavepoint) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + return s.parent.Query(ctx, sql, args...) +} + +// QueryRow implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx. +func (s *sqlSavepoint) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return s.parent.QueryRow(ctx, sql, args...) +} + +// Exec implements [database.Transaction]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (s *sqlSavepoint) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + return s.parent.Exec(ctx, sql, args...) +} + +// Begin implements [database.Transaction]. +// As postgres does not support nested transactions we use savepoints to emulate them. +func (s *sqlSavepoint) Begin(ctx context.Context) (database.Transaction, error) { + return s.parent.Begin(ctx) +} diff --git a/backend/v3/storage/database/dialect/sql/tx.go b/backend/v3/storage/database/dialect/sql/tx.go new file mode 100644 index 00000000000..46b7cf5ad9b --- /dev/null +++ b/backend/v3/storage/database/dialect/sql/tx.go @@ -0,0 +1,104 @@ +package sql + +import ( + "context" + "database/sql" + "errors" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type sqlTx struct{ *sql.Tx } + +var _ database.Transaction = (*sqlTx)(nil) + +func SQLTx(tx *sql.Tx) *sqlTx { + return &sqlTx{ + Tx: tx, + } +} + +// Commit implements [database.Transaction]. +func (tx *sqlTx) Commit(ctx context.Context) error { + return wrapError(tx.Tx.Commit()) +} + +// Rollback implements [database.Transaction]. +func (tx *sqlTx) Rollback(ctx context.Context) error { + return wrapError(tx.Tx.Rollback()) +} + +// End implements [database.Transaction]. +func (tx *sqlTx) End(ctx context.Context, err error) error { + if err != nil { + rollbackErr := tx.Rollback(ctx) + if rollbackErr != nil { + err = errors.Join(err, rollbackErr) + } + return err + } + return tx.Commit(ctx) +} + +// Query implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx. +func (tx *sqlTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) { + //nolint:rowserrcheck // Rows.Close is called by the caller + rows, err := tx.QueryContext(ctx, sql, args...) + if err != nil { + return nil, wrapError(err) + } + return &Rows{rows}, nil +} + +// QueryRow implements [database.Transaction]. +// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx. +func (tx *sqlTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row { + return &Row{tx.QueryRowContext(ctx, sql, args...)} +} + +// Exec implements [database.Transaction]. +// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool. +func (tx *sqlTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) { + res, err := tx.ExecContext(ctx, sql, args...) + if err != nil { + return 0, wrapError(err) + } + return res.RowsAffected() +} + +// Begin implements [database.Transaction]. +// As postgres does not support nested transactions we use savepoints to emulate them. +func (tx *sqlTx) Begin(ctx context.Context) (database.Transaction, error) { + _, err := tx.ExecContext(ctx, createSavepoint) + if err != nil { + return nil, wrapError(err) + } + return &sqlSavepoint{tx}, nil +} + +func transactionOptionsToSQL(opts *database.TransactionOptions) *sql.TxOptions { + if opts == nil { + return nil + } + + return &sql.TxOptions{ + Isolation: isolationToSQL(opts.IsolationLevel), + ReadOnly: accessModeToSQL(opts.AccessMode), + } +} + +func isolationToSQL(isolation database.IsolationLevel) sql.IsolationLevel { + switch isolation { + case database.IsolationLevelSerializable: + return sql.LevelSerializable + case database.IsolationLevelReadCommitted: + return sql.LevelReadCommitted + default: + return sql.LevelSerializable + } +} + +func accessModeToSQL(accessMode database.AccessMode) bool { + return accessMode == database.AccessModeReadOnly +} diff --git a/backend/v3/storage/database/errors.go b/backend/v3/storage/database/errors.go new file mode 100644 index 00000000000..d51d28aa531 --- /dev/null +++ b/backend/v3/storage/database/errors.go @@ -0,0 +1,230 @@ +package database + +import ( + "errors" + "fmt" +) + +var ErrNoChanges = errors.New("update must contain a change") + +// NoRowFoundError is returned when QueryRow does not find any row. +// It wraps the dialect specific original error to provide more context. +type NoRowFoundError struct { + original error +} + +func NewNoRowFoundError(original error) error { + return &NoRowFoundError{ + original: original, + } +} + +func (e *NoRowFoundError) Error() string { + return "no row found" +} + +func (e *NoRowFoundError) Is(target error) bool { + _, ok := target.(*NoRowFoundError) + return ok +} + +func (e *NoRowFoundError) Unwrap() error { + return e.original +} + +// MultipleRowsFoundError is returned when QueryRow finds multiple rows. +// It wraps the dialect specific original error to provide more context. +type MultipleRowsFoundError struct { + original error + count int +} + +func NewMultipleRowsFoundError(original error, count int) error { + return &MultipleRowsFoundError{ + original: original, + count: count, + } +} + +func (e *MultipleRowsFoundError) Error() string { + return fmt.Sprintf("multiple rows found: %d", e.count) +} + +func (e *MultipleRowsFoundError) Is(target error) bool { + _, ok := target.(*MultipleRowsFoundError) + return ok +} + +func (e *MultipleRowsFoundError) Unwrap() error { + return e.original +} + +type IntegrityType string + +const ( + IntegrityTypeCheck IntegrityType = "check" + IntegrityTypeUnique IntegrityType = "unique" + IntegrityTypeForeign IntegrityType = "foreign" + IntegrityTypeNotNull IntegrityType = "not null" +) + +// IntegrityViolationError represents a generic integrity violation error. +// It wraps the dialect specific original error to provide more context. +type IntegrityViolationError struct { + integrityType IntegrityType + table string + constraint string + original error +} + +func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error { + return &IntegrityViolationError{ + integrityType: typ, + table: table, + constraint: constraint, + original: original, + } +} + +func (e *IntegrityViolationError) Error() string { + return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original) +} + +func (e *IntegrityViolationError) Is(target error) bool { + _, ok := target.(*IntegrityViolationError) + return ok +} + +// CheckError is returned when a check constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a check constraint was violated during an insert or update operation. +type CheckError struct { + IntegrityViolationError +} + +func NewCheckError(table, constraint string, original error) error { + return &CheckError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeCheck, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *CheckError) Is(target error) bool { + _, ok := target.(*CheckError) + return ok +} + +func (e *CheckError) Unwrap() error { + return &e.IntegrityViolationError +} + +// UniqueError is returned when a unique constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a unique constraint was violated during an insert or update operation. +type UniqueError struct { + IntegrityViolationError +} + +func NewUniqueError(table, constraint string, original error) error { + return &UniqueError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeUnique, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *UniqueError) Is(target error) bool { + _, ok := target.(*UniqueError) + return ok +} + +func (e *UniqueError) Unwrap() error { + return &e.IntegrityViolationError +} + +// ForeignKeyError is returned when a foreign key constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a foreign key constraint was violated during an insert or update operation +type ForeignKeyError struct { + IntegrityViolationError +} + +func NewForeignKeyError(table, constraint string, original error) error { + return &ForeignKeyError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeForeign, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *ForeignKeyError) Is(target error) bool { + _, ok := target.(*ForeignKeyError) + return ok +} + +func (e *ForeignKeyError) Unwrap() error { + return &e.IntegrityViolationError +} + +// NotNullError is returned when a not null constraint fails. +// It wraps the [IntegrityViolationError] to provide more context. +// It is used to indicate that a not null constraint was violated during an insert or update operation. +type NotNullError struct { + IntegrityViolationError +} + +func NewNotNullError(table, constraint string, original error) error { + return &NotNullError{ + IntegrityViolationError: IntegrityViolationError{ + integrityType: IntegrityTypeNotNull, + table: table, + constraint: constraint, + original: original, + }, + } +} + +func (e *NotNullError) Is(target error) bool { + _, ok := target.(*NotNullError) + return ok +} + +func (e *NotNullError) Unwrap() error { + return &e.IntegrityViolationError +} + +// UnknownError is returned when an unknown error occurs. +// It wraps the dialect specific original error to provide more context. +// It is used to indicate that an error occurred that does not fit into any of the other categories. +type UnknownError struct { + original error +} + +func NewUnknownError(original error) error { + return &UnknownError{ + original: original, + } +} + +func (e *UnknownError) Error() string { + return fmt.Sprintf("unknown database error: %v", e.original) +} + +func (e *UnknownError) Is(target error) bool { + _, ok := target.(*UnknownError) + return ok +} + +func (e *UnknownError) Unwrap() error { + return e.original +} diff --git a/backend/v3/storage/database/events_testing/events_test.go b/backend/v3/storage/database/events_testing/events_test.go new file mode 100644 index 00000000000..f2427ba4b59 --- /dev/null +++ b/backend/v3/storage/database/events_testing/events_test.go @@ -0,0 +1,78 @@ +//go:build integration + +package events_test + +import ( + "context" + "log" + "os" + "testing" + "time" + + "github.com/jackc/pgx/v5" + "github.com/jackc/pgx/v5/pgxpool" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" + "github.com/zitadel/zitadel/internal/integration" + v2beta "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta" + v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/system" +) + +const ConnString = "host=localhost port=5432 user=zitadel password=zitadel dbname=zitadel sslmode=disable" + +var ( + dbPool *pgxpool.Pool + CTX context.Context + Instance *integration.Instance + SystemClient system.SystemServiceClient + OrgClient v2beta_org.OrganizationServiceClient +) + +var pool database.Pool + +func TestMain(m *testing.M) { + os.Exit(func() int { + ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute) + defer cancel() + + CTX = integration.WithSystemAuthorization(ctx) + Instance = integration.NewInstance(CTX) + + SystemClient = integration.SystemClient() + OrgClient = Instance.Client.OrgV2beta + + defer func() { + _, err := Instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{ + InstanceId: Instance.Instance.Id, + }) + if err != nil { + log.Printf("Failed to delete instance on cleanup: %v\n", err) + } + }() + + var err error + dbConfig, err := pgxpool.ParseConfig(ConnString) + if err != nil { + panic(err) + } + dbConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error { + orgState, err := conn.LoadType(ctx, "zitadel.organization_state") + if err != nil { + return err + } + conn.TypeMap().RegisterType(orgState) + return nil + } + + dbPool, err = pgxpool.NewWithConfig(context.Background(), dbConfig) + if err != nil { + panic(err) + } + + pool = postgres.PGxPool(dbPool) + + return m.Run() + }()) +} diff --git a/backend/v3/storage/database/events_testing/instance_domain_test.go b/backend/v3/storage/database/events_testing/instance_domain_test.go new file mode 100644 index 00000000000..db3c0a969c8 --- /dev/null +++ b/backend/v3/storage/database/events_testing/instance_domain_test.go @@ -0,0 +1,310 @@ +//go:build integration + +package events_test + +import ( + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/integration" + v2beta "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta" + "github.com/zitadel/zitadel/pkg/grpc/system" +) + +func TestServer_TestInstanceDomainReduces(t *testing.T) { + instance := integration.NewInstance(CTX) + + instanceRepo := repository.InstanceRepository(pool) + instanceDomainRepo := instanceRepo.Domains(true) + + t.Cleanup(func() { + _, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{ + InstanceId: instance.Instance.Id, + }) + if err != nil { + t.Logf("Failed to delete instance on cleanup: %v", err) + } + }) + + // Wait for instance to be created + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + _, err := instanceRepo.Get(CTX, + database.WithCondition(instanceRepo.IDCondition(instance.Instance.Id)), + ) + assert.NoError(ttt, err) + }, retryDuration, tick) + + t.Run("test instance custom domain add reduces", func(t *testing.T) { + // Add a domain to the instance + domainName := gofakeit.DomainName() + beforeAdd := time.Now() + _, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + afterAdd := time.Now() + + t.Cleanup(func() { + _, err := instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + if err != nil { + t.Logf("Failed to delete instance domain on cleanup: %v", err) + } + }) + + // Test that domain add reduces + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + domain, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + instanceDomainRepo.TypeCondition(domain.DomainTypeCustom), + ), + ), + ) + require.NoError(ttt, err) + // event instance.domain.added + assert.Equal(ttt, domainName, domain.Domain) + assert.Equal(ttt, instance.Instance.Id, domain.InstanceID) + assert.False(ttt, *domain.IsPrimary) + assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd) + assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd) + }, retryDuration, tick) + }) + + t.Run("test instance custom domain set primary reduces", func(t *testing.T) { + // Add a domain to the instance + domainName := gofakeit.DomainName() + _, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + + t.Cleanup(func() { + // first we change the primary domain to something else + domain, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.TypeCondition(domain.DomainTypeCustom), + instanceDomainRepo.IsPrimaryCondition(false), + ), + ), + database.WithLimit(1), + ) + require.NoError(t, err) + _, err = SystemClient.SetPrimaryDomain(CTX, &system.SetPrimaryDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domain.Domain, + }) + require.NoError(t, err) + + _, err = instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + if err != nil { + t.Logf("Failed to delete instance domain on cleanup: %v", err) + } + }) + + // Wait for domain to be created + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + domain, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + instanceDomainRepo.TypeCondition(domain.DomainTypeCustom), + ), + ), + ) + require.NoError(ttt, err) + require.False(ttt, *domain.IsPrimary) + assert.Equal(ttt, domainName, domain.Domain) + }, retryDuration, tick) + + // Set domain as primary + beforeSetPrimary := time.Now() + _, err = SystemClient.SetPrimaryDomain(CTX, &system.SetPrimaryDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + afterSetPrimary := time.Now() + + // Test that set primary reduces + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + domain, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.IsPrimaryCondition(true), + instanceDomainRepo.TypeCondition(domain.DomainTypeCustom), + ), + ), + ) + require.NoError(ttt, err) + // event instance.domain.primary.set + assert.Equal(ttt, domainName, domain.Domain) + assert.True(ttt, *domain.IsPrimary) + assert.WithinRange(ttt, domain.UpdatedAt, beforeSetPrimary, afterSetPrimary) + }, retryDuration, tick) + }) + + t.Run("test instance custom domain remove reduces", func(t *testing.T) { + // Add a domain to the instance + domainName := gofakeit.DomainName() + _, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + + // Wait for domain to be created and verify it exists + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + _, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + instanceDomainRepo.TypeCondition(domain.DomainTypeCustom), + ), + ), + ) + require.NoError(ttt, err) + }, retryDuration, tick) + + // Remove the domain + _, err = instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + + // Test that domain remove reduces + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + domain, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + instanceDomainRepo.TypeCondition(domain.DomainTypeCustom), + ), + ), + ) + // event instance.domain.removed + assert.Nil(ttt, domain) + require.ErrorIs(ttt, err, new(database.NoRowFoundError)) + }, retryDuration, tick) + }) + + t.Run("test instance trusted domain add reduces", func(t *testing.T) { + // Add a domain to the instance + domainName := gofakeit.DomainName() + beforeAdd := time.Now() + _, err := instance.Client.InstanceV2Beta.AddTrustedDomain(CTX, &v2beta.AddTrustedDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + afterAdd := time.Now() + + t.Cleanup(func() { + _, err := instance.Client.InstanceV2Beta.RemoveTrustedDomain(CTX, &v2beta.RemoveTrustedDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + if err != nil { + t.Logf("Failed to delete instance domain on cleanup: %v", err) + } + }) + + // Test that domain add reduces + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + domain, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted), + ), + ), + ) + require.NoError(ttt, err) + // event instance.domain.added + assert.Equal(ttt, domainName, domain.Domain) + assert.Equal(ttt, instance.Instance.Id, domain.InstanceID) + assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd) + assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd) + }, retryDuration, tick) + }) + + t.Run("test instance trusted domain remove reduces", func(t *testing.T) { + // Add a domain to the instance + domainName := gofakeit.DomainName() + _, err := instance.Client.InstanceV2Beta.AddTrustedDomain(CTX, &v2beta.AddTrustedDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + + // Wait for domain to be created and verify it exists + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + _, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted), + ), + ), + ) + require.NoError(ttt, err) + }, retryDuration, tick) + + // Remove the domain + _, err = instance.Client.InstanceV2Beta.RemoveTrustedDomain(CTX, &v2beta.RemoveTrustedDomainRequest{ + InstanceId: instance.Instance.Id, + Domain: domainName, + }) + require.NoError(t, err) + + // Test that domain remove reduces + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + domain, err := instanceDomainRepo.Get(CTX, + database.WithCondition( + database.And( + instanceDomainRepo.InstanceIDCondition(instance.Instance.Id), + instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted), + ), + ), + ) + // event instance.domain.removed + assert.Nil(ttt, domain) + require.ErrorIs(ttt, err, new(database.NoRowFoundError)) + }, retryDuration, tick) + }) +} diff --git a/backend/v3/storage/database/events_testing/instance_test.go b/backend/v3/storage/database/events_testing/instance_test.go new file mode 100644 index 00000000000..b570624ca4e --- /dev/null +++ b/backend/v3/storage/database/events_testing/instance_test.go @@ -0,0 +1,162 @@ +//go:build integration + +package events_test + +import ( + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/system" +) + +func TestServer_TestInstanceReduces(t *testing.T) { + instanceRepo := repository.InstanceRepository(pool) + + t.Run("test instance add reduces", func(t *testing.T) { + instanceName := gofakeit.Name() + beforeCreate := time.Now() + instance, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ + InstanceName: instanceName, + Owner: &system.CreateInstanceRequest_Machine_{ + Machine: &system.CreateInstanceRequest_Machine{ + UserName: "owner", + Name: "owner", + PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + }, + }, + }) + afterCreate := time.Now() + + require.NoError(t, err) + t.Cleanup(func() { + _, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{ + InstanceId: instance.GetInstanceId(), + }) + if err != nil { + t.Logf("Failed to delete instance on cleanup: %v", err) + } + }) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())), + ) + require.NoError(ttt, err) + // event instance.added + assert.Equal(ttt, instanceName, instance.Name) + // event instance.default.org.set + assert.NotNil(t, instance.DefaultOrgID) + // event instance.iam.project.set + assert.NotNil(t, instance.IAMProjectID) + // event instance.iam.console.set + assert.NotNil(t, instance.ConsoleAppID) + // event instance.default.language.set + assert.NotNil(t, instance.DefaultLanguage) + // event instance.added + assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) + // event instance.added + assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) + }, retryDuration, tick) + }) + + t.Run("test instance update reduces", func(t *testing.T) { + instanceName := gofakeit.Name() + res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ + InstanceName: instanceName, + Owner: &system.CreateInstanceRequest_Machine_{ + Machine: &system.CreateInstanceRequest_Machine{ + UserName: "owner", + Name: "owner", + PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + }, + }, + }) + require.NoError(t, err) + t.Cleanup(func() { + _, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{ + InstanceId: res.GetInstanceId(), + }) + if err != nil { + t.Logf("Failed to delete instance on cleanup: %v", err) + } + }) + + // check instance exists + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), + ) + require.NoError(t, err) + assert.Equal(t, instanceName, instance.Name) + }, retryDuration, tick) + + instanceName += "new" + beforeUpdate := time.Now() + _, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{ + InstanceId: res.InstanceId, + InstanceName: instanceName, + }) + afterUpdate := time.Now() + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), + ) + require.NoError(t, err) + // event instance.changed + assert.Equal(t, instanceName, instance.Name) + assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate) + }, retryDuration, tick) + }) + + t.Run("test instance delete reduces", func(t *testing.T) { + instanceName := gofakeit.Name() + res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{ + InstanceName: instanceName, + Owner: &system.CreateInstanceRequest_Machine_{ + Machine: &system.CreateInstanceRequest_Machine{ + UserName: "owner", + Name: "owner", + PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{}, + }, + }, + }) + require.NoError(t, err) + + // check instance exists + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), + ) + require.NoError(t, err) + assert.Equal(t, instanceName, instance.Name) + }, retryDuration, tick) + + _, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{ + InstanceId: res.InstanceId, + }) + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + instance, err := instanceRepo.Get(CTX, + database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())), + ) + // event instance.removed + assert.Nil(t, instance) + require.ErrorIs(t, err, new(database.NoRowFoundError)) + }, retryDuration, tick) + }) +} diff --git a/backend/v3/storage/database/events_testing/org_domain_test.go b/backend/v3/storage/database/events_testing/org_domain_test.go new file mode 100644 index 00000000000..edb6fb1217f --- /dev/null +++ b/backend/v3/storage/database/events_testing/org_domain_test.go @@ -0,0 +1,133 @@ +//go:build integration + +package events_test + +import ( + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/integration" + v2beta "github.com/zitadel/zitadel/pkg/grpc/org/v2beta" +) + +func TestServer_TestOrgDomainReduces(t *testing.T) { + org, err := OrgClient.CreateOrganization(CTX, &v2beta.CreateOrganizationRequest{ + Name: gofakeit.Name(), + }) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(pool) + orgDomainRepo := orgRepo.Domains(false) + + t.Cleanup(func() { + _, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{ + Id: org.GetId(), + }) + if err != nil { + t.Logf("Failed to delete organization on cleanup: %v", err) + } + }) + + // Wait for org to be created + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + _, err := orgRepo.Get(CTX, + database.WithCondition(orgRepo.IDCondition(org.GetId())), + ) + assert.NoError(ttt, err) + }, retryDuration, tick) + + // The API call also sets the domain as primary, so we don't do a separate test for that. + t.Run("test organization domain add reduces", func(t *testing.T) { + // Add a domain to the organization + domainName := gofakeit.DomainName() + beforeAdd := time.Now() + _, err := OrgClient.AddOrganizationDomain(CTX, &v2beta.AddOrganizationDomainRequest{ + OrganizationId: org.GetId(), + Domain: domainName, + }) + require.NoError(t, err) + afterAdd := time.Now() + + t.Cleanup(func() { + _, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{ + OrganizationId: org.GetId(), + Domain: domainName, + }) + if err != nil { + t.Logf("Failed to delete domain on cleanup: %v", err) + } + }) + + // Test that domain add reduces + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + gottenDomain, err := orgDomainRepo.Get(CTX, + database.WithCondition( + database.And( + orgDomainRepo.InstanceIDCondition(Instance.Instance.Id), + orgDomainRepo.OrgIDCondition(org.Id), + orgDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + ), + ), + ) + require.NoError(ttt, err) + // event org.domain.added + assert.Equal(t, domainName, gottenDomain.Domain) + assert.Equal(t, Instance.Instance.Id, gottenDomain.InstanceID) + assert.Equal(t, org.Id, gottenDomain.OrgID) + + assert.WithinRange(t, gottenDomain.CreatedAt, beforeAdd, afterAdd) + assert.WithinRange(t, gottenDomain.UpdatedAt, beforeAdd, afterAdd) + }, retryDuration, tick) + }) + + t.Run("test org domain remove reduces", func(t *testing.T) { + // Add a domain to the organization + domainName := gofakeit.DomainName() + _, err := OrgClient.AddOrganizationDomain(CTX, &v2beta.AddOrganizationDomainRequest{ + OrganizationId: org.GetId(), + Domain: domainName, + }) + require.NoError(t, err) + + t.Cleanup(func() { + _, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{ + OrganizationId: org.GetId(), + Domain: domainName, + }) + if err != nil { + t.Logf("Failed to delete domain on cleanup: %v", err) + } + }) + + // Remove the domain + _, err = OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{ + OrganizationId: org.GetId(), + Domain: domainName, + }) + require.NoError(t, err) + + // Test that domain remove reduces + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(ttt *assert.CollectT) { + domain, err := orgDomainRepo.Get(CTX, + database.WithCondition( + database.And( + orgDomainRepo.InstanceIDCondition(Instance.Instance.Id), + orgDomainRepo.DomainCondition(database.TextOperationEqual, domainName), + ), + ), + ) + // event instance.domain.removed + assert.Nil(ttt, domain) + require.ErrorIs(ttt, err, new(database.NoRowFoundError)) + }, retryDuration, tick) + }) +} diff --git a/backend/v3/storage/database/events_testing/organization_test.go b/backend/v3/storage/database/events_testing/organization_test.go new file mode 100644 index 00000000000..7c89cfbcd51 --- /dev/null +++ b/backend/v3/storage/database/events_testing/organization_test.go @@ -0,0 +1,269 @@ +//go:build integration + +package events_test + +import ( + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/integration" + v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta" +) + +func TestServer_TestOrganizationReduces(t *testing.T) { + instanceID := Instance.ID() + orgRepo := repository.OrganizationRepository(pool) + + t.Run("test org add reduces", func(t *testing.T) { + beforeCreate := time.Now() + orgName := gofakeit.Name() + + org, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + afterCreate := time.Now() + + t.Cleanup(func() { + _, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{ + Id: org.GetId(), + }) + if err != nil { + t.Logf("Failed to delete organization on cleanup: %v", err) + } + }) + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(tt *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + database.WithCondition( + database.And( + orgRepo.IDCondition(org.GetId()), + orgRepo.InstanceIDCondition(instanceID), + ), + ), + ) + require.NoError(tt, err) + + // event org.added + assert.NotNil(t, organization.ID) + assert.Equal(t, orgName, organization.Name) + assert.NotNil(t, organization.InstanceID) + assert.Equal(t, domain.OrgStateActive, organization.State) + assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate) + }, retryDuration, tick) + }) + + t.Run("test org change reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + + t.Cleanup(func() { + _, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{ + Id: organization.Id, + }) + if err != nil { + t.Logf("Failed to delete organization on cleanup: %v", err) + } + }) + + // 2. update org name + beforeUpdate := time.Now() + orgName = orgName + "_new" + _, err = OrgClient.UpdateOrganization(CTX, &v2beta_org.UpdateOrganizationRequest{ + Id: organization.Id, + Name: orgName, + }) + require.NoError(t, err) + afterUpdate := time.Now() + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + database.WithCondition( + database.And( + orgRepo.IDCondition(organization.Id), + orgRepo.InstanceIDCondition(instanceID), + ), + ), + ) + require.NoError(t, err) + + // event org.changed + assert.Equal(t, orgName, organization.Name) + assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate) + }, retryDuration, tick) + }) + + t.Run("test org deactivate reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + t.Cleanup(func() { + // Cleanup: delete the organization + _, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{ + Id: organization.Id, + }) + if err != nil { + t.Logf("Failed to delete organization on cleanup: %v", err) + } + }) + + // 2. deactivate org name + beforeDeactivate := time.Now() + _, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{ + Id: organization.Id, + }) + + require.NoError(t, err) + afterDeactivate := time.Now() + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + database.WithCondition( + database.And( + orgRepo.IDCondition(organization.Id), + orgRepo.InstanceIDCondition(instanceID), + ), + ), + ) + require.NoError(t, err) + + // event org.deactivate + assert.Equal(t, domain.OrgStateInactive, organization.State) + assert.WithinRange(t, organization.UpdatedAt, beforeDeactivate, afterDeactivate) + }, retryDuration, tick) + }) + + t.Run("test org activate reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + t.Cleanup(func() { + // Cleanup: delete the organization + _, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{ + Id: organization.Id, + }) + if err != nil { + t.Logf("Failed to delete organization on cleanup: %v", err) + } + }) + + // 2. deactivate org name + _, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{ + Id: organization.Id, + }) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(pool) + // 3. check org deactivated + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + database.WithCondition( + database.And( + orgRepo.IDCondition(organization.Id), + orgRepo.InstanceIDCondition(instanceID), + ), + ), + ) + require.NoError(t, err) + assert.Equal(t, domain.OrgStateInactive, organization.State) + }, retryDuration, tick) + + // 4. activate org name + beforeActivate := time.Now() + _, err = OrgClient.ActivateOrganization(CTX, &v2beta_org.ActivateOrganizationRequest{ + Id: organization.Id, + }) + require.NoError(t, err) + afterActivate := time.Now() + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + database.WithCondition( + database.And( + orgRepo.IDCondition(organization.Id), + orgRepo.InstanceIDCondition(instanceID), + ), + ), + ) + require.NoError(t, err) + + // event org.reactivate + assert.Equal(t, orgName, organization.Name) + assert.Equal(t, domain.OrgStateActive, organization.State) + assert.WithinRange(t, organization.UpdatedAt, beforeActivate, afterActivate) + }, retryDuration, tick) + }) + + t.Run("test org remove reduces", func(t *testing.T) { + orgName := gofakeit.Name() + + // 1. create org + organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{ + Name: orgName, + }) + require.NoError(t, err) + + // 2. check org retrievable + orgRepo := repository.OrganizationRepository(pool) + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + _, err := orgRepo.Get(CTX, + database.WithCondition( + database.And( + orgRepo.IDCondition(organization.Id), + orgRepo.InstanceIDCondition(instanceID), + ), + ), + ) + require.NoError(t, err) + }, retryDuration, tick) + + // 3. delete org + _, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{ + Id: organization.Id, + }) + require.NoError(t, err) + + retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute) + assert.EventuallyWithT(t, func(t *assert.CollectT) { + organization, err := orgRepo.Get(CTX, + database.WithCondition( + database.And( + orgRepo.IDCondition(organization.Id), + orgRepo.InstanceIDCondition(instanceID), + ), + ), + ) + require.ErrorIs(t, err, new(database.NoRowFoundError)) + + // event org.remove + assert.Nil(t, organization) + }, retryDuration, tick) + }) +} diff --git a/backend/v3/storage/database/gen_mock.go b/backend/v3/storage/database/gen_mock.go new file mode 100644 index 00000000000..04d204cfa1b --- /dev/null +++ b/backend/v3/storage/database/gen_mock.go @@ -0,0 +1,3 @@ +package database + +//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction diff --git a/backend/v3/storage/database/migration.go b/backend/v3/storage/database/migration.go new file mode 100644 index 00000000000..5a5b5af6fc7 --- /dev/null +++ b/backend/v3/storage/database/migration.go @@ -0,0 +1,9 @@ +package database + +import "context" + +type Migrator interface { + // Migrate executes migrations to setup the database. + // The method can be called once per running Zitadel. + Migrate(ctx context.Context) error +} diff --git a/backend/v3/storage/database/operators.go b/backend/v3/storage/database/operators.go new file mode 100644 index 00000000000..c8820d918db --- /dev/null +++ b/backend/v3/storage/database/operators.go @@ -0,0 +1,136 @@ +package database + +import ( + "time" + + "golang.org/x/exp/constraints" +) + +type Value interface { + Boolean | Number | Text | Instruction +} + +type Operation interface { + BooleanOperation | NumberOperation | TextOperation +} + +type Text interface { + ~string | ~[]byte +} + +// TextOperation are operations that can be performed on text values. +type TextOperation uint8 + +const ( + // TextOperationEqual compares two strings for equality. + TextOperationEqual TextOperation = iota + 1 + // TextOperationEqualIgnoreCase compares two strings for equality, ignoring case. + TextOperationEqualIgnoreCase + // TextOperationNotEqual compares two strings for inequality. + TextOperationNotEqual + // TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case. + TextOperationNotEqualIgnoreCase + // TextOperationStartsWith checks if the first string starts with the second. + TextOperationStartsWith + // TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case. + TextOperationStartsWithIgnoreCase +) + +var textOperations = map[TextOperation]string{ + TextOperationEqual: " = ", + TextOperationEqualIgnoreCase: " LIKE ", + TextOperationNotEqual: " <> ", + TextOperationNotEqualIgnoreCase: " NOT LIKE ", + TextOperationStartsWith: " LIKE ", + TextOperationStartsWithIgnoreCase: " LIKE ", +} + +func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) { + switch op { + case TextOperationEqual, TextOperationNotEqual: + col.WriteQualified(builder) + builder.WriteString(textOperations[op]) + builder.WriteArg(value) + case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase: + builder.WriteString("LOWER(") + col.WriteQualified(builder) + builder.WriteString(")") + + builder.WriteString(textOperations[op]) + builder.WriteString("LOWER(") + builder.WriteArg(value) + builder.WriteString(")") + case TextOperationStartsWith: + col.WriteQualified(builder) + builder.WriteString(textOperations[op]) + builder.WriteArg(value) + builder.WriteString(" || '%'") + case TextOperationStartsWithIgnoreCase: + builder.WriteString("LOWER(") + col.WriteQualified(builder) + builder.WriteString(")") + + builder.WriteString(textOperations[op]) + builder.WriteString("LOWER(") + builder.WriteArg(value) + builder.WriteString(")") + builder.WriteString(" || '%'") + default: + panic("unsupported text operation") + } +} + +type Number interface { + constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration +} + +// NumberOperation are operations that can be performed on number values. +type NumberOperation uint8 + +const ( + // NumberOperationEqual compares two numbers for equality. + NumberOperationEqual NumberOperation = iota + 1 + // NumberOperationNotEqual compares two numbers for inequality. + NumberOperationNotEqual + // NumberOperationLessThan compares two numbers to check if the first is less than the second. + NumberOperationLessThan + // NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second. + NumberOperationAtLeast + // NumberOperationGreaterThan compares two numbers to check if the first is greater than the second. + NumberOperationGreaterThan + // NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second. + NumberOperationAtMost +) + +var numberOperations = map[NumberOperation]string{ + NumberOperationEqual: " = ", + NumberOperationNotEqual: " <> ", + NumberOperationLessThan: " < ", + NumberOperationAtLeast: " <= ", + NumberOperationGreaterThan: " > ", + NumberOperationAtMost: " >= ", +} + +func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) { + col.WriteQualified(builder) + builder.WriteString(numberOperations[op]) + builder.WriteArg(value) +} + +type Boolean interface { + ~bool +} + +// BooleanOperation are operations that can be performed on boolean values. +type BooleanOperation uint8 + +const ( + BooleanOperationIsTrue BooleanOperation = iota + 1 + BooleanOperationIsFalse +) + +func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) { + col.WriteQualified(builder) + builder.WriteString(" = ") + builder.WriteArg(value) +} diff --git a/backend/v3/storage/database/order.go b/backend/v3/storage/database/order.go new file mode 100644 index 00000000000..ff906d9c18b --- /dev/null +++ b/backend/v3/storage/database/order.go @@ -0,0 +1,21 @@ +package database + +// Order represents a SQL condition. +// Its written after the ORDER BY keyword in a SQL statement. +type Order interface { + Write(builder *StatementBuilder) +} + +type orderBy struct { + column Column +} + +func OrderBy(column Column) Order { + return &orderBy{column: column} +} + +// Write implements [Order]. +func (o *orderBy) Write(builder *StatementBuilder) { + builder.WriteString(" ORDER BY ") + o.column.WriteQualified(builder) +} diff --git a/backend/v3/storage/database/query.go b/backend/v3/storage/database/query.go new file mode 100644 index 00000000000..df3ee0a377e --- /dev/null +++ b/backend/v3/storage/database/query.go @@ -0,0 +1,174 @@ +package database + +type QueryOption func(opts *QueryOpts) + +// WithCondition sets the condition for the query. +func WithCondition(condition Condition) QueryOption { + return func(opts *QueryOpts) { + opts.Condition = condition + } +} + +// WithOrderBy sets the columns to order the results by. +func WithOrderBy(ordering OrderDirection, orderBy ...Column) QueryOption { + return func(opts *QueryOpts) { + opts.OrderBy = orderBy + opts.Ordering = ordering + } +} + +func WithOrderByAscending(columns ...Column) QueryOption { + return WithOrderBy(OrderDirectionAsc, columns...) +} + +func WithOrderByDescending(columns ...Column) QueryOption { + return WithOrderBy(OrderDirectionDesc, columns...) +} + +// WithLimit sets the maximum number of results to return. +func WithLimit(limit uint32) QueryOption { + return func(opts *QueryOpts) { + opts.Limit = limit + } +} + +// WithOffset sets the number of results to skip before returning the results. +func WithOffset(offset uint32) QueryOption { + return func(opts *QueryOpts) { + opts.Offset = offset + } +} + +// WithGroupBy sets the columns to group the results by. +func WithGroupBy(groupBy ...Column) QueryOption { + return func(opts *QueryOpts) { + opts.GroupBy = groupBy + } +} + +// WithLeftJoin adds a LEFT JOIN to the query. +func WithLeftJoin(table string, columns Condition) QueryOption { + return func(opts *QueryOpts) { + opts.Joins = append(opts.Joins, join{ + table: table, + typ: JoinTypeLeft, + columns: columns, + }) + } +} + +type joinType string + +const ( + JoinTypeLeft joinType = "LEFT" +) + +type join struct { + table string + typ joinType + columns Condition +} + +type OrderDirection uint8 + +const ( + OrderDirectionAsc OrderDirection = iota + OrderDirectionDesc +) + +// QueryOpts holds the options for a query. +// It is used to build the SQL SELECT statement. +type QueryOpts struct { + // Condition is the condition to filter the results. + // It is used to build the WHERE clause of the SQL statement. + Condition Condition + // OrderBy is the columns to order the results by. + // It is used to build the ORDER BY clause of the SQL statement. + OrderBy Columns + // Ordering defines if the columns should be ordered ascending or descending. + // Default is ascending. + Ordering OrderDirection + // Limit is the maximum number of results to return. + // It is used to build the LIMIT clause of the SQL statement. + Limit uint32 + // Offset is the number of results to skip before returning the results. + // It is used to build the OFFSET clause of the SQL statement. + Offset uint32 + // GroupBy is the columns to group the results by. + // It is used to build the GROUP BY clause of the SQL statement. + GroupBy Columns + // Joins is a list of joins to be applied to the query. + // It is used to build the JOIN clauses of the SQL statement. + Joins []join +} + +func (opts *QueryOpts) Write(builder *StatementBuilder) { + opts.WriteLeftJoins(builder) + opts.WriteCondition(builder) + opts.WriteGroupBy(builder) + opts.WriteOrderBy(builder) + opts.WriteLimit(builder) + opts.WriteOffset(builder) +} + +func (opts *QueryOpts) WriteCondition(builder *StatementBuilder) { + if opts.Condition == nil { + return + } + builder.WriteString(" WHERE ") + opts.Condition.Write(builder) +} + +func (opts *QueryOpts) WriteOrderBy(builder *StatementBuilder) { + if len(opts.OrderBy) == 0 { + return + } + builder.WriteString(" ORDER BY ") + for i, col := range opts.OrderBy { + if i > 0 { + builder.WriteString(", ") + } + col.WriteQualified(builder) + if opts.Ordering == OrderDirectionDesc { + builder.WriteString(" DESC") + } + } +} + +func (opts *QueryOpts) WriteLimit(builder *StatementBuilder) { + if opts.Limit == 0 { + return + } + builder.WriteString(" LIMIT ") + builder.WriteArg(opts.Limit) +} + +func (opts *QueryOpts) WriteOffset(builder *StatementBuilder) { + if opts.Offset == 0 { + return + } + builder.WriteString(" OFFSET ") + builder.WriteArg(opts.Offset) +} + +func (opts *QueryOpts) WriteGroupBy(builder *StatementBuilder) { + if len(opts.GroupBy) == 0 { + return + } + builder.WriteString(" GROUP BY ") + opts.GroupBy.WriteQualified(builder) +} + +func (opts *QueryOpts) WriteLeftJoins(builder *StatementBuilder) { + if len(opts.Joins) == 0 { + return + } + for _, join := range opts.Joins { + builder.WriteString(" ") + builder.WriteString(string(join.typ)) + builder.WriteString(" JOIN ") + builder.WriteString(join.table) + builder.WriteString(" ON ") + join.columns.Write(builder) + } +} diff --git a/backend/v3/storage/database/repository/doc.go b/backend/v3/storage/database/repository/doc.go new file mode 100644 index 00000000000..bd01ea6dec3 --- /dev/null +++ b/backend/v3/storage/database/repository/doc.go @@ -0,0 +1,5 @@ +// Package implements the repositories defined in the domain package. +// The repositories are used by the domain package to access the database. +// the inheritance.sql file is me over-engineering table inheritance. +// I would create a user table which is inherited by human_user and machine_user and the same for objects like idps. +package repository diff --git a/backend/v3/storage/database/repository/inheritance.sql b/backend/v3/storage/database/repository/inheritance.sql new file mode 100644 index 00000000000..f7fceb7aed8 --- /dev/null +++ b/backend/v3/storage/database/repository/inheritance.sql @@ -0,0 +1,135 @@ +CREATE TABLE objects ( + id SERIAL PRIMARY KEY, + created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, + deleted_at TIMESTAMP +); + +CREATE OR REPLACE FUNCTION update_updated_at_column() +RETURNS TRIGGER AS $$ +BEGIN + NEW.updated_at = NOW(); + RETURN NEW; +END; +$$ LANGUAGE plpgsql; + +CREATE TABLE instances( + name VARCHAR(50) NOT NULL + , PRIMARY KEY (id) +) INHERITS (objects); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON instances +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE instance_objects( + instance_id INT NOT NULL + , PRIMARY KEY (instance_id, id) + -- as foreign keys are not inherited we need to define them on the child tables + --, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (objects); + +CREATE TABLE orgs( + name VARCHAR(50) NOT NULL + , PRIMARY KEY (instance_id, id) + , CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (instance_objects); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON orgs +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE org_objects( + org_id INT NOT NULL + , PRIMARY KEY (instance_id, org_id, id) + -- as foreign keys are not inherited we need to define them on the child tables + -- CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id), + -- CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (instance_objects); + +CREATE TABLE users ( + username VARCHAR(50) NOT NULL + , PRIMARY KEY (instance_id, org_id, id) + -- as foreign keys are not inherited we need to define them on the child tables + -- , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id) + -- , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (org_objects); + +CREATE INDEX idx_users_username ON users(username); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON users +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE human_users( + first_name VARCHAR(50) + , last_name VARCHAR(50) + , PRIMARY KEY (instance_id, org_id, id) + -- CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id), + , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id) + , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (users); + +CREATE INDEX idx_human_users_username ON human_users(username); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON human_users +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE TABLE machine_users( + description VARCHAR(50) + , PRIMARY KEY (instance_id, org_id, id) + -- , CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id) + , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id) + , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id) +) INHERITS (users); + +CREATE INDEX idx_machine_users_username ON machine_users(username); + +CREATE TRIGGER set_updated_at +BEFORE UPDATE +ON machine_users +FOR EACH ROW +EXECUTE FUNCTION update_updated_at_column(); + +CREATE VIEW users_view AS ( +SELECT + id + , created_at + , updated_at + , deleted_at + , instance_id + , org_id + , username + , tableoid::regclass::TEXT AS type + , first_name + , last_name + , NULL AS description +FROM + human_users + +UNION + +SELECT + id + , created_at + , updated_at + , deleted_at + , instance_id + , org_id + , username + , tableoid::regclass::TEXT AS type + , NULL AS first_name + , NULL AS last_name + , description +FROM + machine_users +); \ No newline at end of file diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go new file mode 100644 index 00000000000..74f4c1f3f9b --- /dev/null +++ b/backend/v3/storage/database/repository/instance.go @@ -0,0 +1,306 @@ +package repository + +import ( + "context" + "database/sql" + "encoding/json" + "time" + + "golang.org/x/text/language" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var _ domain.InstanceRepository = (*instance)(nil) + +type instance struct { + repository + shouldLoadDomains bool + domainRepo *instanceDomain +} + +func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository { + return &instance{ + repository: repository{ + client: client, + }, + } +} + +// ------------------------------------------------------------- +// repository +// ------------------------------------------------------------- + +const ( + queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` + + ` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` + + ` FROM zitadel.instances` +) + +// Get implements [domain.InstanceRepository]. +func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) { + opts = append(opts, + i.joinDomains(), + database.WithGroupBy(i.IDColumn()), + ) + + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryInstanceStmt) + options.Write(&builder) + + return scanInstance(ctx, i.client, &builder) +} + +// List implements [domain.InstanceRepository]. +func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) { + opts = append(opts, + i.joinDomains(), + database.WithGroupBy(i.IDColumn()), + ) + + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryInstanceStmt) + options.Write(&builder) + + return scanInstances(ctx, i.client, &builder) +} + +func (i *instance) joinDomains() database.QueryOption { + columns := make([]database.Condition, 0, 2) + columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn())) + + // If domains should not be joined, we make sure to return null for the domain columns + // the query optimizer of the dialect should optimize this away if no domains are requested + if !i.shouldLoadDomains { + columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn())) + } + + return database.WithLeftJoin( + "zitadel.instance_domains", + database.And(columns...), + ) +} + +// Create implements [domain.InstanceRepository]. +func (i *instance) Create(ctx context.Context, instance *domain.Instance) error { + var ( + builder database.StatementBuilder + createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction + ) + if !instance.CreatedAt.IsZero() { + createdAt = instance.CreatedAt + } + if !instance.UpdatedAt.IsZero() { + updatedAt = instance.UpdatedAt + } + + builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`) + builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt) + builder.WriteString(`) RETURNING created_at, updated_at`) + + return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt) +} + +// Update implements [domain.InstanceRepository]. +func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) { + if len(changes) == 0 { + return 0, database.ErrNoChanges + } + var builder database.StatementBuilder + + builder.WriteString(`UPDATE zitadel.instances SET `) + + database.Changes(changes).Write(&builder) + + idCondition := i.IDCondition(id) + writeCondition(&builder, idCondition) + + stmt := builder.String() + + return i.client.Exec(ctx, stmt, builder.Args()...) +} + +// Delete implements [domain.InstanceRepository]. +func (i instance) Delete(ctx context.Context, id string) (int64, error) { + var builder database.StatementBuilder + + builder.WriteString(`DELETE FROM zitadel.instances`) + + idCondition := i.IDCondition(id) + writeCondition(&builder, idCondition) + + return i.client.Exec(ctx, builder.String(), builder.Args()...) +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +// SetName implements [domain.instanceChanges]. +func (i instance) SetName(name string) database.Change { + return database.NewChange(i.NameColumn(), name) +} + +// SetUpdatedAt implements [domain.instanceChanges]. +func (i instance) SetUpdatedAt(time time.Time) database.Change { + return database.NewChange(i.UpdatedAtColumn(), time) +} + +func (i instance) SetIAMProject(id string) database.Change { + return database.NewChange(i.IAMProjectIDColumn(), id) +} +func (i instance) SetDefaultOrg(id string) database.Change { + return database.NewChange(i.DefaultOrgIDColumn(), id) +} +func (i instance) SetDefaultLanguage(lang language.Tag) database.Change { + return database.NewChange(i.DefaultLanguageColumn(), lang.String()) +} +func (i instance) SetConsoleClientID(id string) database.Change { + return database.NewChange(i.ConsoleClientIDColumn(), id) +} +func (i instance) SetConsoleAppID(id string) database.Change { + return database.NewChange(i.ConsoleAppIDColumn(), id) +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +// IDCondition implements [domain.instanceConditions]. +func (i instance) IDCondition(id string) database.Condition { + return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id) +} + +// NameCondition implements [domain.instanceConditions]. +func (i instance) NameCondition(op database.TextOperation, name string) database.Condition { + return database.NewTextCondition(i.NameColumn(), op, name) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +// IDColumn implements [domain.instanceColumns]. +func (instance) IDColumn() database.Column { + return database.NewColumn("instances", "id") +} + +// NameColumn implements [domain.instanceColumns]. +func (instance) NameColumn() database.Column { + return database.NewColumn("instances", "name") +} + +// CreatedAtColumn implements [domain.instanceColumns]. +func (instance) CreatedAtColumn() database.Column { + return database.NewColumn("instances", "created_at") +} + +// DefaultOrgIdColumn implements [domain.instanceColumns]. +func (instance) DefaultOrgIDColumn() database.Column { + return database.NewColumn("instances", "default_org_id") +} + +// IAMProjectIDColumn implements [domain.instanceColumns]. +func (instance) IAMProjectIDColumn() database.Column { + return database.NewColumn("instances", "iam_project_id") +} + +// ConsoleClientIDColumn implements [domain.instanceColumns]. +func (instance) ConsoleClientIDColumn() database.Column { + return database.NewColumn("instances", "console_client_id") +} + +// ConsoleAppIDColumn implements [domain.instanceColumns]. +func (instance) ConsoleAppIDColumn() database.Column { + return database.NewColumn("instances", "console_app_id") +} + +// DefaultLanguageColumn implements [domain.instanceColumns]. +func (instance) DefaultLanguageColumn() database.Column { + return database.NewColumn("instances", "default_language") +} + +// UpdatedAtColumn implements [domain.instanceColumns]. +func (instance) UpdatedAtColumn() database.Column { + return database.NewColumn("instances", "updated_at") +} + +type rawInstance struct { + *domain.Instance + RawDomains sql.Null[json.RawMessage] `json:"domains,omitzero" db:"domains"` +} + +func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + var instance rawInstance + if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil { + return nil, err + } + + if instance.RawDomains.Valid { + if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil { + return nil, err + } + } + + return instance.Instance, nil +} + +func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Instance, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + var rawInstances []*rawInstance + if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil { + return nil, err + } + + instances := make([]*domain.Instance, len(rawInstances)) + for i, instance := range rawInstances { + if instance.RawDomains.Valid { + if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil { + return nil, err + } + } + instances[i] = instance.Instance + } + return instances, nil +} + +// ------------------------------------------------------------- +// sub repositories +// ------------------------------------------------------------- + +// Domains implements [domain.InstanceRepository]. +func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository { + if !i.shouldLoadDomains { + i.shouldLoadDomains = shouldLoad + } + + if i.domainRepo != nil { + return i.domainRepo + } + + i.domainRepo = &instanceDomain{ + repository: i.repository, + instance: i, + } + return i.domainRepo +} diff --git a/backend/v3/storage/database/repository/instance_domain.go b/backend/v3/storage/database/repository/instance_domain.go new file mode 100644 index 00000000000..3aed56238ea --- /dev/null +++ b/backend/v3/storage/database/repository/instance_domain.go @@ -0,0 +1,214 @@ +package repository + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var _ domain.InstanceDomainRepository = (*instanceDomain)(nil) + +type instanceDomain struct { + repository + *instance +} + +// ------------------------------------------------------------- +// repository +// ------------------------------------------------------------- + +const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_domains.domain, instance_domains.is_primary, instance_domains.created_at, instance_domains.updated_at ` + + `FROM zitadel.instance_domains` + +// Get implements [domain.InstanceDomainRepository]. +// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance. +func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) { + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryInstanceDomainStmt) + options.Write(&builder) + + return scanInstanceDomain(ctx, i.client, &builder) +} + +// List implements [domain.InstanceDomainRepository]. +// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance. +func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) { + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryInstanceDomainStmt) + options.Write(&builder) + + return scanInstanceDomains(ctx, i.client, &builder) +} + +// Add implements [domain.InstanceDomainRepository]. +func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error { + var ( + builder database.StatementBuilder + createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction + ) + if !domain.CreatedAt.IsZero() { + createdAt = domain.CreatedAt + } + if !domain.UpdatedAt.IsZero() { + updatedAt = domain.UpdatedAt + } + + builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_primary, is_generated, type, created_at, updated_at) VALUES (`) + builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt) + builder.WriteString(`) RETURNING created_at, updated_at`) + + return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) +} + +// Update implements [domain.InstanceDomainRepository]. +// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance. +func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { + if len(changes) == 0 { + return 0, database.ErrNoChanges + } + var builder database.StatementBuilder + + builder.WriteString(`UPDATE zitadel.instance_domains SET `) + database.Changes(changes).Write(&builder) + + writeCondition(&builder, condition) + + return i.client.Exec(ctx, builder.String(), builder.Args()...) +} + +// Remove implements [domain.InstanceDomainRepository]. +func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) { + var builder database.StatementBuilder + + builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `) + condition.Write(&builder) + + return i.client.Exec(ctx, builder.String(), builder.Args()...) +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +// SetPrimary implements [domain.InstanceDomainRepository]. +func (i instanceDomain) SetPrimary() database.Change { + return database.NewChange(i.IsPrimaryColumn(), true) +} + +// SetUpdatedAt implements [domain.OrganizationDomainRepository]. +func (i instanceDomain) SetUpdatedAt(updatedAt time.Time) database.Change { + return database.NewChange(i.UpdatedAtColumn(), updatedAt) +} + +// SetType implements [domain.InstanceDomainRepository]. +func (i instanceDomain) SetType(typ domain.DomainType) database.Change { + return database.NewChange(i.TypeColumn(), typ) +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +// DomainCondition implements [domain.InstanceDomainRepository]. +func (i instanceDomain) DomainCondition(op database.TextOperation, domain string) database.Condition { + return database.NewTextCondition(i.DomainColumn(), op, domain) +} + +// InstanceIDCondition implements [domain.InstanceDomainRepository]. +func (i instanceDomain) InstanceIDCondition(instanceID string) database.Condition { + return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, instanceID) +} + +// IsPrimaryCondition implements [domain.InstanceDomainRepository]. +func (i instanceDomain) IsPrimaryCondition(isPrimary bool) database.Condition { + return database.NewBooleanCondition(i.IsPrimaryColumn(), isPrimary) +} + +// TypeCondition implements [domain.InstanceDomainRepository]. +func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition { + return database.NewTextCondition(i.TypeColumn(), database.TextOperationEqual, typ.String()) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +// CreatedAtColumn implements [domain.InstanceDomainRepository]. +// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance. +func (instanceDomain) CreatedAtColumn() database.Column { + return database.NewColumn("instance_domains", "created_at") +} + +// DomainColumn implements [domain.InstanceDomainRepository]. +func (instanceDomain) DomainColumn() database.Column { + return database.NewColumn("instance_domains", "domain") +} + +// InstanceIDColumn implements [domain.InstanceDomainRepository]. +func (instanceDomain) InstanceIDColumn() database.Column { + return database.NewColumn("instance_domains", "instance_id") +} + +// IsPrimaryColumn implements [domain.InstanceDomainRepository]. +func (instanceDomain) IsPrimaryColumn() database.Column { + return database.NewColumn("instance_domains", "is_primary") +} + +// UpdatedAtColumn implements [domain.InstanceDomainRepository]. +// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance. +func (instanceDomain) UpdatedAtColumn() database.Column { + return database.NewColumn("instance_domains", "updated_at") +} + +// IsGeneratedColumn implements [domain.InstanceDomainRepository]. +func (instanceDomain) IsGeneratedColumn() database.Column { + return database.NewColumn("instance_domains", "is_generated") +} + +// TypeColumn implements [domain.InstanceDomainRepository]. +func (instanceDomain) TypeColumn() database.Column { + return database.NewColumn("instance_domains", "type") +} + +// ------------------------------------------------------------- +// scanners +// ------------------------------------------------------------- + +func scanInstanceDomains(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.InstanceDomain, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + var domains []*domain.InstanceDomain + if err := rows.(database.CollectableRows).Collect(&domains); err != nil { + return nil, err + } + + return domains, nil +} + +func scanInstanceDomain(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.InstanceDomain, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + domain := new(domain.InstanceDomain) + if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil { + return nil, err + } + + return domain, nil +} diff --git a/backend/v3/storage/database/repository/instance_domain_test.go b/backend/v3/storage/database/repository/instance_domain_test.go new file mode 100644 index 00000000000..cc18ada380f --- /dev/null +++ b/backend/v3/storage/database/repository/instance_domain_test.go @@ -0,0 +1,701 @@ +package repository_test + +import ( + "context" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestAddInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain + instanceDomain domain.AddInstanceDomain + err error + }{ + { + name: "happy path custom domain", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + }, + }, + { + name: "happy path trusted domain", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeTrusted, + }, + }, + { + name: "add primary domain", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(true), + IsGenerated: gu.Ptr(false), + }, + }, + { + name: "add custom domain without domain name", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: "", + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + }, + err: new(database.CheckError), + }, + { + name: "add trusted domain without domain name", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: "", + Type: domain.DomainTypeTrusted, + }, + err: new(database.CheckError), + }, + { + name: "add custom domain with same domain twice", + testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain { + domainName := gofakeit.DomainName() + + instanceDomain := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + } + + err := domainRepo.Add(ctx, instanceDomain) + require.NoError(t, err) + + // return same domain again + return &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + } + }, + err: new(database.UniqueError), + }, + { + name: "add trusted domain with same domain twice", + testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain { + domainName := gofakeit.DomainName() + + instanceDomain := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + Type: domain.DomainTypeTrusted, + } + + err := domainRepo.Add(ctx, instanceDomain) + require.NoError(t, err) + + // return same domain again + return &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + Type: domain.DomainTypeTrusted, + } + }, + err: new(database.UniqueError), + }, + { + name: "add domain with non-existent instance", + instanceDomain: domain.AddInstanceDomain{ + InstanceID: "non-existent-instance", + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + }, + err: new(database.ForeignKeyError), + }, + { + name: "add domain without instance id", + instanceDomain: domain.AddInstanceDomain{ + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + }, + err: new(database.ForeignKeyError), + }, + { + name: "add custom domain without primary", + instanceDomain: domain.AddInstanceDomain{ + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeCustom, + IsGenerated: gu.Ptr(false), + }, + err: new(database.CheckError), + }, + { + name: "add custom domain without generated", + instanceDomain: domain.AddInstanceDomain{ + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeCustom, + IsPrimary: gu.Ptr(false), + }, + err: new(database.CheckError), + }, + { + name: "add trusted domain with primary", + instanceDomain: domain.AddInstanceDomain{ + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeTrusted, + IsPrimary: gu.Ptr(false), + }, + err: new(database.CheckError), + }, + { + name: "add trusted domain with generated", + instanceDomain: domain.AddInstanceDomain{ + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeTrusted, + IsGenerated: gu.Ptr(false), + }, + err: new(database.CheckError), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + // we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields + beforeAdd := time.Now() + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + instanceRepo := repository.InstanceRepository(tx) + domainRepo := instanceRepo.Domains(false) + + var instanceDomain *domain.AddInstanceDomain + if test.testFunc != nil { + instanceDomain = test.testFunc(ctx, t, domainRepo) + } else { + instanceDomain = &test.instanceDomain + } + + err = domainRepo.Add(ctx, instanceDomain) + afterAdd := time.Now() + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.NotZero(t, instanceDomain.CreatedAt) + assert.NotZero(t, instanceDomain.UpdatedAt) + assert.WithinRange(t, instanceDomain.CreatedAt, beforeAdd, afterAdd) + assert.WithinRange(t, instanceDomain.UpdatedAt, beforeAdd, afterAdd) + }) + } +} + +func TestGetInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add domains + domainRepo := instanceRepo.Domains(false) + domainName1 := gofakeit.DomainName() + domainName2 := gofakeit.DomainName() + + domain1 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName1, + IsPrimary: gu.Ptr(true), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + } + domain2 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName2, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + } + + err = domainRepo.Add(t.Context(), domain1) + require.NoError(t, err) + err = domainRepo.Add(t.Context(), domain2) + require.NoError(t, err) + + tests := []struct { + name string + opts []database.QueryOption + expected *domain.InstanceDomain + err error + }{ + { + name: "get primary domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsPrimaryCondition(true)), + }, + expected: &domain.InstanceDomain{ + InstanceID: instanceID, + Domain: domainName1, + IsPrimary: gu.Ptr(true), + }, + }, + { + name: "get by domain name", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)), + }, + expected: &domain.InstanceDomain{ + InstanceID: instanceID, + Domain: domainName2, + IsPrimary: gu.Ptr(false), + }, + }, + { + name: "get non-existent domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")), + }, + err: new(database.NoRowFoundError), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + result, err := domainRepo.Get(ctx, test.opts...) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expected.InstanceID, result.InstanceID) + assert.Equal(t, test.expected.Domain, result.Domain) + assert.Equal(t, test.expected.IsPrimary, result.IsPrimary) + assert.NotEmpty(t, result.CreatedAt) + assert.NotEmpty(t, result.UpdatedAt) + }) + } +} + +func TestListInstanceDomains(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add multiple domains + domainRepo := instanceRepo.Domains(false) + domains := []domain.AddInstanceDomain{ + { + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsPrimary: gu.Ptr(true), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + }, + { + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + }, + { + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + Type: domain.DomainTypeTrusted, + }, + } + + for i := range domains { + err = domainRepo.Add(t.Context(), &domains[i]) + require.NoError(t, err) + } + + tests := []struct { + name string + opts []database.QueryOption + expectedCount int + }{ + { + name: "list all domains", + opts: []database.QueryOption{}, + expectedCount: 3, + }, + { + name: "list primary domains", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsPrimaryCondition(true)), + }, + expectedCount: 1, + }, + { + name: "list by instance", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.InstanceIDCondition(instanceID)), + }, + expectedCount: 3, + }, + { + name: "list non-existent instance", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.InstanceIDCondition("non-existent")), + }, + expectedCount: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + results, err := domainRepo.List(ctx, test.opts...) + require.NoError(t, err) + assert.Len(t, results, test.expectedCount) + + for _, result := range results { + assert.Equal(t, instanceID, result.InstanceID) + assert.NotEmpty(t, result.Domain) + assert.NotEmpty(t, result.CreatedAt) + assert.NotEmpty(t, result.UpdatedAt) + } + }) + } +} + +func TestUpdateInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add domain + domainRepo := instanceRepo.Domains(false) + domainName := gofakeit.DomainName() + instanceDomain := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName, + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + } + + err = domainRepo.Add(t.Context(), instanceDomain) + require.NoError(t, err) + + tests := []struct { + name string + condition database.Condition + changes []database.Change + expected int64 + err error + }{ + { + name: "set primary", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{domainRepo.SetPrimary()}, + expected: 1, + }, + { + name: "update non-existent domain", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + changes: []database.Change{domainRepo.SetPrimary()}, + expected: 0, + }, + { + name: "no changes", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{}, + expected: 0, + err: database.ErrNoChanges, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expected, rowsAffected) + + // verify changes were applied if rows were affected + if rowsAffected > 0 && len(test.changes) > 0 { + result, err := domainRepo.Get(ctx, database.WithCondition(test.condition)) + require.NoError(t, err) + + // We know changes were applied since rowsAffected > 0 + // The specific verification of what changed is less important + // than knowing the operation succeeded + assert.NotNil(t, result) + } + }) + } +} + +func TestRemoveInstanceDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // add domains + domainRepo := instanceRepo.Domains(false) + domainName1 := gofakeit.DomainName() + + domain1 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: domainName1, + IsPrimary: gu.Ptr(true), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + } + domain2 := &domain.AddInstanceDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsPrimary: gu.Ptr(false), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + } + + err = domainRepo.Add(t.Context(), domain1) + require.NoError(t, err) + err = domainRepo.Add(t.Context(), domain2) + require.NoError(t, err) + + tests := []struct { + name string + condition database.Condition + expected int64 + }{ + { + name: "remove by domain name", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1), + expected: 1, + }, + { + name: "remove by primary condition", + condition: domainRepo.IsPrimaryCondition(false), + expected: 1, // domain2 should still exist and be non-primary + }, + { + name: "remove non-existent domain", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + expected: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + // count before removal + beforeCount, err := domainRepo.List(ctx) + require.NoError(t, err) + + rowsAffected, err := domainRepo.Remove(ctx, test.condition) + require.NoError(t, err) + assert.Equal(t, test.expected, rowsAffected) + + // verify removal + afterCount, err := domainRepo.List(ctx) + require.NoError(t, err) + assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount)) + }) + } +} + +func TestInstanceDomainConditions(t *testing.T) { + instanceRepo := repository.InstanceRepository(pool) + domainRepo := instanceRepo.Domains(false) + + tests := []struct { + name string + condition database.Condition + expected string + }{ + { + name: "domain condition equal", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "example.com"), + expected: "instance_domains.domain = $1", + }, + { + name: "domain condition starts with", + condition: domainRepo.DomainCondition(database.TextOperationStartsWith, "example"), + expected: "instance_domains.domain LIKE $1 || '%'", + }, + { + name: "instance id condition", + condition: domainRepo.InstanceIDCondition("instance-123"), + expected: "instance_domains.instance_id = $1", + }, + { + name: "is primary true", + condition: domainRepo.IsPrimaryCondition(true), + expected: "instance_domains.is_primary = $1", + }, + { + name: "is primary false", + condition: domainRepo.IsPrimaryCondition(false), + expected: "instance_domains.is_primary = $1", + }, + { + name: "is type custom", + condition: domainRepo.TypeCondition(domain.DomainTypeCustom), + expected: "instance_domains.type = $1", + }, + { + name: "is type trusted", + condition: domainRepo.TypeCondition(domain.DomainTypeTrusted), + expected: "instance_domains.type = $1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var builder database.StatementBuilder + test.condition.Write(&builder) + assert.Equal(t, test.expected, builder.String()) + }) + } +} + +func TestInstanceDomainChanges(t *testing.T) { + instanceRepo := repository.InstanceRepository(pool) + domainRepo := instanceRepo.Domains(false) + + tests := []struct { + name string + change database.Change + expected string + }{ + { + name: "set primary", + change: domainRepo.SetPrimary(), + expected: "is_primary = $1", + }, + { + name: "set type", + change: domainRepo.SetType(domain.DomainTypeCustom), + expected: "type = $1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var builder database.StatementBuilder + test.change.Write(&builder) + assert.Equal(t, test.expected, builder.String()) + }) + } +} diff --git a/backend/v3/storage/database/repository/instance_test.go b/backend/v3/storage/database/repository/instance_test.go new file mode 100644 index 00000000000..728cbdd753e --- /dev/null +++ b/backend/v3/storage/database/repository/instance_test.go @@ -0,0 +1,723 @@ +package repository_test + +import ( + "context" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestCreateInstance(t *testing.T) { + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + instance domain.Instance + err error + }{ + { + name: "happy path", + instance: func() domain.Instance { + instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + return instance + }(), + }, + { + name: "create instance without name", + instance: func() domain.Instance { + instanceId := gofakeit.Name() + // instanceName := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: "", + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + return instance + }(), + err: new(database.CheckError), + }, + { + name: "adding same instance twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + + inst := domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + err := instanceRepo.Create(ctx, &inst) + // change the name to make sure same only the id clashes + inst.Name = gofakeit.Name() + require.NoError(t, err) + return &inst + }, + err: new(database.UniqueError), + }, + func() struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + instance domain.Instance + err error + } { + instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + return struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + instance domain.Instance + err error + }{ + name: "adding instance with same name twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + + inst := domain.Instance{ + ID: gofakeit.Name(), + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + // change the id + inst.ID = instanceId + inst.CreatedAt = time.Time{} + inst.UpdatedAt = time.Time{} + return &inst + }, + instance: domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + }, + // two instances can have the sane name + err: nil, + } + }(), + { + name: "adding instance with no id", + instance: func() domain.Instance { + // instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + instance := domain.Instance{ + // ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + return instance + }(), + err: new(database.CheckError), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + var instance *domain.Instance + if tt.testFunc != nil { + instance = tt.testFunc(ctx, t) + } else { + instance = &tt.instance + } + instanceRepo := repository.InstanceRepository(pool) + + // create instance + beforeCreate := time.Now() + err := instanceRepo.Create(ctx, instance) + assert.ErrorIs(t, err, tt.err) + if err != nil { + return + } + afterCreate := time.Now() + + // check instance values + instance, err = instanceRepo.Get(ctx, + database.WithCondition( + instanceRepo.IDCondition(instance.ID), + ), + ) + require.NoError(t, err) + + assert.Equal(t, tt.instance.ID, instance.ID) + assert.Equal(t, tt.instance.Name, instance.Name) + assert.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID) + assert.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID) + assert.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID) + assert.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID) + assert.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage) + assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) + }) + } +} + +func TestUpdateInstance(t *testing.T) { + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + rowsAffected int64 + getErr error + }{ + { + name: "happy path", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + + inst := domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + return &inst + }, + rowsAffected: 1, + }, + { + name: "update deleted instance", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + + inst := domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + // delete instance + affectedRows, err := instanceRepo.Delete(ctx, + inst.ID, + ) + require.NoError(t, err) + assert.Equal(t, int64(1), affectedRows) + + return &inst + }, + rowsAffected: 0, + }, + { + name: "update non existent instance", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceId := gofakeit.Name() + + inst := domain.Instance{ + ID: instanceId, + } + return &inst + }, + rowsAffected: 0, + getErr: new(database.NoRowFoundError), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + instanceRepo := repository.InstanceRepository(pool) + + instance := tt.testFunc(ctx, t) + + beforeUpdate := time.Now() + // update name + newName := "new_" + instance.Name + rowsAffected, err := instanceRepo.Update(ctx, + instance.ID, + instanceRepo.SetName(newName), + ) + afterUpdate := time.Now() + require.NoError(t, err) + + assert.Equal(t, tt.rowsAffected, rowsAffected) + + if rowsAffected == 0 { + return + } + + // check instance values + instance, err = instanceRepo.Get(ctx, + database.WithCondition( + instanceRepo.IDCondition(instance.ID), + ), + ) + require.Equal(t, tt.getErr, err) + + assert.Equal(t, newName, instance.Name) + assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate) + }) + } +} + +func TestGetInstance(t *testing.T) { + instanceRepo := repository.InstanceRepository(pool) + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Instance + err error + } + + tests := []test{ + func() test { + instanceId := gofakeit.Name() + return test{ + name: "happy path get using id", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceName := gofakeit.Name() + + inst := domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + return &inst + }, + } + }(), + { + name: "happy path including domains", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + instanceId := gofakeit.Name() + instanceName := gofakeit.Name() + + inst := domain.Instance{ + ID: instanceId, + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + domainRepo := instanceRepo.Domains(false) + d := &domain.AddInstanceDomain{ + InstanceID: inst.ID, + Domain: gofakeit.DomainName(), + IsPrimary: gu.Ptr(true), + IsGenerated: gu.Ptr(false), + Type: domain.DomainTypeCustom, + } + err = domainRepo.Add(ctx, d) + require.NoError(t, err) + + inst.Domains = append(inst.Domains, &domain.InstanceDomain{ + InstanceID: d.InstanceID, + Domain: d.Domain, + IsPrimary: d.IsPrimary, + IsGenerated: d.IsGenerated, + Type: d.Type, + CreatedAt: d.CreatedAt, + UpdatedAt: d.UpdatedAt, + }) + + return &inst + }, + }, + { + name: "get non existent instance", + testFunc: func(ctx context.Context, t *testing.T) *domain.Instance { + inst := domain.Instance{ + ID: "get non existent instance", + } + return &inst + }, + err: new(database.NoRowFoundError), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + instanceRepo := repository.InstanceRepository(pool) + + var instance *domain.Instance + if tt.testFunc != nil { + instance = tt.testFunc(ctx, t) + } + + // check instance values + returnedInstance, err := instanceRepo.Get(ctx, + database.WithCondition( + instanceRepo.IDCondition(instance.ID), + ), + ) + if tt.err != nil { + require.ErrorIs(t, err, tt.err) + return + } + + if instance.ID == "get non existent instance" { + assert.Nil(t, returnedInstance) + return + } + + assert.Equal(t, returnedInstance.ID, instance.ID) + assert.Equal(t, returnedInstance.Name, instance.Name) + assert.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID) + assert.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID) + assert.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID) + assert.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID) + assert.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage) + }) + } +} + +func TestListInstance(t *testing.T) { + ctx := context.Background() + pool, stop, err := newEmbeddedDB(ctx) + require.NoError(t, err) + defer stop() + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) []*domain.Instance + conditionClauses []database.Condition + noInstanceReturned bool + } + tests := []test{ + { + name: "happy path single instance no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + noOfInstances := 1 + instances := make([]*domain.Instance, noOfInstances) + for i := range noOfInstances { + + inst := domain.Instance{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + instances[i] = &inst + } + + return instances + }, + }, + { + name: "happy path multiple instance no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { + instanceRepo := repository.InstanceRepository(pool) + noOfInstances := 5 + instances := make([]*domain.Instance, noOfInstances) + for i := range noOfInstances { + + inst := domain.Instance{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + instances[i] = &inst + } + + return instances + }, + }, + func() test { + instanceRepo := repository.InstanceRepository(pool) + instanceId := gofakeit.Name() + return test{ + name: "instance filter on id", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { + noOfInstances := 1 + instances := make([]*domain.Instance, noOfInstances) + for i := range noOfInstances { + + inst := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + instances[i] = &inst + } + + return instances + }, + conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)}, + } + }(), + func() test { + instanceRepo := repository.InstanceRepository(pool) + instanceName := gofakeit.Name() + return test{ + name: "multiple instance filter on name", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance { + noOfInstances := 5 + instances := make([]*domain.Instance, noOfInstances) + for i := range noOfInstances { + + inst := domain.Instance{ + ID: gofakeit.Name(), + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + instances[i] = &inst + } + + return instances + }, + conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)}, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + _, err := pool.Exec(ctx, "DELETE FROM zitadel.instances") + require.NoError(t, err) + }) + + instances := tt.testFunc(ctx, t) + + instanceRepo := repository.InstanceRepository(pool) + + var condition database.Condition + if len(tt.conditionClauses) > 0 { + condition = database.And(tt.conditionClauses...) + } + + // check instance values + returnedInstances, err := instanceRepo.List(ctx, + database.WithCondition(condition), + database.WithOrderByAscending(instanceRepo.CreatedAtColumn()), + ) + require.NoError(t, err) + if tt.noInstanceReturned { + assert.Nil(t, returnedInstances) + return + } + + assert.Equal(t, len(instances), len(returnedInstances)) + for i, instance := range instances { + assert.Equal(t, returnedInstances[i].ID, instance.ID) + assert.Equal(t, returnedInstances[i].Name, instance.Name) + assert.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID) + assert.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID) + assert.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID) + assert.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID) + assert.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage) + } + }) + } +} + +func TestDeleteInstance(t *testing.T) { + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) + instanceID string + noOfDeletedRows int64 + } + tests := []test{ + func() test { + instanceRepo := repository.InstanceRepository(pool) + instanceId := gofakeit.Name() + var noOfInstances int64 = 1 + return test{ + name: "happy path delete single instance filter id", + testFunc: func(ctx context.Context, t *testing.T) { + instances := make([]*domain.Instance, noOfInstances) + for i := range noOfInstances { + + inst := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + instances[i] = &inst + } + }, + instanceID: instanceId, + noOfDeletedRows: noOfInstances, + } + }(), + func() test { + non_existent_instance_name := gofakeit.Name() + return test{ + name: "delete non existent instance", + instanceID: non_existent_instance_name, + } + }(), + func() test { + instanceRepo := repository.InstanceRepository(pool) + instanceName := gofakeit.Name() + return test{ + name: "deleted already deleted instance", + testFunc: func(ctx context.Context, t *testing.T) { + noOfInstances := 1 + instances := make([]*domain.Instance, noOfInstances) + for i := range noOfInstances { + + inst := domain.Instance{ + ID: gofakeit.Name(), + Name: instanceName, + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create instance + err := instanceRepo.Create(ctx, &inst) + require.NoError(t, err) + + instances[i] = &inst + } + + // delete instance + affectedRows, err := instanceRepo.Delete(ctx, + instances[0].ID, + ) + require.NoError(t, err) + assert.Equal(t, int64(1), affectedRows) + }, + instanceID: instanceName, + // this test should return 0 affected rows as the instance was already deleted + noOfDeletedRows: 0, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + instanceRepo := repository.InstanceRepository(pool) + + if tt.testFunc != nil { + tt.testFunc(ctx, t) + } + + // delete instance + noOfDeletedRows, err := instanceRepo.Delete(ctx, + tt.instanceID, + ) + require.NoError(t, err) + assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows) + + // check instance was deleted + instance, err := instanceRepo.Get(ctx, + database.WithCondition( + instanceRepo.IDCondition(tt.instanceID), + ), + ) + require.ErrorIs(t, err, new(database.NoRowFoundError)) + assert.Nil(t, instance) + }) + } +} diff --git a/backend/v3/storage/database/repository/org.go b/backend/v3/storage/database/repository/org.go new file mode 100644 index 00000000000..49bff75ca87 --- /dev/null +++ b/backend/v3/storage/database/repository/org.go @@ -0,0 +1,282 @@ +package repository + +import ( + "context" + "encoding/json" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +// ------------------------------------------------------------- +// repository +// ------------------------------------------------------------- + +var _ domain.OrganizationRepository = (*org)(nil) + +type org struct { + repository + shouldLoadDomains bool + domainRepo domain.OrganizationDomainRepository +} + +func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository { + return &org{ + repository: repository{ + client: client, + }, + } +} + +const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` + + ` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` + + ` FROM zitadel.organizations` + +// Get implements [domain.OrganizationRepository]. +func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) { + opts = append(opts, + o.joinDomains(), + database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()), + ) + + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryOrganizationStmt) + options.Write(&builder) + + return scanOrganization(ctx, o.client, &builder) +} + +// List implements [domain.OrganizationRepository]. +func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) { + opts = append(opts, + o.joinDomains(), + database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()), + ) + + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryOrganizationStmt) + options.Write(&builder) + + return scanOrganizations(ctx, o.client, &builder) +} + +func (o *org) joinDomains() database.QueryOption { + columns := make([]database.Condition, 0, 3) + columns = append(columns, + database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()), + database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()), + ) + + // If domains should not be joined, we make sure to return null for the domain columns + // the query optimizer of the dialect should optimize this away if no domains are requested + if !o.shouldLoadDomains { + columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn())) + } + + return database.WithLeftJoin( + "zitadel.org_domains", + database.And(columns...), + ) +} + +const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` + + ` VALUES ($1, $2, $3, $4)` + + ` RETURNING created_at, updated_at` + +// Create implements [domain.OrganizationRepository]. +func (o *org) Create(ctx context.Context, organization *domain.Organization) error { + builder := database.StatementBuilder{} + builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State) + builder.WriteString(createOrganizationStmt) + + return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt) +} + +// Update implements [domain.OrganizationRepository]. +func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) { + if len(changes) == 0 { + return 0, database.ErrNoChanges + } + builder := database.StatementBuilder{} + builder.WriteString(`UPDATE zitadel.organizations SET `) + + instanceIDCondition := o.InstanceIDCondition(instanceID) + + conditions := []database.Condition{id, instanceIDCondition} + database.Changes(changes).Write(&builder) + writeCondition(&builder, database.And(conditions...)) + + stmt := builder.String() + + rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...) + return rowsAffected, err +} + +// Delete implements [domain.OrganizationRepository]. +func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) { + builder := database.StatementBuilder{} + + builder.WriteString(`DELETE FROM zitadel.organizations`) + + instanceIDCondition := o.InstanceIDCondition(instanceID) + + conditions := []database.Condition{id, instanceIDCondition} + writeCondition(&builder, database.And(conditions...)) + + return o.client.Exec(ctx, builder.String(), builder.Args()...) +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +// SetName implements [domain.organizationChanges]. +func (o org) SetName(name string) database.Change { + return database.NewChange(o.NameColumn(), name) +} + +// SetState implements [domain.organizationChanges]. +func (o org) SetState(state domain.OrgState) database.Change { + return database.NewChange(o.StateColumn(), state) +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +// IDCondition implements [domain.organizationConditions]. +func (o org) IDCondition(id string) domain.OrgIdentifierCondition { + return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id) +} + +// NameCondition implements [domain.organizationConditions]. +func (o org) NameCondition(name string) domain.OrgIdentifierCondition { + return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name) +} + +// InstanceIDCondition implements [domain.organizationConditions]. +func (o org) InstanceIDCondition(instanceID string) database.Condition { + return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID) +} + +// StateCondition implements [domain.organizationConditions]. +func (o org) StateCondition(state domain.OrgState) database.Condition { + return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String()) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +// IDColumn implements [domain.organizationColumns]. +func (org) IDColumn() database.Column { + return database.NewColumn("organizations", "id") +} + +// NameColumn implements [domain.organizationColumns]. +func (org) NameColumn() database.Column { + return database.NewColumn("organizations", "name") +} + +// InstanceIDColumn implements [domain.organizationColumns]. +func (org) InstanceIDColumn() database.Column { + return database.NewColumn("organizations", "instance_id") +} + +// StateColumn implements [domain.organizationColumns]. +func (org) StateColumn() database.Column { + return database.NewColumn("organizations", "state") +} + +// CreatedAtColumn implements [domain.organizationColumns]. +func (org) CreatedAtColumn() database.Column { + return database.NewColumn("organizations", "created_at") +} + +// UpdatedAtColumn implements [domain.organizationColumns]. +func (org) UpdatedAtColumn() database.Column { + return database.NewColumn("organizations", "updated_at") +} + +// ------------------------------------------------------------- +// scanners +// ------------------------------------------------------------- + +type rawOrganization struct { + *domain.Organization + RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"` +} + +func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + var org rawOrganization + if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil { + return nil, err + } + if len(org.RawDomains) > 0 { + if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil { + return nil, err + } + } + + return org.Organization, nil +} + +func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) { + rows, err := querier.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + var rawOrgs []*rawOrganization + if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil { + return nil, err + } + + organizations := make([]*domain.Organization, len(rawOrgs)) + for i, org := range rawOrgs { + if len(org.RawDomains) > 0 { + if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil { + return nil, err + } + } + organizations[i] = org.Organization + } + return organizations, nil +} + +// ------------------------------------------------------------- +// sub repositories +// ------------------------------------------------------------- + +// Domains implements [domain.OrganizationRepository]. +func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository { + if !o.shouldLoadDomains { + o.shouldLoadDomains = shouldLoad + } + + if o.domainRepo != nil { + return o.domainRepo + } + + o.domainRepo = &orgDomain{ + repository: o.repository, + org: o, + } + + return o.domainRepo +} diff --git a/backend/v3/storage/database/repository/org_domain.go b/backend/v3/storage/database/repository/org_domain.go new file mode 100644 index 00000000000..5b2e91acf28 --- /dev/null +++ b/backend/v3/storage/database/repository/org_domain.go @@ -0,0 +1,230 @@ +package repository + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +var _ domain.OrganizationDomainRepository = (*orgDomain)(nil) + +type orgDomain struct { + repository + *org +} + +// ------------------------------------------------------------- +// repository +// ------------------------------------------------------------- + +const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_verified, is_primary, validation_type, created_at, updated_at ` + + `FROM zitadel.org_domains` + +// Get implements [domain.OrganizationDomainRepository]. +// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org. +func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) { + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryOrganizationDomainStmt) + options.Write(&builder) + + return scanOrganizationDomain(ctx, o.client, &builder) +} + +// List implements [domain.OrganizationDomainRepository]. +// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org. +func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) { + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + var builder database.StatementBuilder + builder.WriteString(queryOrganizationDomainStmt) + options.Write(&builder) + + return scanOrganizationDomains(ctx, o.client, &builder) +} + +// Add implements [domain.OrganizationDomainRepository]. +func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error { + var ( + builder database.StatementBuilder + createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction + ) + if !domain.CreatedAt.IsZero() { + createdAt = domain.CreatedAt + } + if !domain.UpdatedAt.IsZero() { + updatedAt = domain.UpdatedAt + } + + builder.WriteString(`INSERT INTO zitadel.org_domains (instance_id, org_id, domain, is_verified, is_primary, validation_type, created_at, updated_at) VALUES (`) + builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt) + builder.WriteString(`) RETURNING created_at, updated_at`) + + return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt) +} + +// Update implements [domain.OrganizationDomainRepository]. +// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org. +func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { + if len(changes) == 0 { + return 0, database.ErrNoChanges + } + + var builder database.StatementBuilder + + builder.WriteString(`UPDATE zitadel.org_domains SET `) + database.Changes(changes).Write(&builder) + writeCondition(&builder, condition) + + return o.client.Exec(ctx, builder.String(), builder.Args()...) +} + +// Remove implements [domain.OrganizationDomainRepository]. +func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) { + var builder database.StatementBuilder + + builder.WriteString(`DELETE FROM zitadel.org_domains `) + writeCondition(&builder, condition) + + return o.client.Exec(ctx, builder.String(), builder.Args()...) +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +// SetPrimary implements [domain.OrganizationDomainRepository]. +func (o orgDomain) SetPrimary() database.Change { + return database.NewChange(o.IsPrimaryColumn(), true) +} + +// SetValidationType implements [domain.OrganizationDomainRepository]. +func (o orgDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change { + return database.NewChange(o.ValidationTypeColumn(), verificationType) +} + +// SetVerified implements [domain.OrganizationDomainRepository]. +func (o orgDomain) SetVerified() database.Change { + return database.NewChange(o.IsVerifiedColumn(), true) +} + +// SetUpdatedAt implements [domain.OrganizationDomainRepository]. +func (o orgDomain) SetUpdatedAt(updatedAt time.Time) database.Change { + return database.NewChange(o.UpdatedAtColumn(), updatedAt) +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +// DomainCondition implements [domain.OrganizationDomainRepository]. +func (o orgDomain) DomainCondition(op database.TextOperation, domain string) database.Condition { + return database.NewTextCondition(o.DomainColumn(), op, domain) +} + +// InstanceIDCondition implements [domain.OrganizationDomainRepository]. +// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDCondition of orgDomain.org. +func (o orgDomain) InstanceIDCondition(instanceID string) database.Condition { + return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID) +} + +// IsPrimaryCondition implements [domain.OrganizationDomainRepository]. +func (o orgDomain) IsPrimaryCondition(isPrimary bool) database.Condition { + return database.NewBooleanCondition(o.IsPrimaryColumn(), isPrimary) +} + +// IsVerifiedCondition implements [domain.OrganizationDomainRepository]. +func (o orgDomain) IsVerifiedCondition(isVerified bool) database.Condition { + return database.NewBooleanCondition(o.IsVerifiedColumn(), isVerified) +} + +// OrgIDCondition implements [domain.OrganizationDomainRepository]. +func (o orgDomain) OrgIDCondition(orgID string) database.Condition { + return database.NewTextCondition(o.OrgIDColumn(), database.TextOperationEqual, orgID) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +// CreatedAtColumn implements [domain.OrganizationDomainRepository]. +// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org. +func (orgDomain) CreatedAtColumn() database.Column { + return database.NewColumn("org_domains", "created_at") +} + +// DomainColumn implements [domain.OrganizationDomainRepository]. +func (orgDomain) DomainColumn() database.Column { + return database.NewColumn("org_domains", "domain") +} + +// InstanceIDColumn implements [domain.OrganizationDomainRepository]. +// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org. +func (orgDomain) InstanceIDColumn() database.Column { + return database.NewColumn("org_domains", "instance_id") +} + +// IsPrimaryColumn implements [domain.OrganizationDomainRepository]. +func (orgDomain) IsPrimaryColumn() database.Column { + return database.NewColumn("org_domains", "is_primary") +} + +// IsVerifiedColumn implements [domain.OrganizationDomainRepository]. +func (orgDomain) IsVerifiedColumn() database.Column { + return database.NewColumn("org_domains", "is_verified") +} + +// OrgIDColumn implements [domain.OrganizationDomainRepository]. +func (orgDomain) OrgIDColumn() database.Column { + return database.NewColumn("org_domains", "org_id") +} + +// UpdatedAtColumn implements [domain.OrganizationDomainRepository]. +// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org. +func (orgDomain) UpdatedAtColumn() database.Column { + return database.NewColumn("org_domains", "updated_at") +} + +// ValidationTypeColumn implements [domain.OrganizationDomainRepository]. +func (orgDomain) ValidationTypeColumn() database.Column { + return database.NewColumn("org_domains", "validation_type") +} + +// ------------------------------------------------------------- +// scanners +// ------------------------------------------------------------- + +func scanOrganizationDomain(ctx context.Context, client database.Querier, builder *database.StatementBuilder) (*domain.OrganizationDomain, error) { + rows, err := client.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + domain := &domain.OrganizationDomain{} + if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil { + return nil, err + } + return domain, nil +} + +func scanOrganizationDomains(ctx context.Context, client database.Querier, builder *database.StatementBuilder) ([]*domain.OrganizationDomain, error) { + rows, err := client.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + + var domains []*domain.OrganizationDomain + if err := rows.(database.CollectableRows).Collect(&domains); err != nil { + return nil, err + } + return domains, nil +} diff --git a/backend/v3/storage/database/repository/org_domain_test.go b/backend/v3/storage/database/repository/org_domain_test.go new file mode 100644 index 00000000000..a06d9eeee14 --- /dev/null +++ b/backend/v3/storage/database/repository/org_domain_test.go @@ -0,0 +1,972 @@ +package repository_test + +import ( + "context" + "testing" + + "github.com/brianvoe/gofakeit/v6" + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestAddOrganizationDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + // create organization + orgID := gofakeit.UUID() + organization := domain.Organization{ + ID: orgID, + Name: gofakeit.Name(), + InstanceID: instanceID, + State: domain.OrgStateActive, + } + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain + organizationDomain domain.AddOrganizationDomain + err error + }{ + { + name: "happy path", + organizationDomain: domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + }, + { + name: "add verified domain", + organizationDomain: domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), + }, + }, + { + name: "add primary domain", + organizationDomain: domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + }, + { + name: "add domain without domain name", + organizationDomain: domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: "", + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + err: new(database.CheckError), + }, + { + name: "add domain with same domain twice", + testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain { + domainName := gofakeit.DomainName() + + organizationDomain := &domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName, + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + } + + err := domainRepo.Add(ctx, organizationDomain) + require.NoError(t, err) + + // return same domain again + return &domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName, + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), + } + }, + err: new(database.UniqueError), + }, + { + name: "add domain with non-existent instance", + organizationDomain: domain.AddOrganizationDomain{ + InstanceID: "non-existent-instance", + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + err: new(database.ForeignKeyError), + }, + { + name: "add domain with non-existent organization", + organizationDomain: domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: "non-existent-org", + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + err: new(database.ForeignKeyError), + }, + { + name: "add domain without instance id", + organizationDomain: domain.AddOrganizationDomain{ + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + err: new(database.ForeignKeyError), + }, + { + name: "add domain without org id", + organizationDomain: domain.AddOrganizationDomain{ + InstanceID: instanceID, + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + err: new(database.ForeignKeyError), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + orgRepo := repository.OrganizationRepository(tx) + err = orgRepo.Create(t.Context(), &organization) + require.NoError(t, err) + + domainRepo := orgRepo.Domains(false) + + var organizationDomain *domain.AddOrganizationDomain + if test.testFunc != nil { + organizationDomain = test.testFunc(ctx, t, domainRepo) + } else { + organizationDomain = &test.organizationDomain + } + + err = domainRepo.Add(ctx, organizationDomain) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.NotZero(t, organizationDomain.CreatedAt) + assert.NotZero(t, organizationDomain.UpdatedAt) + }) + } +} + +func TestGetOrganizationDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create organization + orgID := gofakeit.UUID() + organization := domain.Organization{ + ID: orgID, + Name: gofakeit.Name(), + InstanceID: instanceID, + State: domain.OrgStateActive, + } + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(tx) + err = orgRepo.Create(t.Context(), &organization) + require.NoError(t, err) + + // add domains + domainRepo := orgRepo.Domains(false) + domainName1 := gofakeit.DomainName() + domainName2 := gofakeit.DomainName() + + domain1 := &domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + } + domain2 := &domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName2, + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), + } + + err = domainRepo.Add(t.Context(), domain1) + require.NoError(t, err) + err = domainRepo.Add(t.Context(), domain2) + require.NoError(t, err) + + tests := []struct { + name string + opts []database.QueryOption + expected *domain.OrganizationDomain + err error + }{ + { + name: "get primary domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsPrimaryCondition(true)), + }, + expected: &domain.OrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + }, + { + name: "get by domain name", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)), + }, + expected: &domain.OrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName2, + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), + }, + }, + { + name: "get by org ID", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.OrgIDCondition(orgID)), + database.WithCondition(domainRepo.IsPrimaryCondition(true)), + }, + expected: &domain.OrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + }, + { + name: "get verified domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsVerifiedCondition(true)), + }, + expected: &domain.OrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + }, + { + name: "get non-existent domain", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")), + }, + err: new(database.NoRowFoundError), + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + result, err := domainRepo.Get(ctx, test.opts...) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expected.InstanceID, result.InstanceID) + assert.Equal(t, test.expected.OrgID, result.OrgID) + assert.Equal(t, test.expected.Domain, result.Domain) + assert.Equal(t, test.expected.IsVerified, result.IsVerified) + assert.Equal(t, test.expected.IsPrimary, result.IsPrimary) + assert.Equal(t, test.expected.ValidationType, result.ValidationType) + assert.NotEmpty(t, result.CreatedAt) + assert.NotEmpty(t, result.UpdatedAt) + }) + } +} + +func TestListOrganizationDomains(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create organization + orgID := gofakeit.UUID() + organization := domain.Organization{ + ID: orgID, + Name: gofakeit.Name(), + InstanceID: instanceID, + State: domain.OrgStateActive, + } + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(tx) + err = orgRepo.Create(t.Context(), &organization) + require.NoError(t, err) + + // add multiple domains + domainRepo := orgRepo.Domains(false) + domains := []domain.AddOrganizationDomain{ + { + InstanceID: instanceID, + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + { + InstanceID: instanceID, + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), + }, + { + InstanceID: instanceID, + OrgID: orgID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + }, + } + + for i := range domains { + err = domainRepo.Add(t.Context(), &domains[i]) + require.NoError(t, err) + } + + tests := []struct { + name string + opts []database.QueryOption + expectedCount int + }{ + { + name: "list all domains", + opts: []database.QueryOption{}, + expectedCount: 3, + }, + { + name: "list verified domains", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsVerifiedCondition(true)), + }, + expectedCount: 2, + }, + { + name: "list primary domains", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.IsPrimaryCondition(true)), + }, + expectedCount: 1, + }, + { + name: "list by organization", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.OrgIDCondition(orgID)), + }, + expectedCount: 3, + }, + { + name: "list by instance", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.InstanceIDCondition(instanceID)), + }, + expectedCount: 3, + }, + { + name: "list non-existent organization", + opts: []database.QueryOption{ + database.WithCondition(domainRepo.OrgIDCondition("non-existent")), + }, + expectedCount: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + results, err := domainRepo.List(ctx, test.opts...) + require.NoError(t, err) + assert.Len(t, results, test.expectedCount) + + for _, result := range results { + assert.Equal(t, instanceID, result.InstanceID) + assert.Equal(t, orgID, result.OrgID) + assert.NotEmpty(t, result.Domain) + assert.NotEmpty(t, result.CreatedAt) + assert.NotEmpty(t, result.UpdatedAt) + } + }) + } +} + +func TestUpdateOrganizationDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create organization + orgID := gofakeit.UUID() + organization := domain.Organization{ + ID: orgID, + Name: gofakeit.Name(), + InstanceID: instanceID, + State: domain.OrgStateActive, + } + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(tx) + err = orgRepo.Create(t.Context(), &organization) + require.NoError(t, err) + + // add domain + domainRepo := orgRepo.Domains(false) + domainName := gofakeit.DomainName() + organizationDomain := &domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName, + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + } + + err = domainRepo.Add(t.Context(), organizationDomain) + require.NoError(t, err) + + tests := []struct { + name string + condition database.Condition + changes []database.Change + expected int64 + err error + }{ + { + name: "set verified", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 1, + }, + { + name: "set primary", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{domainRepo.SetPrimary()}, + expected: 1, + }, + { + name: "set validation type", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)}, + expected: 1, + }, + { + name: "multiple changes", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{ + domainRepo.SetVerified(), + domainRepo.SetPrimary(), + domainRepo.SetValidationType(domain.DomainValidationTypeDNS), + }, + expected: 1, + }, + { + name: "update by org ID and domain", + condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 1, + }, + { + name: "update non-existent domain", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + changes: []database.Change{domainRepo.SetVerified()}, + expected: 0, + }, + { + name: "no changes", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName), + changes: []database.Change{}, + expected: 0, + err: database.ErrNoChanges, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...) + if test.err != nil { + assert.ErrorIs(t, err, test.err) + return + } + + require.NoError(t, err) + assert.Equal(t, test.expected, rowsAffected) + + // verify changes were applied if rows were affected + if rowsAffected > 0 && len(test.changes) > 0 { + result, err := domainRepo.Get(ctx, database.WithCondition(test.condition)) + require.NoError(t, err) + + // We know changes were applied since rowsAffected > 0 + // The specific verification of what changed is less important + // than knowing the operation succeeded + assert.NotNil(t, result) + } + }) + } +} + +func TestRemoveOrganizationDomain(t *testing.T) { + // create instance + instanceID := gofakeit.UUID() + instance := domain.Instance{ + ID: instanceID, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleClient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + + // create organization + orgID := gofakeit.UUID() + organization := domain.Organization{ + ID: orgID, + Name: gofakeit.Name(), + InstanceID: instanceID, + State: domain.OrgStateActive, + } + + tx, err := pool.Begin(t.Context(), nil) + require.NoError(t, err) + defer func() { + require.NoError(t, tx.Rollback(t.Context())) + }() + + instanceRepo := repository.InstanceRepository(tx) + err = instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(tx) + err = orgRepo.Create(t.Context(), &organization) + require.NoError(t, err) + + // add domains + domainRepo := orgRepo.Domains(false) + domainName1 := gofakeit.DomainName() + domainName2 := gofakeit.DomainName() + + domain1 := &domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName1, + IsVerified: true, + IsPrimary: true, + ValidationType: gu.Ptr(domain.DomainValidationTypeDNS), + } + domain2 := &domain.AddOrganizationDomain{ + InstanceID: instanceID, + OrgID: orgID, + Domain: domainName2, + IsVerified: false, + IsPrimary: false, + ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP), + } + + err = domainRepo.Add(t.Context(), domain1) + require.NoError(t, err) + err = domainRepo.Add(t.Context(), domain2) + require.NoError(t, err) + + tests := []struct { + name string + condition database.Condition + expected int64 + }{ + { + name: "remove by domain name", + condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1), + expected: 1, + }, + { + name: "remove by primary condition", + condition: domainRepo.IsPrimaryCondition(false), + expected: 1, // domain2 should still exist and be non-primary + }, + { + name: "remove by org ID and domain", + condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)), + expected: 1, + }, + { + name: "remove non-existent domain", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"), + expected: 0, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + ctx := t.Context() + + snapshot, err := tx.Begin(ctx) + require.NoError(t, err) + defer func() { + require.NoError(t, snapshot.Rollback(ctx)) + }() + + orgRepo := repository.OrganizationRepository(snapshot) + domainRepo := orgRepo.Domains(false) + + // count before removal + beforeCount, err := domainRepo.List(ctx) + require.NoError(t, err) + + rowsAffected, err := domainRepo.Remove(ctx, test.condition) + require.NoError(t, err) + assert.Equal(t, test.expected, rowsAffected) + + // verify removal + afterCount, err := domainRepo.List(ctx) + require.NoError(t, err) + assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount)) + }) + } +} + +func TestOrganizationDomainConditions(t *testing.T) { + orgRepo := repository.OrganizationRepository(pool) + domainRepo := orgRepo.Domains(false) + + tests := []struct { + name string + condition database.Condition + expected string + }{ + { + name: "domain condition equal", + condition: domainRepo.DomainCondition(database.TextOperationEqual, "example.com"), + expected: "org_domains.domain = $1", + }, + { + name: "domain condition starts with", + condition: domainRepo.DomainCondition(database.TextOperationStartsWith, "example"), + expected: "org_domains.domain LIKE $1 || '%'", + }, + { + name: "instance id condition", + condition: domainRepo.InstanceIDCondition("instance-123"), + expected: "org_domains.instance_id = $1", + }, + { + name: "org id condition", + condition: domainRepo.OrgIDCondition("org-123"), + expected: "org_domains.org_id = $1", + }, + { + name: "is primary true", + condition: domainRepo.IsPrimaryCondition(true), + expected: "org_domains.is_primary = $1", + }, + { + name: "is primary false", + condition: domainRepo.IsPrimaryCondition(false), + expected: "org_domains.is_primary = $1", + }, + { + name: "is verified true", + condition: domainRepo.IsVerifiedCondition(true), + expected: "org_domains.is_verified = $1", + }, + { + name: "is verified false", + condition: domainRepo.IsVerifiedCondition(false), + expected: "org_domains.is_verified = $1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var builder database.StatementBuilder + test.condition.Write(&builder) + assert.Equal(t, test.expected, builder.String()) + }) + } +} + +func TestOrganizationDomainChanges(t *testing.T) { + orgRepo := repository.OrganizationRepository(pool) + domainRepo := orgRepo.Domains(false) + + tests := []struct { + name string + change database.Change + expected string + }{ + { + name: "set verified", + change: domainRepo.SetVerified(), + expected: "is_verified = $1", + }, + { + name: "set primary", + change: domainRepo.SetPrimary(), + expected: "is_primary = $1", + }, + { + name: "set validation type DNS", + change: domainRepo.SetValidationType(domain.DomainValidationTypeDNS), + expected: "validation_type = $1", + }, + { + name: "set validation type HTTP", + change: domainRepo.SetValidationType(domain.DomainValidationTypeHTTP), + expected: "validation_type = $1", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var builder database.StatementBuilder + test.change.Write(&builder) + assert.Equal(t, test.expected, builder.String()) + }) + } +} + +func TestOrganizationDomainColumns(t *testing.T) { + orgRepo := repository.OrganizationRepository(pool) + domainRepo := orgRepo.Domains(false) + + tests := []struct { + name string + column database.Column + qualified bool + expected string + }{ + { + name: "instance id column qualified", + column: domainRepo.InstanceIDColumn(), + qualified: true, + expected: "org_domains.instance_id", + }, + { + name: "instance id column unqualified", + column: domainRepo.InstanceIDColumn(), + qualified: false, + expected: "instance_id", + }, + { + name: "org id column qualified", + column: domainRepo.OrgIDColumn(), + qualified: true, + expected: "org_domains.org_id", + }, + { + name: "org id column unqualified", + column: domainRepo.OrgIDColumn(), + qualified: false, + expected: "org_id", + }, + { + name: "domain column qualified", + column: domainRepo.DomainColumn(), + qualified: true, + expected: "org_domains.domain", + }, + { + name: "domain column unqualified", + column: domainRepo.DomainColumn(), + qualified: false, + expected: "domain", + }, + { + name: "is verified column qualified", + column: domainRepo.IsVerifiedColumn(), + qualified: true, + expected: "org_domains.is_verified", + }, + { + name: "is verified column unqualified", + column: domainRepo.IsVerifiedColumn(), + qualified: false, + expected: "is_verified", + }, + { + name: "is primary column qualified", + column: domainRepo.IsPrimaryColumn(), + qualified: true, + expected: "org_domains.is_primary", + }, + { + name: "is primary column unqualified", + column: domainRepo.IsPrimaryColumn(), + qualified: false, + expected: "is_primary", + }, + { + name: "validation type column qualified", + column: domainRepo.ValidationTypeColumn(), + qualified: true, + expected: "org_domains.validation_type", + }, + { + name: "validation type column unqualified", + column: domainRepo.ValidationTypeColumn(), + qualified: false, + expected: "validation_type", + }, + { + name: "created at column qualified", + column: domainRepo.CreatedAtColumn(), + qualified: true, + expected: "org_domains.created_at", + }, + { + name: "created at column unqualified", + column: domainRepo.CreatedAtColumn(), + qualified: false, + expected: "created_at", + }, + { + name: "updated at column qualified", + column: domainRepo.UpdatedAtColumn(), + qualified: true, + expected: "org_domains.updated_at", + }, + { + name: "updated at column unqualified", + column: domainRepo.UpdatedAtColumn(), + qualified: false, + expected: "updated_at", + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + var builder database.StatementBuilder + if test.qualified { + test.column.WriteQualified(&builder) + } else { + test.column.WriteUnqualified(&builder) + } + assert.Equal(t, test.expected, builder.String()) + }) + } +} diff --git a/backend/v3/storage/database/repository/org_member.go b/backend/v3/storage/database/repository/org_member.go new file mode 100644 index 00000000000..1720ed76301 --- /dev/null +++ b/backend/v3/storage/database/repository/org_member.go @@ -0,0 +1,28 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" +) + +type orgMember struct { + *org +} + +// AddMember implements [domain.MemberRepository]. +func (o *orgMember) AddMember(ctx context.Context, orgID string, userID string, roles []string) error { + return nil +} + +// RemoveMember implements [domain.MemberRepository]. +func (o *orgMember) RemoveMember(ctx context.Context, orgID string, userID string) error { + return nil +} + +// SetMemberRoles implements [domain.MemberRepository]. +func (o *orgMember) SetMemberRoles(ctx context.Context, orgID string, userID string, roles []string) error { + return nil +} + +var _ domain.MemberRepository = (*orgMember)(nil) diff --git a/backend/v3/storage/database/repository/org_test.go b/backend/v3/storage/database/repository/org_test.go new file mode 100644 index 00000000000..baaa02cff9c --- /dev/null +++ b/backend/v3/storage/database/repository/org_test.go @@ -0,0 +1,1008 @@ +package repository_test + +import ( + "context" + "testing" + "time" + + "github.com/brianvoe/gofakeit/v6" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestCreateOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + organization domain.Organization + err error + }{ + { + name: "happy path", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + return organization + }(), + }, + { + name: "create organization without name", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + // organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: "", + InstanceID: instanceId, + State: domain.OrgStateActive, + } + return organization + }(), + err: new(database.CheckError), + }, + { + name: "adding org with same id twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationRepo := repository.OrganizationRepository(pool) + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + // change the name to make sure same only the id clashes + org.Name = gofakeit.Name() + return &org + }, + err: new(database.UniqueError), + }, + { + name: "adding org with same name twice", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationRepo := repository.OrganizationRepository(pool) + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + // change the id to make sure same name+instance causes an error + org.ID = gofakeit.Name() + return &org + }, + err: new(database.UniqueError), + }, + func() struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + organization domain.Organization + err error + } { + orgID := gofakeit.Name() + organizationName := gofakeit.Name() + + return struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + organization domain.Organization + err error + }{ + name: "adding org with same name, different instance", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + // create instance + instId := gofakeit.Name() + instance := domain.Instance{ + ID: instId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(ctx, &instance) + assert.Nil(t, err) + + organizationRepo := repository.OrganizationRepository(pool) + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: organizationName, + InstanceID: instId, + State: domain.OrgStateActive, + } + + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // change the id to make it unique + org.ID = orgID + // change the instanceID to a different instance + org.InstanceID = instanceId + return &org + }, + organization: domain.Organization{ + ID: orgID, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + }, + } + }(), + { + name: "adding organization with no id", + organization: func() domain.Organization { + // organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + // ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + return organization + }(), + err: new(database.CheckError), + }, + { + name: "adding organization with no instance id", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: organizationName, + State: domain.OrgStateActive, + } + return organization + }(), + err: new(database.ForeignKeyError), + }, + { + name: "adding organization with non existent instance id", + organization: func() domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + organization := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: gofakeit.Name(), + State: domain.OrgStateActive, + } + return organization + }(), + err: new(database.ForeignKeyError), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + var organization *domain.Organization + if tt.testFunc != nil { + organization = tt.testFunc(ctx, t) + } else { + organization = &tt.organization + } + organizationRepo := repository.OrganizationRepository(pool) + + // create organization + beforeCreate := time.Now() + err = organizationRepo.Create(ctx, organization) + assert.ErrorIs(t, err, tt.err) + if err != nil { + return + } + afterCreate := time.Now() + + // check organization values + organization, err = organizationRepo.Get(ctx, + database.WithCondition( + database.And( + organizationRepo.IDCondition(organization.ID), + organizationRepo.InstanceIDCondition(organization.InstanceID), + ), + ), + ) + require.NoError(t, err) + + assert.Equal(t, tt.organization.ID, organization.ID) + assert.Equal(t, tt.organization.Name, organization.Name) + assert.Equal(t, tt.organization.InstanceID, organization.InstanceID) + assert.Equal(t, tt.organization.State, organization.State) + assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate) + assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate) + }) + } +} + +func TestUpdateOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + organizationRepo := repository.OrganizationRepository(pool) + + tests := []struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + update []database.Change + rowsAffected int64 + }{ + { + name: "happy path update name", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // update with updated value + org.Name = "new_name" + return &org + }, + update: []database.Change{organizationRepo.SetName("new_name")}, + rowsAffected: 1, + }, + { + name: "update deleted organization", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // delete instance + _, err = organizationRepo.Delete(ctx, + organizationRepo.IDCondition(org.ID), + org.InstanceID, + ) + require.NoError(t, err) + + return &org + }, + update: []database.Change{organizationRepo.SetName("new_name")}, + rowsAffected: 0, + }, + { + name: "happy path change state", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // update with updated value + org.State = domain.OrgStateInactive + return &org + }, + update: []database.Change{organizationRepo.SetState(domain.OrgStateInactive)}, + rowsAffected: 1, + }, + { + name: "update non existent organization", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + } + return &org + }, + update: []database.Change{organizationRepo.SetName("new_name")}, + rowsAffected: 0, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + organizationRepo := repository.OrganizationRepository(pool) + + createdOrg := tt.testFunc(ctx, t) + + // update org + beforeUpdate := time.Now() + rowsAffected, err := organizationRepo.Update(ctx, + organizationRepo.IDCondition(createdOrg.ID), + createdOrg.InstanceID, + tt.update..., + ) + afterUpdate := time.Now() + require.NoError(t, err) + + assert.Equal(t, tt.rowsAffected, rowsAffected) + + if rowsAffected == 0 { + return + } + + // check organization values + organization, err := organizationRepo.Get(ctx, + database.WithCondition( + database.And( + organizationRepo.IDCondition(createdOrg.ID), + organizationRepo.InstanceIDCondition(createdOrg.InstanceID), + ), + ), + ) + require.NoError(t, err) + + assert.Equal(t, createdOrg.ID, organization.ID) + assert.Equal(t, createdOrg.Name, organization.Name) + assert.Equal(t, createdOrg.State, organization.State) + assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate) + }) + } +} + +func TestGetOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + orgRepo := repository.OrganizationRepository(pool) + + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of the tests + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive, + } + err = orgRepo.Create(t.Context(), &org) + require.NoError(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) *domain.Organization + orgIdentifierCondition domain.OrgIdentifierCondition + err error + } + + tests := []test{ + func() test { + organizationId := gofakeit.Name() + return test{ + name: "happy path get using id", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateInactive, + } + + // create organization + err := orgRepo.Create(ctx, &org) + require.NoError(t, err) + + return &org + }, + orgIdentifierCondition: orgRepo.IDCondition(organizationId), + } + }(), + func() test { + organizationId := gofakeit.Name() + return test{ + name: "happy path get using id including domain", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationName := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := orgRepo.Create(ctx, &org) + require.NoError(t, err) + + d := &domain.AddOrganizationDomain{ + InstanceID: org.InstanceID, + OrgID: org.ID, + Domain: gofakeit.DomainName(), + IsVerified: true, + IsPrimary: true, + } + err = orgRepo.Domains(false).Add(ctx, d) + require.NoError(t, err) + + org.Domains = []*domain.OrganizationDomain{ + { + InstanceID: d.InstanceID, + OrgID: d.OrgID, + ValidationType: d.ValidationType, + Domain: d.Domain, + IsPrimary: d.IsPrimary, + IsVerified: d.IsVerified, + CreatedAt: d.CreatedAt, + UpdatedAt: d.UpdatedAt, + }, + } + + return &org + }, + orgIdentifierCondition: orgRepo.IDCondition(organizationId), + } + }(), + func() test { + organizationName := gofakeit.Name() + return test{ + name: "happy path get using name", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + organizationId := gofakeit.Name() + + org := domain.Organization{ + ID: organizationId, + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := orgRepo.Create(ctx, &org) + require.NoError(t, err) + + return &org + }, + orgIdentifierCondition: orgRepo.NameCondition(organizationName), + } + }(), + { + name: "get non existent organization", + testFunc: func(ctx context.Context, t *testing.T) *domain.Organization { + org := domain.Organization{ + ID: "non existent org", + Name: "non existent org", + } + return &org + }, + orgIdentifierCondition: orgRepo.NameCondition("non-existent-instance-name"), + err: new(database.NoRowFoundError), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + orgRepo := repository.OrganizationRepository(pool) + + var org *domain.Organization + if tt.testFunc != nil { + org = tt.testFunc(ctx, t) + } + + // get org values + returnedOrg, err := orgRepo.Get(ctx, + database.WithCondition( + database.And( + tt.orgIdentifierCondition, + orgRepo.InstanceIDCondition(org.InstanceID), + ), + ), + ) + if tt.err != nil { + require.ErrorIs(t, tt.err, err) + return + } + require.NoError(t, err) + + if org.Name == "non existent org" { + assert.Nil(t, returnedOrg) + return + } + + assert.Equal(t, returnedOrg.ID, org.ID) + assert.Equal(t, returnedOrg.Name, org.Name) + assert.Equal(t, returnedOrg.InstanceID, org.InstanceID) + assert.Equal(t, returnedOrg.State, org.State) + }) + } +} + +func TestListOrganization(t *testing.T) { + ctx := t.Context() + pool, stop, err := newEmbeddedDB(ctx) + require.NoError(t, err) + defer stop() + organizationRepo := repository.OrganizationRepository(pool) + + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err = instanceRepo.Create(ctx, &instance) + require.NoError(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) []*domain.Organization + conditionClauses []database.Condition + noOrganizationReturned bool + } + tests := []test{ + { + name: "happy path single organization no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + noOfOrganizations := 1 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + }, + { + name: "happy path multiple organization no filter", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + noOfOrganizations := 5 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + }, + func() test { + organizationId := gofakeit.Name() + return test{ + name: "organization filter on id", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of this test case + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive, + } + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + noOfOrganizations := 1 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: organizationId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + conditionClauses: []database.Condition{organizationRepo.IDCondition(organizationId)}, + } + }(), + { + name: "multiple organization filter on state", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of this test case + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive, + } + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + noOfOrganizations := 5 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateInactive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + conditionClauses: []database.Condition{organizationRepo.StateCondition(domain.OrgStateInactive)}, + }, + func() test { + instanceId_2 := gofakeit.Name() + return test{ + name: "multiple organization filter on instance", + testFunc: func(ctx context.Context, t *testing.T) []*domain.Organization { + // create instance 1 + instanceId_1 := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId_1, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err = instanceRepo.Create(ctx, &instance) + assert.Nil(t, err) + + // create organization + // this org is created as an additional org which should NOT + // be returned in the results of this test case + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId_1, + State: domain.OrgStateActive, + } + err = organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + // create instance 2 + instance_2 := domain.Instance{ + ID: instanceId_2, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + err = instanceRepo.Create(ctx, &instance_2) + assert.Nil(t, err) + + noOfOrganizations := 5 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: gofakeit.Name(), + InstanceID: instanceId_2, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + return organizations + }, + conditionClauses: []database.Condition{organizationRepo.InstanceIDCondition(instanceId_2)}, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + t.Cleanup(func() { + _, err := pool.Exec(ctx, "DELETE FROM zitadel.organizations") + require.NoError(t, err) + }) + + organizations := tt.testFunc(ctx, t) + + var condition database.Condition + if len(tt.conditionClauses) > 0 { + condition = database.And(tt.conditionClauses...) + } + + // check organization values + returnedOrgs, err := organizationRepo.List(ctx, + database.WithCondition(condition), + database.WithOrderByAscending(organizationRepo.CreatedAtColumn()), + ) + require.NoError(t, err) + if tt.noOrganizationReturned { + assert.Nil(t, returnedOrgs) + return + } + + assert.Equal(t, len(organizations), len(returnedOrgs)) + for i, org := range organizations { + assert.Equal(t, returnedOrgs[i].ID, org.ID) + assert.Equal(t, returnedOrgs[i].Name, org.Name) + assert.Equal(t, returnedOrgs[i].InstanceID, org.InstanceID) + assert.Equal(t, returnedOrgs[i].State, org.State) + } + }) + } +} + +func TestDeleteOrganization(t *testing.T) { + // create instance + instanceId := gofakeit.Name() + instance := domain.Instance{ + ID: instanceId, + Name: gofakeit.Name(), + DefaultOrgID: "defaultOrgId", + IAMProjectID: "iamProject", + ConsoleClientID: "consoleCLient", + ConsoleAppID: "consoleApp", + DefaultLanguage: "defaultLanguage", + } + instanceRepo := repository.InstanceRepository(pool) + err := instanceRepo.Create(t.Context(), &instance) + require.NoError(t, err) + + type test struct { + name string + testFunc func(ctx context.Context, t *testing.T) + orgIdentifierCondition domain.OrgIdentifierCondition + noOfDeletedRows int64 + } + tests := []test{ + func() test { + organizationRepo := repository.OrganizationRepository(pool) + organizationId := gofakeit.Name() + var noOfOrganizations int64 = 1 + return test{ + name: "happy path delete organization filter id", + testFunc: func(ctx context.Context, t *testing.T) { + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: organizationId, + Name: gofakeit.Name(), + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + }, + orgIdentifierCondition: organizationRepo.IDCondition(organizationId), + noOfDeletedRows: noOfOrganizations, + } + }(), + func() test { + organizationRepo := repository.OrganizationRepository(pool) + organizationName := gofakeit.Name() + var noOfOrganizations int64 = 1 + return test{ + name: "happy path delete organization filter name", + testFunc: func(ctx context.Context, t *testing.T) { + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + }, + orgIdentifierCondition: organizationRepo.NameCondition(organizationName), + noOfDeletedRows: noOfOrganizations, + } + }(), + func() test { + organizationRepo := repository.OrganizationRepository(pool) + non_existent_organization_name := gofakeit.Name() + return test{ + name: "delete non existent organization", + orgIdentifierCondition: organizationRepo.NameCondition(non_existent_organization_name), + } + }(), + func() test { + organizationRepo := repository.OrganizationRepository(pool) + organizationName := gofakeit.Name() + return test{ + name: "deleted already deleted organization", + testFunc: func(ctx context.Context, t *testing.T) { + noOfOrganizations := 1 + organizations := make([]*domain.Organization, noOfOrganizations) + for i := range noOfOrganizations { + + org := domain.Organization{ + ID: gofakeit.Name(), + Name: organizationName, + InstanceID: instanceId, + State: domain.OrgStateActive, + } + + // create organization + err := organizationRepo.Create(ctx, &org) + require.NoError(t, err) + + organizations[i] = &org + } + + // delete organization + affectedRows, err := organizationRepo.Delete(ctx, + organizationRepo.NameCondition(organizationName), + organizations[0].InstanceID, + ) + assert.Equal(t, int64(1), affectedRows) + require.NoError(t, err) + }, + orgIdentifierCondition: organizationRepo.NameCondition(organizationName), + // this test should return 0 affected rows as the org was already deleted + noOfDeletedRows: 0, + } + }(), + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + organizationRepo := repository.OrganizationRepository(pool) + + if tt.testFunc != nil { + tt.testFunc(ctx, t) + } + + // delete organization + noOfDeletedRows, err := organizationRepo.Delete(ctx, + tt.orgIdentifierCondition, + instanceId, + ) + require.NoError(t, err) + assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows) + + // check organization was deleted + organization, err := organizationRepo.Get(ctx, + database.WithCondition( + database.And( + tt.orgIdentifierCondition, + organizationRepo.InstanceIDCondition(instanceId), + ), + ), + ) + require.ErrorIs(t, err, new(database.NoRowFoundError)) + assert.Nil(t, organization) + }) + } +} diff --git a/backend/v3/storage/database/repository/repository.go b/backend/v3/storage/database/repository/repository.go new file mode 100644 index 00000000000..c5b9ff81f09 --- /dev/null +++ b/backend/v3/storage/database/repository/repository.go @@ -0,0 +1,20 @@ +package repository + +import ( + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type repository struct { + client database.QueryExecutor +} + +func writeCondition( + builder *database.StatementBuilder, + condition database.Condition, +) { + if condition == nil { + return + } + builder.WriteString(" WHERE ") + condition.Write(builder) +} diff --git a/backend/v3/storage/database/repository/repository_test.go b/backend/v3/storage/database/repository/repository_test.go new file mode 100644 index 00000000000..36c28c22e2e --- /dev/null +++ b/backend/v3/storage/database/repository/repository_test.go @@ -0,0 +1,51 @@ +package repository_test + +import ( + "context" + "fmt" + "log" + "os" + "testing" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded" +) + +func TestMain(m *testing.M) { + os.Exit(runTests(m)) +} + +var pool database.PoolTest + +func runTests(m *testing.M) int { + var stop func() + var err error + ctx := context.Background() + pool, stop, err = newEmbeddedDB(ctx) + if err != nil { + log.Printf("error with embedded postgres database: %v", err) + return 1 + } + defer stop() + + return m.Run() +} + +func newEmbeddedDB(ctx context.Context) (pool database.PoolTest, stop func(), err error) { + connector, stop, err := embedded.StartEmbedded() + if err != nil { + return nil, nil, fmt.Errorf("unable to start embedded postgres: %w", err) + } + + pool_, err := connector.Connect(ctx) + if err != nil { + return nil, nil, fmt.Errorf("unable to connect to embedded postgres: %w", err) + } + pool = pool_.(database.PoolTest) + + err = pool.MigrateTest(ctx) + if err != nil { + return nil, nil, fmt.Errorf("unable to migrate database: %w", err) + } + return pool, stop, err +} diff --git a/backend/v3/storage/database/repository/user.go b/backend/v3/storage/database/repository/user.go new file mode 100644 index 00000000000..5953af78572 --- /dev/null +++ b/backend/v3/storage/database/repository/user.go @@ -0,0 +1,282 @@ +package repository + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_at, updated_at, deleted_at,` + + ` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` + + ` FROM users_view users` + +type user struct { + repository +} + +func UserRepository(client database.QueryExecutor) domain.UserRepository { + return &user{ + repository: repository{ + client: client, + }, + } +} + +var _ domain.UserRepository = (*user)(nil) + +// ------------------------------------------------------------- +// repository +// ------------------------------------------------------------- + +// Human implements [domain.UserRepository]. +func (u *user) Human() domain.HumanRepository { + return &userHuman{user: u} +} + +// Machine implements [domain.UserRepository]. +func (u *user) Machine() domain.MachineRepository { + return &userMachine{user: u} +} + +// List implements [domain.UserRepository]. +func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) { + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + builder := database.StatementBuilder{} + builder.WriteString(queryUserStmt) + options.WriteCondition(&builder) + options.WriteOrderBy(&builder) + options.WriteLimit(&builder) + options.WriteOffset(&builder) + + rows, err := u.client.Query(ctx, builder.String(), builder.Args()...) + if err != nil { + return nil, err + } + defer func() { + closeErr := rows.Close() + if err != nil { + return + } + err = closeErr + }() + for rows.Next() { + user, err := scanUser(rows) + if err != nil { + return nil, err + } + users = append(users, user) + } + if err := rows.Err(); err != nil { + return nil, err + } + return users, nil +} + +// Get implements [domain.UserRepository]. +func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) { + options := new(database.QueryOpts) + for _, opt := range opts { + opt(options) + } + + builder := database.StatementBuilder{} + builder.WriteString(queryUserStmt) + options.WriteCondition(&builder) + options.WriteOrderBy(&builder) + options.WriteLimit(&builder) + options.WriteOffset(&builder) + + return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...)) +} + +const ( + createHumanStmt = `INSERT INTO human_users (instance_id, org_id, user_id, username, first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at)` + + ` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` + + ` RETURNING created_at, updated_at` + createMachineStmt = `INSERT INTO user_machines (instance_id, org_id, user_id, username, description)` + + ` VALUES ($1, $2, $3, $4, $5)` + + ` RETURNING created_at, updated_at` +) + +// Create implements [domain.UserRepository]. +func (u *user) Create(ctx context.Context, user *domain.User) error { + builder := database.StatementBuilder{} + builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type()) + switch trait := user.Traits.(type) { + case *domain.Human: + builder.WriteString(createHumanStmt) + builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt) + case *domain.Machine: + builder.WriteString(createMachineStmt) + builder.AppendArgs(trait.Description) + } + return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt) +} + +// Delete implements [domain.UserRepository]. +func (u *user) Delete(ctx context.Context, condition database.Condition) error { + builder := database.StatementBuilder{} + builder.WriteString("DELETE FROM users") + writeCondition(&builder, condition) + _, err := u.client.Exec(ctx, builder.String(), builder.Args()...) + return err +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +// SetUsername implements [domain.userChanges]. +func (u user) SetUsername(username string) database.Change { + return database.NewChange(u.UsernameColumn(), username) +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +// InstanceIDCondition implements [domain.userConditions]. +func (u user) InstanceIDCondition(instanceID string) database.Condition { + return database.NewTextCondition(u.InstanceIDColumn(), database.TextOperationEqual, instanceID) +} + +// OrgIDCondition implements [domain.userConditions]. +func (u user) OrgIDCondition(orgID string) database.Condition { + return database.NewTextCondition(u.OrgIDColumn(), database.TextOperationEqual, orgID) +} + +// IDCondition implements [domain.userConditions]. +func (u user) IDCondition(userID string) database.Condition { + return database.NewTextCondition(u.IDColumn(), database.TextOperationEqual, userID) +} + +// UsernameCondition implements [domain.userConditions]. +func (u user) UsernameCondition(op database.TextOperation, username string) database.Condition { + return database.NewTextCondition(u.UsernameColumn(), op, username) +} + +// CreatedAtCondition implements [domain.userConditions]. +func (u user) CreatedAtCondition(op database.NumberOperation, createdAt time.Time) database.Condition { + return database.NewNumberCondition(u.CreatedAtColumn(), op, createdAt) +} + +// UpdatedAtCondition implements [domain.userConditions]. +func (u user) UpdatedAtCondition(op database.NumberOperation, updatedAt time.Time) database.Condition { + return database.NewNumberCondition(u.UpdatedAtColumn(), op, updatedAt) +} + +// DeletedCondition implements [domain.userConditions]. +func (u user) DeletedCondition(isDeleted bool) database.Condition { + if isDeleted { + return database.IsNotNull(u.DeletedAtColumn()) + } + return database.IsNull(u.DeletedAtColumn()) +} + +// DeletedAtCondition implements [domain.userConditions]. +func (u user) DeletedAtCondition(op database.NumberOperation, deletedAt time.Time) database.Condition { + return database.NewNumberCondition(u.DeletedAtColumn(), op, deletedAt) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +// InstanceIDColumn implements [domain.userColumns]. +func (user) InstanceIDColumn() database.Column { + return database.NewColumn("users", "instance_id") +} + +// OrgIDColumn implements [domain.userColumns]. +func (user) OrgIDColumn() database.Column { + return database.NewColumn("users", "org_id") +} + +// IDColumn implements [domain.userColumns]. +func (user) IDColumn() database.Column { + return database.NewColumn("users", "id") +} + +// UsernameColumn implements [domain.userColumns]. +func (user) UsernameColumn() database.Column { + return database.NewColumn("users", "username") +} + +// FirstNameColumn implements [domain.userColumns]. +func (user) CreatedAtColumn() database.Column { + return database.NewColumn("users", "created_at") +} + +// UpdatedAtColumn implements [domain.userColumns]. +func (user) UpdatedAtColumn() database.Column { + return database.NewColumn("users", "updated_at") +} + +// DeletedAtColumn implements [domain.userColumns]. +func (user) DeletedAtColumn() database.Column { + return database.NewColumn("users", "deleted_at") +} + +func (u user) columns() database.Columns { + return database.Columns{ + u.InstanceIDColumn(), + u.OrgIDColumn(), + u.IDColumn(), + u.UsernameColumn(), + u.CreatedAtColumn(), + u.UpdatedAtColumn(), + u.DeletedAtColumn(), + } +} + +func scanUser(scanner database.Scanner) (*domain.User, error) { + var ( + user domain.User + human domain.Human + email domain.Email + phone domain.Phone + machine domain.Machine + typ domain.UserType + ) + err := scanner.Scan( + &user.InstanceID, + &user.OrgID, + &user.ID, + &user.Username, + &typ, + &user.CreatedAt, + &user.UpdatedAt, + &user.DeletedAt, + &human.FirstName, + &human.LastName, + &email.Address, + &email.VerifiedAt, + &phone.Number, + &phone.VerifiedAt, + &machine.Description, + ) + if err != nil { + return nil, err + } + + switch typ { + case domain.UserTypeHuman: + if email.Address != "" { + human.Email = &email + } + if phone.Number != "" { + human.Phone = &phone + } + user.Traits = &human + case domain.UserTypeMachine: + user.Traits = &machine + } + + return &user, nil +} diff --git a/backend/v3/storage/database/repository/user_human.go b/backend/v3/storage/database/repository/user_human.go new file mode 100644 index 00000000000..ae7643c53c8 --- /dev/null +++ b/backend/v3/storage/database/repository/user_human.go @@ -0,0 +1,208 @@ +package repository + +import ( + "context" + "time" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +// ------------------------------------------------------------- +// repository +// ------------------------------------------------------------- + +type userHuman struct { + *user +} + +var _ domain.HumanRepository = (*userHuman)(nil) + +const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h` + +// GetEmail implements [domain.HumanRepository]. +func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) { + var email domain.Email + + builder := database.StatementBuilder{} + builder.WriteString(userEmailQuery) + writeCondition(&builder, condition) + + err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan( + &email.Address, + &email.VerifiedAt, + ) + if err != nil { + return nil, err + } + return &email, nil +} + +// Update implements [domain.HumanRepository]. +func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error { + builder := database.StatementBuilder{} + builder.WriteString(`UPDATE human_users SET `) + database.Changes(changes).Write(&builder) + writeCondition(&builder, condition) + + stmt := builder.String() + + _, err := h.client.Exec(ctx, stmt, builder.Args()...) + return err +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +// SetFirstName implements [domain.humanChanges]. +func (h userHuman) SetFirstName(firstName string) database.Change { + return database.NewChange(h.FirstNameColumn(), firstName) +} + +// SetLastName implements [domain.humanChanges]. +func (h userHuman) SetLastName(lastName string) database.Change { + return database.NewChange(h.LastNameColumn(), lastName) +} + +// SetEmail implements [domain.humanChanges]. +func (h userHuman) SetEmail(address string, verified *time.Time) database.Change { + return database.NewChanges( + h.SetEmailAddress(address), + database.NewChangePtr(h.EmailVerifiedAtColumn(), verified), + ) +} + +// SetEmailAddress implements [domain.humanChanges]. +func (h userHuman) SetEmailAddress(address string) database.Change { + return database.NewChange(h.EmailAddressColumn(), address) +} + +// SetEmailVerifiedAt implements [domain.humanChanges]. +func (h userHuman) SetEmailVerifiedAt(at time.Time) database.Change { + if at.IsZero() { + return database.NewChange(h.EmailVerifiedAtColumn(), database.NowInstruction) + } + return database.NewChange(h.EmailVerifiedAtColumn(), at) +} + +// SetPhone implements [domain.humanChanges]. +func (h userHuman) SetPhone(number string, verifiedAt *time.Time) database.Change { + return database.NewChanges( + h.SetPhoneNumber(number), + database.NewChangePtr(h.PhoneVerifiedAtColumn(), verifiedAt), + ) +} + +// SetPhoneNumber implements [domain.humanChanges]. +func (h userHuman) SetPhoneNumber(number string) database.Change { + return database.NewChange(h.PhoneNumberColumn(), number) +} + +// SetPhoneVerifiedAt implements [domain.humanChanges]. +func (h userHuman) SetPhoneVerifiedAt(at time.Time) database.Change { + if at.IsZero() { + return database.NewChange(h.PhoneVerifiedAtColumn(), database.NowInstruction) + } + return database.NewChange(h.PhoneVerifiedAtColumn(), at) +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +// FirstNameCondition implements [domain.humanConditions]. +func (h userHuman) FirstNameCondition(op database.TextOperation, firstName string) database.Condition { + return database.NewTextCondition(h.FirstNameColumn(), op, firstName) +} + +// LastNameCondition implements [domain.humanConditions]. +func (h userHuman) LastNameCondition(op database.TextOperation, lastName string) database.Condition { + return database.NewTextCondition(h.LastNameColumn(), op, lastName) +} + +// EmailAddressCondition implements [domain.humanConditions]. +func (h userHuman) EmailAddressCondition(op database.TextOperation, email string) database.Condition { + return database.NewTextCondition(h.EmailAddressColumn(), op, email) +} + +// EmailVerifiedCondition implements [domain.humanConditions]. +func (h *userHuman) EmailVerifiedCondition(isVerified bool) database.Condition { + if isVerified { + return database.IsNotNull(h.EmailVerifiedAtColumn()) + } + return database.IsNull(h.EmailVerifiedAtColumn()) +} + +// EmailVerifiedAtCondition implements [domain.humanConditions]. +func (h userHuman) EmailVerifiedAtCondition(op database.NumberOperation, verifiedAt time.Time) database.Condition { + return database.NewNumberCondition(h.EmailVerifiedAtColumn(), op, verifiedAt) +} + +// PhoneNumberCondition implements [domain.humanConditions]. +func (h userHuman) PhoneNumberCondition(op database.TextOperation, phoneNumber string) database.Condition { + return database.NewTextCondition(h.PhoneNumberColumn(), op, phoneNumber) +} + +// PhoneVerifiedCondition implements [domain.humanConditions]. +func (h userHuman) PhoneVerifiedCondition(isVerified bool) database.Condition { + if isVerified { + return database.IsNotNull(h.PhoneVerifiedAtColumn()) + } + return database.IsNull(h.PhoneVerifiedAtColumn()) +} + +// PhoneVerifiedAtCondition implements [domain.humanConditions]. +func (h userHuman) PhoneVerifiedAtCondition(op database.NumberOperation, verifiedAt time.Time) database.Condition { + return database.NewNumberCondition(h.PhoneVerifiedAtColumn(), op, verifiedAt) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +// FirstNameColumn implements [domain.humanColumns]. +func (h userHuman) FirstNameColumn() database.Column { + return database.NewColumn("user_humans", "first_name") +} + +// LastNameColumn implements [domain.humanColumns]. +func (h userHuman) LastNameColumn() database.Column { + return database.NewColumn("user_humans", "last_name") +} + +// EmailAddressColumn implements [domain.humanColumns]. +func (h userHuman) EmailAddressColumn() database.Column { + return database.NewColumn("user_humans", "email_address") +} + +// EmailVerifiedAtColumn implements [domain.humanColumns]. +func (h userHuman) EmailVerifiedAtColumn() database.Column { + return database.NewColumn("user_humans", "email_verified_at") +} + +// PhoneNumberColumn implements [domain.humanColumns]. +func (h userHuman) PhoneNumberColumn() database.Column { + return database.NewColumn("user_humans", "phone_number") +} + +// PhoneVerifiedAtColumn implements [domain.humanColumns]. +func (h userHuman) PhoneVerifiedAtColumn() database.Column { + return database.NewColumn("user_humans", "phone_verified_at") +} + +// func (h userHuman) columns() database.Columns { +// return append(h.user.columns(), +// h.FirstNameColumn(), +// h.LastNameColumn(), +// h.EmailAddressColumn(), +// h.EmailVerifiedAtColumn(), +// h.PhoneNumberColumn(), +// h.PhoneVerifiedAtColumn(), +// ) +// } + +// func (h userHuman) writeReturning(builder *database.StatementBuilder) { +// builder.WriteString(" RETURNING ") +// h.columns().Write(builder) +// } diff --git a/backend/v3/storage/database/repository/user_machine.go b/backend/v3/storage/database/repository/user_machine.go new file mode 100644 index 00000000000..2bda09d0e5d --- /dev/null +++ b/backend/v3/storage/database/repository/user_machine.go @@ -0,0 +1,67 @@ +package repository + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type userMachine struct { + *user +} + +var _ domain.MachineRepository = (*userMachine)(nil) + +// ------------------------------------------------------------- +// repository +// ------------------------------------------------------------- + +// Update implements [domain.MachineRepository]. +func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error { + builder := database.StatementBuilder{} + builder.WriteString("UPDATE user_machines SET ") + database.Changes(changes).Write(&builder) + writeCondition(&builder, condition) + m.writeReturning() + + _, err := m.client.Exec(ctx, builder.String(), builder.Args()...) + return err +} + +// ------------------------------------------------------------- +// changes +// ------------------------------------------------------------- + +// SetDescription implements [domain.machineChanges]. +func (m userMachine) SetDescription(description string) database.Change { + return database.NewChange(m.DescriptionColumn(), description) +} + +// ------------------------------------------------------------- +// conditions +// ------------------------------------------------------------- + +// DescriptionCondition implements [domain.machineConditions]. +func (m userMachine) DescriptionCondition(op database.TextOperation, description string) database.Condition { + return database.NewTextCondition(m.DescriptionColumn(), op, description) +} + +// ------------------------------------------------------------- +// columns +// ------------------------------------------------------------- + +// DescriptionColumn implements [domain.machineColumns]. +func (m userMachine) DescriptionColumn() database.Column { + return database.NewColumn("user_machines", "description") +} + +func (m userMachine) columns() database.Columns { + return append(m.user.columns(), m.DescriptionColumn()) +} + +func (m *userMachine) writeReturning() { + builder := database.StatementBuilder{} + builder.WriteString(" RETURNING ") + m.columns().WriteQualified(&builder) +} diff --git a/backend/v3/storage/database/statement.go b/backend/v3/storage/database/statement.go new file mode 100644 index 00000000000..2858feae434 --- /dev/null +++ b/backend/v3/storage/database/statement.go @@ -0,0 +1,68 @@ +package database + +import ( + "strconv" + "strings" +) + +type Instruction string + +const ( + DefaultInstruction Instruction = "DEFAULT" + NowInstruction Instruction = "NOW()" + NullInstruction Instruction = "NULL" +) + +// StatementBuilder is a helper to build SQL statement. +type StatementBuilder struct { + strings.Builder + args []any + existingArgs map[any]string +} + +// WriteArgs adds the argument to the statement and writes the placeholder to the query. +func (b *StatementBuilder) WriteArg(arg any) { + b.WriteString(b.AppendArg(arg)) +} + +// WriteArgs adds the arguments to the statement and writes the placeholders to the query. +// The placeholders are comma separated. +func (b *StatementBuilder) WriteArgs(args ...any) { + for i, arg := range args { + if i > 0 { + b.WriteString(", ") + } + b.WriteArg(arg) + } +} + +// AppendArg adds the argument to the statement and returns the placeholder. +func (b *StatementBuilder) AppendArg(arg any) (placeholder string) { + if b.existingArgs == nil { + b.existingArgs = make(map[any]string) + } + if placeholder, ok := b.existingArgs[arg]; ok { + return placeholder + } + if instruction, ok := arg.(Instruction); ok { + return string(instruction) + } + + b.args = append(b.args, arg) + placeholder = "$" + strconv.Itoa(len(b.args)) + b.existingArgs[arg] = placeholder + return placeholder +} + +// AppendArgs adds the arguments to the statement and doesn't return the placeholders. +// If an argument is already added, it will not be added again. +func (b *StatementBuilder) AppendArgs(args ...any) { + for _, arg := range args { + b.AppendArg(arg) + } +} + +// Args returns the arguments added to the statement. +func (b *StatementBuilder) Args() []any { + return b.args +} diff --git a/backend/v3/storage/database/tx.go b/backend/v3/storage/database/tx.go new file mode 100644 index 00000000000..a8f7adab582 --- /dev/null +++ b/backend/v3/storage/database/tx.go @@ -0,0 +1,38 @@ +package database + +import "context" + +// Transaction is an SQL transaction. +type Transaction interface { + Commit(ctx context.Context) error + Rollback(ctx context.Context) error + End(ctx context.Context, err error) error + + Begin(ctx context.Context) (Transaction, error) + + QueryExecutor +} + +// Beginner can start a new transaction. +type Beginner interface { + Begin(ctx context.Context, opts *TransactionOptions) (Transaction, error) +} + +type TransactionOptions struct { + IsolationLevel IsolationLevel + AccessMode AccessMode +} + +type IsolationLevel uint8 + +const ( + IsolationLevelSerializable IsolationLevel = iota + IsolationLevelReadCommitted +) + +type AccessMode uint8 + +const ( + AccessModeReadWrite AccessMode = iota + AccessModeReadOnly +) diff --git a/backend/v3/storage/eventstore/event.go b/backend/v3/storage/eventstore/event.go new file mode 100644 index 00000000000..4a52bc860c0 --- /dev/null +++ b/backend/v3/storage/eventstore/event.go @@ -0,0 +1,24 @@ +package eventstore + +import ( + "context" + + "github.com/zitadel/zitadel/backend/v3/storage/database" +) + +type Event struct { + AggregateType string `json:"aggregateType"` + AggregateID string `json:"aggregateId"` + Type string `json:"type"` + Payload any `json:"payload,omitempty"` +} + +func Publish(ctx context.Context, events []*Event, db database.Executor) error { + for _, event := range events { + _, err := db.Exec(ctx, `INSERT INTO events (aggregate_type, aggregate_id) VALUES ($1, $2)`, event.AggregateType, event.AggregateID) + if err != nil { + return err + } + } + return nil +} diff --git a/backend/v3/telemetry/logging/logger.go b/backend/v3/telemetry/logging/logger.go new file mode 100644 index 00000000000..a7d724daffa --- /dev/null +++ b/backend/v3/telemetry/logging/logger.go @@ -0,0 +1,12 @@ +package logging + +import "log/slog" + +// Logger abstracts [slog.Logger] not sure if thats needed +type Logger struct { + *slog.Logger +} + +func NewLogger(logger *slog.Logger) *Logger { + return &Logger{Logger: logger} +} diff --git a/backend/v3/telemetry/metric/doc.go b/backend/v3/telemetry/metric/doc.go new file mode 100644 index 00000000000..e4a8b3bd63a --- /dev/null +++ b/backend/v3/telemetry/metric/doc.go @@ -0,0 +1,2 @@ +// implementation of otel metrics +package metric diff --git a/backend/v3/telemetry/tracing/tracer.go b/backend/v3/telemetry/tracing/tracer.go new file mode 100644 index 00000000000..025a0c11168 --- /dev/null +++ b/backend/v3/telemetry/tracing/tracer.go @@ -0,0 +1,24 @@ +package tracing + +import ( + "context" + + "go.opentelemetry.io/otel/trace" + "go.opentelemetry.io/otel/trace/noop" +) + +// Tracer is a wrapper around the OpenTelemetry Tracer interface. +type Tracer struct { + trace.Tracer +} + +var noopTracer = Tracer{ + Tracer: noop.NewTracerProvider().Tracer(""), +} + +func (t *Tracer) Start(ctx context.Context, spanName string, opts ...trace.SpanStartOption) (context.Context, trace.Span) { + if t.Tracer == nil { + return noopTracer.Start(ctx, spanName, opts...) + } + return t.Tracer.Start(ctx, spanName, opts...) +} diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index fd4cbda285e..9bc0adc5cea 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -286,6 +286,9 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) ExternalSecure: config.ExternalSecure, defaults: config.SystemDefaults, }, + &TransactionalTables{ + dbClient: dbClient, + }, &projectionTables{ es: eventstoreClient, Version: build.Version(), diff --git a/cmd/setup/transactional_tables.go b/cmd/setup/transactional_tables.go new file mode 100644 index 00000000000..ca5df0a0a83 --- /dev/null +++ b/cmd/setup/transactional_tables.go @@ -0,0 +1,32 @@ +package setup + +import ( + "context" + _ "embed" + + "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres" + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +type TransactionalTables struct { + dbClient *database.DB +} + +func (mig *TransactionalTables) Execute(ctx context.Context, _ eventstore.Event) error { + config := &postgres.Config{Pool: mig.dbClient.Pool} + pool, err := config.Connect(ctx) + if err != nil { + return err + } + + return pool.Migrate(ctx) +} + +func (mig *TransactionalTables) String() string { + return "repeatable_transactional_tables" +} + +func (mig *TransactionalTables) Check(lastRun map[string]interface{}) bool { + return true +} diff --git a/go.mod b/go.mod index 438f77ecb72..77041c405cb 100644 --- a/go.mod +++ b/go.mod @@ -31,6 +31,7 @@ require ( github.com/fatih/color v1.18.0 github.com/fergusstrange/embedded-postgres v1.30.0 github.com/gabriel-vasile/mimetype v1.4.9 + github.com/georgysavva/scany/v2 v2.1.4 github.com/go-chi/chi/v5 v5.2.2 github.com/go-jose/go-jose/v4 v4.1.0 github.com/go-ldap/ldap/v3 v3.4.11 @@ -51,6 +52,7 @@ require ( github.com/improbable-eng/grpc-web v0.15.0 github.com/jackc/pgx-shopspring-decimal v0.0.0-20220624020537-1d36b5a1853e github.com/jackc/pgx/v5 v5.7.5 + github.com/jackc/tern/v2 v2.3.3 github.com/jarcoal/jpath v0.0.0-20140328210829-f76b8b2dbf52 github.com/jinzhu/gorm v1.9.16 github.com/k3a/html2text v1.2.1 @@ -73,7 +75,7 @@ require ( github.com/robfig/cron/v3 v3.0.1 github.com/rs/cors v1.11.1 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 - github.com/shopspring/decimal v1.3.1 + github.com/shopspring/decimal v1.4.0 github.com/sony/gobreaker/v2 v2.1.0 github.com/sony/sonyflake v1.2.1 github.com/spf13/cobra v1.9.1 @@ -119,6 +121,9 @@ require ( github.com/GoogleCloudPlatform/opentelemetry-operations-go/detectors/gcp v1.27.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/exporter/metric v0.51.0 // indirect github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 // indirect + github.com/Masterminds/goutils v1.1.1 // indirect + github.com/Masterminds/semver/v3 v3.3.0 // indirect + github.com/Masterminds/sprig/v3 v3.3.0 // indirect github.com/alicebob/gopher-json v0.0.0-20230218143504-906a9b012302 // indirect github.com/bmatcuk/doublestar/v4 v4.9.0 // indirect github.com/cncf/xds/go v0.0.0-20250121191232-2f005788dc42 // indirect @@ -140,6 +145,7 @@ require ( github.com/google/s2a-go v0.1.9 // indirect github.com/hashicorp/errwrap v1.1.0 // indirect github.com/hashicorp/go-multierror v1.1.1 // indirect + github.com/huandu/xstrings v1.5.0 // indirect github.com/jackc/puddle/v2 v2.2.2 // indirect github.com/klauspost/cpuid/v2 v2.2.10 // indirect github.com/lib/pq v1.10.9 // indirect @@ -147,6 +153,8 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/minio/crc64nvme v1.0.1 // indirect + github.com/mitchellh/copystructure v1.2.0 // indirect + github.com/mitchellh/reflectwalk v1.0.2 // indirect github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 // indirect github.com/pelletier/go-toml/v2 v2.2.3 // indirect github.com/pkg/errors v0.9.1 // indirect diff --git a/go.sum b/go.sum index 50855532b9c..ab66c29fa7f 100644 --- a/go.sum +++ b/go.sum @@ -50,8 +50,12 @@ github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/cloudmock v0 github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0 h1:6/0iUd0xrnX7qt+mLNRwg5c0PGv8wpE8K90ryANQwMI= github.com/GoogleCloudPlatform/opentelemetry-operations-go/internal/resourcemapping v0.51.0/go.mod h1:otE2jQekW/PqXk1Awf5lmfokJx4uwuqcj1ab5SpGeW0= github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= -github.com/Masterminds/semver/v3 v3.2.1 h1:RN9w6+7QoMeJVGyfmbcgs28Br8cvmnucEXnY0rYXWg0= -github.com/Masterminds/semver/v3 v3.2.1/go.mod h1:qvl/7zhW3nngYb5+80sSMF+FG2BjYrf8m9wsX0PNOMQ= +github.com/Masterminds/goutils v1.1.1 h1:5nUrii3FMTL5diU80unEVvNevw1nH4+ZV4DSLVJLSYI= +github.com/Masterminds/goutils v1.1.1/go.mod h1:8cTjp+g8YejhMuvIA5y2vz3BpJxksy863GQaJW2MFNU= +github.com/Masterminds/semver/v3 v3.3.0 h1:B8LGeaivUe71a5qox1ICM/JLl0NqZSW5CHyL+hmvYS0= +github.com/Masterminds/semver/v3 v3.3.0/go.mod h1:4V+yj/TJE1HU9XfppCwVMZq3I84lprf4nC11bSS5beM= +github.com/Masterminds/sprig/v3 v3.3.0 h1:mQh0Yrg1XPo6vjYXgtf5OtijNAKJRNcTdOOGZe3tPhs= +github.com/Masterminds/sprig/v3 v3.3.0/go.mod h1:Zy1iXRYNqNLUolqCpL4uhk6SHUMAOSCzdgBfDb35Lz0= github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM= github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= github.com/PuerkitoBio/goquery v1.5.1/go.mod h1:GsLWisAFVj4WgDibEWF4pvYnkVQBpKBKeU+7zCJoLcc= @@ -233,6 +237,8 @@ github.com/fxamacker/cbor/v2 v2.7.0 h1:iM5WgngdRBanHcxugY4JySA0nk1wZorNOpTgCMedv github.com/fxamacker/cbor/v2 v2.7.0/go.mod h1:pxXPTn3joSm21Gbwsv0w9OSA2y1HFR9qXEeXQVeNoDQ= github.com/gabriel-vasile/mimetype v1.4.9 h1:5k+WDwEsD9eTLL8Tz3L0VnmVh9QxGjRmjBvAG7U/oYY= github.com/gabriel-vasile/mimetype v1.4.9/go.mod h1:WnSQhFKJuBlRyLiKohA/2DtIlPFAbguNaG7QCHcyGok= +github.com/georgysavva/scany/v2 v2.1.4 h1:nrzHEJ4oQVRoiKmocRqA1IyGOmM/GQOEsg9UjMR5Ip4= +github.com/georgysavva/scany/v2 v2.1.4/go.mod h1:fqp9yHZzM/PFVa3/rYEC57VmDx+KDch0LoqrJzkvtos= github.com/ghodss/yaml v1.0.0/go.mod h1:4dBDuWmgqj2HViK6kFavaiC9ZROes6MMH2rRYeMEF04= github.com/gin-contrib/sse v0.1.0/go.mod h1:RHrZQHXnP2xjPF+u1gW/2HnVO7nvIa9PG3Gm+fLHvGI= github.com/gin-gonic/gin v1.6.3/go.mod h1:75u5sXoLsGZoRN5Sgbi1eraJ4GU3++wFwWzhwvtwp4M= @@ -308,8 +314,9 @@ github.com/golang-jwt/jwt/v4 v4.5.2 h1:YtQM7lnr8iZ+j5q71MGKkNw9Mn7AjHM68uc9g5fXe github.com/golang-jwt/jwt/v4 v4.5.2/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0= github.com/golang-jwt/jwt/v5 v5.2.2 h1:Rl4B7itRWVtYIHFrSNd7vhTiz9UpLdi6gZhZ3wEeDy8= github.com/golang-jwt/jwt/v5 v5.2.2/go.mod h1:pqrtFR0X4osieyHYxtmOUWsAWrfe1Q5UVIyoH402zdk= -github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe h1:lXe2qZdvpiX5WZkZR4hgp4KJVfY3nMkvmwbVkpv1rVY= github.com/golang-sql/civil v0.0.0-20190719163853-cb61b32ac6fe/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9 h1:au07oEsX2xN0ktxqI+Sida1w446QrXBRJ0nee3SNZlA= +github.com/golang-sql/civil v0.0.0-20220223132316-b832511892a9/go.mod h1:8vg3r2VgvsThLBIFL93Qb5yWzgyZWhEmBwUJWevAkK0= github.com/golang/freetype v0.0.0-20170609003504-e2365dfdc4a0/go.mod h1:E/TSTwGwJL78qG/PmXZO1EjYhfJinVAhrmmHX6Z8B9k= github.com/golang/geo v0.0.0-20190916061304-5b978397cfec/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= github.com/golang/geo v0.0.0-20200319012246-673a6f80352d/go.mod h1:QZ0nwyI2jOfgRAoBvP+ab5aRr7c9x7lhGEJrKvBwjWI= @@ -435,6 +442,8 @@ github.com/hashicorp/serf v0.8.2/go.mod h1:6hOLApaqBFA1NXqRQAsxw9QxuDEvNxSQRwA/J github.com/hexops/gotextdiff v1.0.3 h1:gitA9+qJrrTCsiCl7+kh75nPqQt1cx4ZkudSTLoUqJM= github.com/hexops/gotextdiff v1.0.3/go.mod h1:pSWU5MAI3yDq+fZBTazCSJysOMbxWL1BSow5/V2vxeg= github.com/hpcloud/tail v1.0.0/go.mod h1:ab1qPbhIpdTxEkNHXyeSf5vhxWSCs/tWer42PpOxQnU= +github.com/huandu/xstrings v1.5.0 h1:2ag3IFq9ZDANvthTwTiqSSZLjDc+BedvHPAp5tJy2TI= +github.com/huandu/xstrings v1.5.0/go.mod h1:y5/lhBue+AyNmUVz9RLU9xbLR0o4KIIExikq4ovT0aE= github.com/hudl/fargo v1.3.0/go.mod h1:y3CKSmjA+wD2gak7sUSXTAoopbhU08POFhmITJgmKTg= github.com/improbable-eng/grpc-web v0.15.0 h1:BN+7z6uNXZ1tQGcNAuaU1YjsLTApzkjt2tzCixLaUPQ= github.com/improbable-eng/grpc-web v0.15.0/go.mod h1:1sy9HKV4Jt9aEs9JSnkWlRJPuPtwNr0l57L4f878wP8= @@ -454,6 +463,8 @@ github.com/jackc/pgx/v5 v5.7.5 h1:JHGfMnQY+IEtGM63d+NGMjoRpysB2JBwDr5fsngwmJs= github.com/jackc/pgx/v5 v5.7.5/go.mod h1:aruU7o91Tc2q2cFp5h4uP3f6ztExVpyVv88Xl/8Vl8M= github.com/jackc/puddle/v2 v2.2.2 h1:PR8nw+E/1w0GLuRFSmiioY6UooMp6KJv0/61nB7icHo= github.com/jackc/puddle/v2 v2.2.2/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= +github.com/jackc/tern/v2 v2.3.3 h1:d6QNRyjk9HttJtSF5pUB8UaXrHwCgEai3/yxYjgci/k= +github.com/jackc/tern/v2 v2.3.3/go.mod h1:0/9jqEreuC+ywjB7C5ta6Xkhl+HSaxFmCAggEDcp6v0= github.com/jarcoal/jpath v0.0.0-20140328210829-f76b8b2dbf52 h1:jny9eqYPwkG8IVy7foUoRjQmFLcArCSz+uPsL6KS0HQ= github.com/jarcoal/jpath v0.0.0-20140328210829-f76b8b2dbf52/go.mod h1:RDZ+4PR3mDOtTpVbI0qBE+rdhmtIrtbssiNn38/1OWA= github.com/jcmturner/aescts/v2 v2.0.0 h1:9YKLH6ey7H4eDBXW8khjYslgyqG2xZikXP0EQFKrle8= @@ -559,6 +570,8 @@ github.com/minio/md5-simd v1.1.2/go.mod h1:MzdKDxYpY2BT9XQFocsiZf/NKVtR7nkE4RoEp github.com/minio/minio-go/v7 v7.0.91 h1:tWLZnEfo3OZl5PoXQwcwTAPNNrjyWwOh6cbZitW5JQc= github.com/minio/minio-go/v7 v7.0.91/go.mod h1:uvMUcGrpgeSAAI6+sD3818508nUyMULw94j2Nxku/Go= github.com/mitchellh/cli v1.0.0/go.mod h1:hNIlj7HEI86fIcpObd7a0FcrxTWetlwJDGcceTlRvqc= +github.com/mitchellh/copystructure v1.2.0 h1:vpKXTN4ewci03Vljg/q9QvCGUDttBOGBIa15WveJJGw= +github.com/mitchellh/copystructure v1.2.0/go.mod h1:qLl+cE2AmVv+CoeAwDPye/v+N2HKCj9FbZEVFJRxO9s= github.com/mitchellh/go-homedir v1.0.0/go.mod h1:SfyaCUpYCn1Vlf4IUYiD9fPX4A5wJrkLzIz1N1q0pr0= github.com/mitchellh/go-testing-interface v1.0.0/go.mod h1:kRemZodwjscx+RGhAo8eIhFbs2+BFgRtFPeD/KE+zxI= github.com/mitchellh/gox v0.4.0/go.mod h1:Sd9lOJ0+aimLBi73mGofS1ycjY8lL3uZM3JPS42BGNg= @@ -567,6 +580,8 @@ github.com/mitchellh/mapstructure v0.0.0-20160808181253-ca63d7c062ee/go.mod h1:F github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY= github.com/mitchellh/mapstructure v1.5.0/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo= +github.com/mitchellh/reflectwalk v1.0.2 h1:G2LzWKi524PWgd3mLHV8Y5k7s6XUvT0Gef6zxSIeXaQ= +github.com/mitchellh/reflectwalk v1.0.2/go.mod h1:mSTlrgnPZtwu0c4WaC2kGObEpuNDbx0jmZXqmk4esnw= github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/concurrent v0.0.0-20180306012644-bacd9c7ef1dd/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= github.com/modern-go/reflect2 v0.0.0-20180701023420-4b7aa43c6742/go.mod h1:bx2lNnkwVCuqBIxFjflWJWanXIb3RllmbCylyMrvgv0= @@ -715,8 +730,8 @@ github.com/samuel/go-zookeeper v0.0.0-20190923202752-2cc03de413da/go.mod h1:gi+0 github.com/santhosh-tekuri/jsonschema/v5 v5.3.1 h1:lZUw3E0/J3roVtGQ+SCrUrg3ON6NgVqpn3+iol9aGu4= github.com/santhosh-tekuri/jsonschema/v5 v5.3.1/go.mod h1:uToXkOrWAZ6/Oc07xWQrPOhJotwFIyu2bBVN41fcDUY= github.com/sean-/seed v0.0.0-20170313163322-e2103e2c3529/go.mod h1:DxrIzT+xaE7yg65j358z/aeFdxmN0P9QXhEzd20vsDc= -github.com/shopspring/decimal v1.3.1 h1:2Usl1nmF/WZucqkFZhnfFYxxxu8LG21F6nPQBE5gKV8= -github.com/shopspring/decimal v1.3.1/go.mod h1:DKyhrW/HYNuLGql+MJL6WCR6knT2jwCFRcu2hWCYk4o= +github.com/shopspring/decimal v1.4.0 h1:bxl37RwXBklmTi0C79JfXCEBD1cqqHt0bbgBAGFp81k= +github.com/shopspring/decimal v1.4.0/go.mod h1:gawqmDU56v4yIKSwfBSFip1HdCCXN8/+DMd9qYNcwME= github.com/shurcooL/sanitized_anchor_name v1.0.0/go.mod h1:1NzhyTcUVG4SuEtjjoZeVRXNmyL/1OwPU0+IJeTBvfc= github.com/sirupsen/logrus v1.2.0/go.mod h1:LxeOpSwHxABJmUn/MG1IvRgCAasNZTLOkJPxbbu5VWo= github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE= @@ -756,6 +771,8 @@ github.com/streadway/amqp v0.0.0-20190827072141-edfb9018d271/go.mod h1:AZpEONHx3 github.com/streadway/handy v0.0.0-20190108123426-d5acb3125c2a/go.mod h1:qNTQ5P5JnDBl6z3cMAg/SywNDC5ABu5ApDIw6lUbRmI= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/objx v0.5.2 h1:xuMeJ0Sdp5ZMRXx/aWO6RZxdr3beISkG5/G/aIRr3pY= +github.com/stretchr/objx v0.5.2/go.mod h1:FRsXN1f5AsAjCGJKqEizvkpNtU+EGNCLh3NxZ/8L+MA= github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= diff --git a/internal/command/user_machine_model.go b/internal/command/user_machine_model.go index 1ed6c8ca58f..e905eba4da7 100644 --- a/internal/command/user_machine_model.go +++ b/internal/command/user_machine_model.go @@ -2,6 +2,7 @@ package command import ( "context" + "github.com/zitadel/zitadel/internal/crypto" "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore" diff --git a/internal/crypto/crypto.go b/internal/crypto/crypto.go index ff3b6e24186..0aadba35c8b 100644 --- a/internal/crypto/crypto.go +++ b/internal/crypto/crypto.go @@ -25,6 +25,8 @@ type EncryptionAlgorithm interface { DecryptString(hashed []byte, keyID string) (string, error) } +// CryptoValue is a struct that can be used to store encrypted values in a database. +// The struct is compatible with the [driver.Valuer] and database/sql.Scanner interfaces. type CryptoValue struct { CryptoType CryptoType Algorithm string diff --git a/internal/integration/config/zitadel.yaml b/internal/integration/config/zitadel.yaml index 00d55d0f6da..74956afe658 100644 --- a/internal/integration/config/zitadel.yaml +++ b/internal/integration/config/zitadel.yaml @@ -62,10 +62,10 @@ Projections: RequeueEvery: 20s Customizations: NotificationsQuotas: - RequeueEvery: 1s + RequeueEvery: 5s telemetry: HandleActiveInstances: 60s - RequeueEvery: 1s + RequeueEvery: 5s DefaultInstance: LoginPolicy: diff --git a/internal/query/projection/instance_domain_relational.go b/internal/query/projection/instance_domain_relational.go new file mode 100644 index 00000000000..0deb4e82a45 --- /dev/null +++ b/internal/query/projection/instance_domain_relational.go @@ -0,0 +1,178 @@ +package projection + +import ( + "context" + "database/sql" + + "github.com/muhlemmer/gu" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + v3_sql "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/sql" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type instanceDomainRelationalProjection struct{} + +func newInstanceDomainRelationalProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(instanceDomainRelationalProjection)) +} + +func (*instanceDomainRelationalProjection) Name() string { + return "zitadel.instance_domains" +} + +func (p *instanceDomainRelationalProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: instance.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: instance.InstanceDomainAddedEventType, + Reduce: p.reduceCustomDomainAdded, + }, + { + Event: instance.InstanceDomainPrimarySetEventType, + Reduce: p.reduceDomainPrimarySet, + }, + { + Event: instance.InstanceDomainRemovedEventType, + Reduce: p.reduceCustomDomainRemoved, + }, + { + Event: instance.TrustedDomainAddedEventType, + Reduce: p.reduceTrustedDomainAdded, + }, + { + Event: instance.TrustedDomainRemovedEventType, + Reduce: p.reduceTrustedDomainRemoved, + }, + }, + }, + } +} + +func (p *instanceDomainRelationalProjection) reduceCustomDomainAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DomainAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-DU0xF", "reduce.wrong.event.type %s", instance.InstanceDomainAddedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-bXCa6", "reduce.wrong.db.pool %T", ex) + } + return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{ + InstanceID: e.Aggregate().InstanceID, + Domain: e.Domain, + IsPrimary: gu.Ptr(false), + IsGenerated: &e.Generated, + Type: domain.DomainTypeCustom, + CreatedAt: e.CreationDate(), + UpdatedAt: e.CreationDate(), + }) + }), nil +} + +func (p *instanceDomainRelationalProjection) reduceDomainPrimarySet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DomainPrimarySetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-TdEWA", "reduce.wrong.event.type %s", instance.InstanceDomainPrimarySetEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-QnjHo", "reduce.wrong.db.pool %T", ex) + } + domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false) + + condition := database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + domainRepo.TypeCondition(domain.DomainTypeCustom), + ) + + _, err := domainRepo.Update(ctx, + condition, + domainRepo.SetPrimary(), + ) + if err != nil { + return err + } + // we need to split the update into two statements because multiple events can have the same creation date + // therefore we first do not set the updated_at timestamp + _, err = domainRepo.Update(ctx, + condition, + domainRepo.SetUpdatedAt(e.CreationDate()), + ) + return err + }), nil +} + +func (p *instanceDomainRelationalProjection) reduceCustomDomainRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DomainRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Hhcdl", "reduce.wrong.event.type %s", instance.InstanceDomainRemovedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-58ghE", "reduce.wrong.db.pool %T", ex) + } + domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false) + _, err := domainRepo.Remove(ctx, + database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + domainRepo.TypeCondition(domain.DomainTypeCustom), + ), + ) + return err + }), nil +} + +func (p *instanceDomainRelationalProjection) reduceTrustedDomainAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.TrustedDomainAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-svHDh", "reduce.wrong.event.type %s", instance.TrustedDomainAddedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-gx7tQ", "reduce.wrong.db.pool %T", ex) + } + return repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddInstanceDomain{ + InstanceID: e.Aggregate().InstanceID, + Domain: e.Domain, + Type: domain.DomainTypeTrusted, + CreatedAt: e.CreationDate(), + UpdatedAt: e.CreationDate(), + }) + }), nil +} + +func (p *instanceDomainRelationalProjection) reduceTrustedDomainRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.TrustedDomainRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-4K74E", "reduce.wrong.event.type %s", instance.TrustedDomainRemovedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-D68ap", "reduce.wrong.db.pool %T", ex) + } + domainRepo := repository.InstanceRepository(v3_sql.SQLTx(tx)).Domains(false) + _, err := domainRepo.Remove(ctx, + database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + domainRepo.TypeCondition(domain.DomainTypeTrusted), + ), + ) + return err + }), nil +} diff --git a/internal/query/projection/instance_relational.go b/internal/query/projection/instance_relational.go new file mode 100644 index 00000000000..f79127bef86 --- /dev/null +++ b/internal/query/projection/instance_relational.go @@ -0,0 +1,200 @@ +package projection + +import ( + "context" + "database/sql" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + v3_sql "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/sql" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const InstanceRelationalProjectionTable = "zitadel.instances" + +type instanceRelationalProjection struct{} + +func newInstanceRelationalProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(instanceRelationalProjection)) +} + +func (*instanceRelationalProjection) Name() string { + return InstanceRelationalProjectionTable +} + +func (p *instanceRelationalProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: instance.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: instance.InstanceAddedEventType, + Reduce: p.reduceInstanceAdded, + }, + { + Event: instance.InstanceChangedEventType, + Reduce: p.reduceInstanceChanged, + }, + { + Event: instance.InstanceRemovedEventType, + Reduce: p.reduceInstanceDelete, + }, + { + Event: instance.DefaultOrgSetEventType, + Reduce: p.reduceDefaultOrgSet, + }, + { + Event: instance.ProjectSetEventType, + Reduce: p.reduceIAMProjectSet, + }, + { + Event: instance.ConsoleSetEventType, + Reduce: p.reduceConsoleSet, + }, + { + Event: instance.DefaultLanguageSetEventType, + Reduce: p.reduceDefaultLanguageSet, + }, + }, + }, + } +} + +func (p *instanceRelationalProjection) reduceInstanceAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.InstanceAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-29nRr", "reduce.wrong.event.type %s", instance.InstanceAddedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) + } + return repository.InstanceRepository(v3_sql.SQLTx(tx)).Create(ctx, &domain.Instance{ + ID: e.Aggregate().ID, + Name: e.Name, + CreatedAt: e.CreationDate(), + UpdatedAt: e.CreationDate(), + }) + }), nil +} + +func (p *instanceRelationalProjection) reduceInstanceChanged(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.InstanceChangedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-so2am1", "reduce.wrong.event.type %s", instance.InstanceChangedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) + } + repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) + return p.updateInstance(ctx, event, repo, repo.SetName(e.Name)) + }), nil +} + +func (p *instanceRelationalProjection) reduceInstanceDelete(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.InstanceRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-so2am1", "reduce.wrong.event.type %s", instance.InstanceChangedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) + } + _, err := repository.InstanceRepository(v3_sql.SQLTx(tx)).Delete(ctx, e.Aggregate().ID) + return err + }), nil +} + +func (p *instanceRelationalProjection) reduceDefaultOrgSet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DefaultOrgSetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-2n9f2", "reduce.wrong.event.type %s", instance.DefaultOrgSetEventType) + } + + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) + } + repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) + return p.updateInstance(ctx, event, repo, repo.SetDefaultOrg(e.OrgID)) + }), nil +} + +func (p *instanceRelationalProjection) reduceIAMProjectSet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.ProjectSetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-30o0e", "reduce.wrong.event.type %s", instance.ProjectSetEventType) + } + + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) + } + repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) + return p.updateInstance(ctx, event, repo, repo.SetIAMProject(e.ProjectID)) + }), nil +} + +func (p *instanceRelationalProjection) reduceConsoleSet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.ConsoleSetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Dgf11", "reduce.wrong.event.type %s", instance.ConsoleSetEventType) + } + + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) + } + repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) + return p.updateInstance(ctx, event, repo, repo.SetConsoleClientID(e.ClientID), repo.SetConsoleAppID(e.AppID)) + }), nil +} + +func (p *instanceRelationalProjection) reduceDefaultLanguageSet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*instance.DefaultLanguageSetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-30o0e", "reduce.wrong.event.type %s", instance.DefaultLanguageSetEventType) + } + + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-rVUyy", "reduce.wrong.db.pool %T", ex) + } + repo := repository.InstanceRepository(v3_sql.SQLTx(tx)) + return p.updateInstance(ctx, event, repo, repo.SetDefaultLanguage(e.Language)) + }), nil +} + +func (p *instanceRelationalProjection) updateInstance(ctx context.Context, event eventstore.Event, repo domain.InstanceRepository, changes ...database.Change) error { + _, err := repo.Update(ctx, event.Aggregate().ID, changes...) + if err != nil { + return err + } + + instance, err := repo.Get(ctx, database.WithCondition(repo.IDCondition(event.Aggregate().ID))) + if err != nil { + return err + } + if instance.UpdatedAt.Equal(event.CreatedAt()) { + return nil + } + // we need to split the update into two statements because multiple events can have the same creation date + // therefore we first do not set the updated_at timestamp + _, err = repo.Update(ctx, + event.Aggregate().ID, + repo.SetUpdatedAt(event.CreatedAt()), + ) + return err +} diff --git a/internal/query/projection/org_domain_relational.go b/internal/query/projection/org_domain_relational.go new file mode 100644 index 00000000000..747a21a2cb0 --- /dev/null +++ b/internal/query/projection/org_domain_relational.go @@ -0,0 +1,210 @@ +package projection + +import ( + "context" + "database/sql" + + "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/backend/v3/storage/database" + v3_sql "github.com/zitadel/zitadel/backend/v3/storage/database/dialect/sql" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" + old_domain "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/org" + "github.com/zitadel/zitadel/internal/zerrors" +) + +type orgDomainRelationalProjection struct{} + +func newOrgDomainRelationalProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(orgDomainRelationalProjection)) +} + +func (*orgDomainRelationalProjection) Name() string { + return "zitadel.org_domains" +} + +func (p *orgDomainRelationalProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: org.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: org.OrgDomainAddedEventType, + Reduce: p.reduceAdded, + }, + { + Event: org.OrgDomainPrimarySetEventType, + Reduce: p.reducePrimarySet, + }, + { + Event: org.OrgDomainRemovedEventType, + Reduce: p.reduceRemoved, + }, + { + Event: org.OrgDomainVerificationAddedEventType, + Reduce: p.reduceVerificationAdded, + }, + { + Event: org.OrgDomainVerifiedEventType, + Reduce: p.reduceVerified, + }, + }, + }, + } +} + +func (p *orgDomainRelationalProjection) reduceAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-ZX9Fw", "reduce.wrong.event.type %s", org.OrgDomainAddedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-kGokE", "reduce.wrong.db.pool %T", ex) + } + return repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false).Add(ctx, &domain.AddOrganizationDomain{ + InstanceID: e.Aggregate().InstanceID, + OrgID: e.Aggregate().ResourceOwner, + Domain: e.Domain, + CreatedAt: e.CreationDate(), + UpdatedAt: e.CreationDate(), + }) + }), nil +} + +func (p *orgDomainRelationalProjection) reducePrimarySet(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainPrimarySetEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-dmFdb", "reduce.wrong.event.type %s", org.OrgDomainPrimarySetEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-h6xF0", "reduce.wrong.db.pool %T", ex) + } + domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) + condition := database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + ) + _, err := domainRepo.Update(ctx, + condition, + domainRepo.SetPrimary(), + ) + if err != nil { + return err + } + // we need to split the update into two statements because multiple events can have the same creation date + // therefore we first do not set the updated_at timestamp + _, err = domainRepo.Update(ctx, + condition, + domainRepo.SetUpdatedAt(e.CreationDate()), + ) + return err + }), nil +} + +func (p *orgDomainRelationalProjection) reduceRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-MzC0n", "reduce.wrong.event.type %s", org.OrgDomainRemovedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-X8oS8", "reduce.wrong.db.pool %T", ex) + } + domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) + _, err := domainRepo.Remove(ctx, + database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + ), + ) + return err + }), nil +} + +func (p *orgDomainRelationalProjection) reduceVerificationAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainVerificationAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-oGzip", "reduce.wrong.event.type %s", org.OrgDomainVerificationAddedEventType) + } + var validationType domain.DomainValidationType + switch e.ValidationType { + case old_domain.OrgDomainValidationTypeDNS: + validationType = domain.DomainValidationTypeDNS + case old_domain.OrgDomainValidationTypeHTTP: + validationType = domain.DomainValidationTypeHTTP + case old_domain.OrgDomainValidationTypeUnspecified: + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-FJfKB", "reduce.unsupported.validation.type %v", e.ValidationType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-yF03i", "reduce.wrong.db.pool %T", ex) + } + domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) + condition := database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + ) + + _, err := domainRepo.Update(ctx, + condition, + domainRepo.SetValidationType(validationType), + ) + if err != nil { + return err + } + // we need to split the update into two statements because multiple events can have the same creation date + // therefore we first do not set the updated_at timestamp + _, err = domainRepo.Update(ctx, + condition, + domainRepo.SetUpdatedAt(e.CreationDate()), + ) + return err + }), nil +} + +func (p *orgDomainRelationalProjection) reduceVerified(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.DomainVerifiedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-7WrI2", "reduce.wrong.event.type %s", org.OrgDomainVerifiedEventType) + } + return handler.NewStatement(e, func(ctx context.Context, ex handler.Executer, projectionName string) error { + tx, ok := ex.(*sql.Tx) + if !ok { + return zerrors.ThrowInvalidArgumentf(nil, "HANDL-0ZGqC", "reduce.wrong.db.pool %T", ex) + } + domainRepo := repository.OrganizationRepository(v3_sql.SQLTx(tx)).Domains(false) + + condition := database.And( + domainRepo.InstanceIDCondition(e.Aggregate().InstanceID), + domainRepo.OrgIDCondition(e.Aggregate().ResourceOwner), + domainRepo.DomainCondition(database.TextOperationEqual, e.Domain), + ) + + _, err := domainRepo.Update(ctx, + condition, + domainRepo.SetVerified(), + domainRepo.SetUpdatedAt(e.CreationDate()), + ) + if err != nil { + return err + } + // we need to split the update into two statements because multiple events can have the same creation date + // therefore we first do not set the updated_at timestamp + _, err = domainRepo.Update(ctx, + condition, + domainRepo.SetUpdatedAt(e.CreationDate()), + ) + return err + }), nil +} diff --git a/internal/query/projection/org_relational.go b/internal/query/projection/org_relational.go new file mode 100644 index 00000000000..3cc469a4dff --- /dev/null +++ b/internal/query/projection/org_relational.go @@ -0,0 +1,181 @@ +package projection + +import ( + "context" + + repoDomain "github.com/zitadel/zitadel/backend/v3/domain" + "github.com/zitadel/zitadel/internal/eventstore" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/org" + "github.com/zitadel/zitadel/internal/zerrors" +) + +const ( + OrgRelationProjectionTable = "zitadel.organizations" +) + +type orgRelationalProjection struct{} + +func (*orgRelationalProjection) Name() string { + return OrgRelationProjectionTable +} + +func newOrgRelationalProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(orgRelationalProjection)) +} + +func (p *orgRelationalProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: org.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: org.OrgAddedEventType, + Reduce: p.reduceOrgRelationalAdded, + }, + { + Event: org.OrgChangedEventType, + Reduce: p.reduceOrgRelationalChanged, + }, + { + Event: org.OrgDeactivatedEventType, + Reduce: p.reduceOrgRelationalDeactivated, + }, + { + Event: org.OrgReactivatedEventType, + Reduce: p.reduceOrgRelationalReactivated, + }, + { + Event: org.OrgRemovedEventType, + Reduce: p.reduceOrgRelationalRemoved, + }, + // TODO + // { + // Event: org.OrgDomainPrimarySetEventType, + // Reduce: p.reducePrimaryDomainSetRelational, + // }, + }, + }, + { + Aggregate: instance.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: instance.InstanceRemovedEventType, + Reduce: reduceInstanceRemovedHelper(OrgColumnInstanceID), + }, + }, + }, + } +} + +func (p *orgRelationalProjection) reduceOrgRelationalAdded(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgAddedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-uYq5R", "reduce.wrong.event.type %s", org.OrgAddedEventType) + } + + return handler.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(OrgColumnID, e.Aggregate().ID), + handler.NewCol(OrgColumnName, e.Name), + handler.NewCol(OrgColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(State, repoDomain.OrgStateActive), + handler.NewCol(CreatedAt, e.CreationDate()), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + ), nil +} + +func (p *orgRelationalProjection) reduceOrgRelationalChanged(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgChangedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-Bg9om", "reduce.wrong.event.type %s", org.OrgChangedEventType) + } + if e.Name == "" { + return handler.NewNoOpStatement(e), nil + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(OrgColumnName, e.Name), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (p *orgRelationalProjection) reduceOrgRelationalDeactivated(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgDeactivatedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-BApK5", "reduce.wrong.event.type %s", org.OrgDeactivatedEventType) + } + + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(State, repoDomain.OrgStateInactive), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (p *orgRelationalProjection) reduceOrgRelationalReactivated(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgReactivatedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-o38DE", "reduce.wrong.event.type %s", org.OrgReactivatedEventType) + } + return handler.NewUpdateStatement( + e, + []handler.Column{ + handler.NewCol(State, repoDomain.OrgStateActive), + handler.NewCol(UpdatedAt, e.CreationDate()), + }, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +// TODO +// func (p *orgRelationalProjection) reducePrimaryDomainSetRelational(event eventstore.Event) (*handler.Statement, error) { +// e, ok := event.(*org.DomainPrimarySetEvent) +// if !ok { +// return nil, zerrors.ThrowInvalidArgumentf(nil, "HANDL-3Tbkt", "reduce.wrong.event.type %s", org.OrgDomainPrimarySetEventType) +// } +// return handler.NewUpdateStatement( +// e, +// []handler.Column{ +// handler.NewCol(OrgColumnChangeDate, e.CreationDate()), +// handler.NewCol(OrgColumnSequence, e.Sequence()), +// handler.NewCol(OrgColumnDomain, e.Domain), +// }, +// []handler.Condition{ +// handler.NewCond(OrgColumnID, e.Aggregate().ID), +// handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), +// }, +// ), nil +// } + +func (p *orgRelationalProjection) reduceOrgRelationalRemoved(event eventstore.Event) (*handler.Statement, error) { + e, ok := event.(*org.OrgRemovedEvent) + if !ok { + return nil, zerrors.ThrowInvalidArgumentf(nil, "PROJE-DGm9g", "reduce.wrong.event.type %s", org.OrgRemovedEventType) + } + return handler.NewDeleteStatement( + e, + []handler.Condition{ + handler.NewCond(OrgColumnID, e.Aggregate().ID), + handler.NewCond(OrgColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index 3b22f6c7d7d..cddf65114f8 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -89,6 +89,11 @@ var ( HostedLoginTranslationProjection *handler.Handler OrganizationSettingsProjection *handler.Handler + InstanceRelationalProjection *handler.Handler + OrganizationRelationalProjection *handler.Handler + InstanceDomainRelationalProjection *handler.Handler + OrganizationDomainRelationalProjection *handler.Handler + ProjectGrantFields *handler.FieldHandler OrgDomainVerifiedFields *handler.FieldHandler InstanceDomainFields *handler.FieldHandler @@ -199,6 +204,11 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore, PermissionFields = newFillPermissionFields(applyCustomConfig(projectionConfig, config.Customizations[fieldsPermission])) // Don't forget to add the new field handler to [ProjectInstanceFields] + InstanceRelationalProjection = newInstanceRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instances_relational"])) + OrganizationRelationalProjection = newOrgRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["organizations_relational"])) + InstanceDomainRelationalProjection = newInstanceDomainRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instance_domains_relational"])) + OrganizationDomainRelationalProjection = newOrgDomainRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["organization_domains_relational"])) + newProjectionsList() newFieldsList() return nil @@ -380,5 +390,10 @@ func newProjectionsList() { DebugEventsProjection, HostedLoginTranslationProjection, OrganizationSettingsProjection, + + InstanceRelationalProjection, + OrganizationRelationalProjection, + InstanceDomainRelationalProjection, + OrganizationDomainRelationalProjection, } } diff --git a/internal/query/projection/relational_common.go b/internal/query/projection/relational_common.go new file mode 100644 index 00000000000..bdb1fdd10e8 --- /dev/null +++ b/internal/query/projection/relational_common.go @@ -0,0 +1,7 @@ +package projection + +const ( + State = "state" + CreatedAt = "created_at" + UpdatedAt = "updated_at" +) diff --git a/internal/repository/org/domain.go b/internal/repository/org/domain.go index 0b722b3ca0b..85b8a939f69 100644 --- a/internal/repository/org/domain.go +++ b/internal/repository/org/domain.go @@ -117,7 +117,8 @@ func NewDomainVerificationAddedEvent( aggregate *eventstore.Aggregate, domain string, validationType domain.OrgDomainValidationType, - validationCode *crypto.CryptoValue) *DomainVerificationAddedEvent { + validationCode *crypto.CryptoValue, +) *DomainVerificationAddedEvent { return &DomainVerificationAddedEvent{ BaseEvent: *eventstore.NewBaseEventForPush( ctx,