mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-24 03:57:13 +00:00
feat(backend): state persisted objects (#9870)
This PR initiates the rework of Zitadel's backend to state-persisted objects. This change is a step towards a more scalable and maintainable architecture. ## Changes * **New `/backend/v3` package**: A new package structure has been introduced to house the reworked backend logic. This includes: * `domain`: Contains the core business logic, commands, and repository interfaces. * `storage`: Implements the repository interfaces for database interactions with new transactional tables. * `telemetry`: Provides logging and tracing capabilities. * **Transactional Tables**: New database tables have been defined for `instances`, `instance_domains`, `organizations`, and `org_domains`. * **Projections**: New projections have been created to populate the new relational tables from the existing event store, ensuring data consistency during the migration. * **Repositories**: New repositories provide an abstraction layer for accessing and manipulating the data in the new tables. * **Setup**: A new setup step for `TransactionalTables` has been added to manage the database migrations for the new tables. This PR lays the foundation for future work to fully transition to state-persisted objects for these components, which will improve performance and simplify data access patterns. This PR initiates the rework of ZITADEL's backend to state-persisted objects. This is a foundational step towards a new architecture that will improve performance and maintainability. The following objects are migrated from event-sourced aggregates to state-persisted objects: * Instances * incl. Domains * Orgs * incl. Domains The structure of the new backend implementation follows the software architecture defined in this [wiki page](https://github.com/zitadel/zitadel/wiki/Software-Architecturel). This PR includes: * The initial implementation of the new transactional repositories for the objects listed above. * Projections to populate the new relational tables from the existing event store. * Adjustments to the build and test process to accommodate the new backend structure. This is a work in progress and further changes will be made to complete the migration. --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Iraq Jaber <iraq+github@zitadel.com> Co-authored-by: Iraq <66622793+kkrime@users.noreply.github.com> Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
4
Makefile
4
Makefile
@@ -138,7 +138,7 @@ core_integration_server_start: core_integration_setup
|
||||
|
||||
.PHONY: core_integration_test_packages
|
||||
core_integration_test_packages:
|
||||
go test -race -count 1 -tags integration -timeout 60m -parallel 1 $$(go list -tags integration ./... | grep "integration_test")
|
||||
go test -race -count 1 -tags integration -timeout 5m -parallel 1 $$(go list -tags integration ./... | grep -e "integration_test" -e "events_testing") -run ^TestServer_TestInstanceReduces$
|
||||
|
||||
.PHONY: core_integration_server_stop
|
||||
core_integration_server_stop:
|
||||
@@ -152,7 +152,7 @@ core_integration_server_stop:
|
||||
|
||||
.PHONY: core_integration_reports
|
||||
core_integration_reports:
|
||||
go tool covdata textfmt -i=tmp/coverage -pkg=github.com/zitadel/zitadel/internal/...,github.com/zitadel/zitadel/cmd/... -o profile.cov
|
||||
go tool covdata textfmt -i=tmp/coverage -pkg=github.com/zitadel/zitadel/internal/...,github.com/zitadel/zitadel/cmd/...,github.com/zitadel/zitadel/backend/... -o profile.cov
|
||||
|
||||
.PHONY: core_integration_test
|
||||
core_integration_test: core_integration_server_start core_integration_test_packages core_integration_server_stop core_integration_reports
|
||||
|
||||
10
backend/main.go
Normal file
10
backend/main.go
Normal file
@@ -0,0 +1,10 @@
|
||||
// huhu thanks for looking at this.
|
||||
// you can find more comments and doc.go files in the v3 package.
|
||||
// to get an overview i would start in the api package.
|
||||
package main
|
||||
|
||||
// import "github.com/zitadel/zitadel/backend/cmd"
|
||||
|
||||
// func main() {
|
||||
// cmd.Execute()
|
||||
// }
|
||||
21
backend/v3/api/instance/v2/server.go
Normal file
21
backend/v3/api/instance/v2/server.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package v2
|
||||
|
||||
// this file has been commented out to pass the linter
|
||||
|
||||
// import (
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
// )
|
||||
|
||||
// var (
|
||||
// logger logging.Logger
|
||||
// tracer tracing.Tracer
|
||||
// )
|
||||
|
||||
// func SetLogger(l logging.Logger) {
|
||||
// logger = l
|
||||
// }
|
||||
|
||||
// func SetTracer(t tracing.Tracer) {
|
||||
// tracer = t
|
||||
// }
|
||||
33
backend/v3/api/org/v2/org.go
Normal file
33
backend/v3/api/org/v2/org.go
Normal file
@@ -0,0 +1,33 @@
|
||||
package orgv2
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/domain"
|
||||
// "github.com/zitadel/zitadel/pkg/grpc/org/v2"
|
||||
// )
|
||||
|
||||
// func CreateOrg(ctx context.Context, req *org.AddOrganizationRequest) (resp *org.AddOrganizationResponse, err error) {
|
||||
// cmd := domain.NewAddOrgCommand(
|
||||
// req.GetName(),
|
||||
// addOrgAdminToCommand(req.GetAdmins()...)...,
|
||||
// )
|
||||
// err = domain.Invoke(ctx, cmd)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// return &org.AddOrganizationResponse{
|
||||
// OrganizationId: cmd.ID,
|
||||
// }, nil
|
||||
// }
|
||||
|
||||
// func addOrgAdminToCommand(admins ...*org.AddOrganizationRequest_Admin) []*domain.AddMemberCommand {
|
||||
// cmds := make([]*domain.AddMemberCommand, len(admins))
|
||||
// for i, admin := range admins {
|
||||
// cmds[i] = &domain.AddMemberCommand{
|
||||
// UserID: admin.GetUserId(),
|
||||
// Roles: admin.GetRoles(),
|
||||
// }
|
||||
// }
|
||||
// return cmds
|
||||
// }
|
||||
21
backend/v3/api/org/v2/server.go
Normal file
21
backend/v3/api/org/v2/server.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package orgv2
|
||||
|
||||
// this file has been commented out to pass the linter
|
||||
|
||||
// import (
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
// )
|
||||
|
||||
// var (
|
||||
// logger logging.Logger
|
||||
// tracer tracing.Tracer
|
||||
// )
|
||||
|
||||
// func SetLogger(l logging.Logger) {
|
||||
// logger = l
|
||||
// }
|
||||
|
||||
// func SetTracer(t tracing.Tracer) {
|
||||
// tracer = t
|
||||
// }
|
||||
93
backend/v3/api/user/v2/email.go
Normal file
93
backend/v3/api/user/v2/email.go
Normal file
@@ -0,0 +1,93 @@
|
||||
package userv2
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/domain"
|
||||
// "github.com/zitadel/zitadel/pkg/grpc/user/v2"
|
||||
// )
|
||||
|
||||
// func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
|
||||
// var (
|
||||
// verification domain.SetEmailOpt
|
||||
// returnCode *domain.ReturnCodeCommand
|
||||
// )
|
||||
|
||||
// switch req.GetVerification().(type) {
|
||||
// case *user.SetEmailRequest_IsVerified:
|
||||
// verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified())
|
||||
// case *user.SetEmailRequest_SendCode:
|
||||
// verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
// case *user.SetEmailRequest_ReturnCode:
|
||||
// returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
// verification = returnCode
|
||||
// default:
|
||||
// verification = domain.NewSendCodeCommand(req.GetUserId(), nil)
|
||||
// }
|
||||
|
||||
// err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification))
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
|
||||
// var code *string
|
||||
// if returnCode != nil && returnCode.Code != "" {
|
||||
// code = &returnCode.Code
|
||||
// }
|
||||
|
||||
// return &user.SetEmailResponse{
|
||||
// VerificationCode: code,
|
||||
// }, nil
|
||||
// }
|
||||
|
||||
// func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
|
||||
// var (
|
||||
// returnCode *domain.ReturnCodeCommand
|
||||
// cmd domain.Commander
|
||||
// )
|
||||
|
||||
// switch req.GetVerification().(type) {
|
||||
// case *user.SendEmailCodeRequest_SendCode:
|
||||
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
// case *user.SendEmailCodeRequest_ReturnCode:
|
||||
// returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
// cmd = returnCode
|
||||
// default:
|
||||
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
// }
|
||||
// err = domain.Invoke(ctx, cmd)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// resp = new(user.SendEmailCodeResponse)
|
||||
// if returnCode != nil {
|
||||
// resp.VerificationCode = &returnCode.Code
|
||||
// }
|
||||
// return resp, nil
|
||||
// }
|
||||
|
||||
// func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
|
||||
// var (
|
||||
// returnCode *domain.ReturnCodeCommand
|
||||
// cmd domain.Commander
|
||||
// )
|
||||
|
||||
// switch req.GetVerification().(type) {
|
||||
// case *user.ResendEmailCodeRequest_SendCode:
|
||||
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
// case *user.ResendEmailCodeRequest_ReturnCode:
|
||||
// returnCode = domain.NewReturnCodeCommand(req.GetUserId())
|
||||
// cmd = returnCode
|
||||
// default:
|
||||
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
|
||||
// }
|
||||
// err = domain.Invoke(ctx, cmd)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// resp = new(user.SendEmailCodeResponse)
|
||||
// if returnCode != nil {
|
||||
// resp.VerificationCode = &returnCode.Code
|
||||
// }
|
||||
// return resp, nil
|
||||
// }
|
||||
19
backend/v3/api/user/v2/server.go
Normal file
19
backend/v3/api/user/v2/server.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package userv2
|
||||
|
||||
// this file has been commented out to pass the linter
|
||||
|
||||
// import (
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
// )
|
||||
|
||||
// logger logging.Logger
|
||||
// var tracer tracing.Tracer
|
||||
|
||||
// func SetLogger(l logging.Logger) {
|
||||
// logger = l
|
||||
// }
|
||||
|
||||
// func SetTracer(t tracing.Tracer) {
|
||||
// tracer = t
|
||||
// }
|
||||
21
backend/v3/doc.go
Normal file
21
backend/v3/doc.go
Normal file
@@ -0,0 +1,21 @@
|
||||
// the test used the manly relies on the following patterns:
|
||||
// - api:
|
||||
// - some example stubs for the grpc api, it maps the calls and responses to the domain objects
|
||||
//
|
||||
// - domain:
|
||||
// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package
|
||||
// - command pattern which implements the changes
|
||||
// - the invoker decorates the commands by checking for events, tracing, logging, potentially caching, etc.
|
||||
// - the database connections are manged in this package
|
||||
// - the database connections are passed to the repositories
|
||||
//
|
||||
// - storage:
|
||||
// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package
|
||||
// - the repositories are used by the domain package to access the database
|
||||
// - the eventstore to store events. At the beginning it writes to the same events table as the /internal package, afterwards it writes to a different table
|
||||
//
|
||||
// - telemetry:
|
||||
// - logging for standard output
|
||||
// - tracing for distributed tracing
|
||||
// - metrics for monitoring
|
||||
package v3
|
||||
131
backend/v3/domain/command.go
Normal file
131
backend/v3/domain/command.go
Normal file
@@ -0,0 +1,131 @@
|
||||
package domain
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
// )
|
||||
|
||||
// // Commander is the all it needs to implement the command pattern.
|
||||
// // It is the interface all manipulations need to implement.
|
||||
// // If possible it should also be used for queries. We will find out if this is possible in the future.
|
||||
// type Commander interface {
|
||||
// Execute(ctx context.Context, opts *CommandOpts) (err error)
|
||||
// fmt.Stringer
|
||||
// }
|
||||
|
||||
// // Invoker is part of the command pattern.
|
||||
// // It is the interface that is used to execute commands.
|
||||
// type Invoker interface {
|
||||
// Invoke(ctx context.Context, command Commander, opts *CommandOpts) error
|
||||
// }
|
||||
|
||||
// // CommandOpts are passed to each command
|
||||
// // the provide common fields used by commands like the database client.
|
||||
// type CommandOpts struct {
|
||||
// DB database.QueryExecutor
|
||||
// Invoker Invoker
|
||||
// }
|
||||
|
||||
// type ensureTxOpts struct {
|
||||
// *database.TransactionOptions
|
||||
// }
|
||||
|
||||
// type EnsureTransactionOpt func(*ensureTxOpts)
|
||||
|
||||
// // EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction.
|
||||
// // The returned close function will end the transaction. If the DB is already a transaction, the close function
|
||||
// // will do nothing because another [Commander] is already responsible for ending the transaction.
|
||||
// func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) {
|
||||
// beginner, ok := o.DB.(database.Beginner)
|
||||
// if !ok {
|
||||
// // db is already a transaction
|
||||
// return func(_ context.Context, err error) error {
|
||||
// return err
|
||||
// }, nil
|
||||
// }
|
||||
|
||||
// txOpts := &ensureTxOpts{
|
||||
// TransactionOptions: new(database.TransactionOptions),
|
||||
// }
|
||||
// for _, opt := range opts {
|
||||
// opt(txOpts)
|
||||
// }
|
||||
|
||||
// tx, err := beginner.Begin(ctx, txOpts.TransactionOptions)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// o.DB = tx
|
||||
|
||||
// return func(ctx context.Context, err error) error {
|
||||
// return tx.End(ctx, err)
|
||||
// }, nil
|
||||
// }
|
||||
|
||||
// // EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool].
|
||||
// // The returned close function will release the client. If the o.DB is already a client or transaction, the close function
|
||||
// // will do nothing because another [Commander] is already responsible for releasing the client.
|
||||
// func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) {
|
||||
// pool, ok := o.DB.(database.Pool)
|
||||
// if !ok {
|
||||
// // o.DB is already a client
|
||||
// return func(_ context.Context) error {
|
||||
// return nil
|
||||
// }, nil
|
||||
// }
|
||||
// client, err := pool.Acquire(ctx)
|
||||
// if err != nil {
|
||||
// return nil, err
|
||||
// }
|
||||
// o.DB = client
|
||||
// return func(ctx context.Context) error {
|
||||
// return client.Release(ctx)
|
||||
// }, nil
|
||||
// }
|
||||
|
||||
// func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error {
|
||||
// if o.Invoker == nil {
|
||||
// return command.Execute(ctx, o)
|
||||
// }
|
||||
// return o.Invoker.Invoke(ctx, command, o)
|
||||
// }
|
||||
|
||||
// func DefaultOpts(invoker Invoker) *CommandOpts {
|
||||
// if invoker == nil {
|
||||
// invoker = &noopInvoker{}
|
||||
// }
|
||||
// return &CommandOpts{
|
||||
// DB: pool,
|
||||
// Invoker: invoker,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // commandBatch is a batch of commands.
|
||||
// // It uses the [Invoker] provided by the opts to execute each command.
|
||||
// type commandBatch struct {
|
||||
// Commands []Commander
|
||||
// }
|
||||
|
||||
// func BatchCommands(cmds ...Commander) *commandBatch {
|
||||
// return &commandBatch{
|
||||
// Commands: cmds,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *commandBatch) String() string {
|
||||
// return "commandBatch"
|
||||
// }
|
||||
|
||||
// func (b *commandBatch) Execute(ctx context.Context, opts *CommandOpts) (err error) {
|
||||
// for _, cmd := range b.Commands {
|
||||
// if err = opts.Invoke(ctx, cmd); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// var _ Commander = (*commandBatch)(nil)
|
||||
90
backend/v3/domain/create_user.go
Normal file
90
backend/v3/domain/create_user.go
Normal file
@@ -0,0 +1,90 @@
|
||||
package domain
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
// )
|
||||
|
||||
// // CreateUserCommand adds a new user including the email verification for humans.
|
||||
// // In the future it might make sense to separate the command into two commands:
|
||||
// // - CreateHumanCommand: creates a new human user
|
||||
// // - CreateMachineCommand: creates a new machine user
|
||||
// type CreateUserCommand struct {
|
||||
// user *User
|
||||
// email *SetEmailCommand
|
||||
// }
|
||||
|
||||
// var (
|
||||
// _ Commander = (*CreateUserCommand)(nil)
|
||||
// _ eventer = (*CreateUserCommand)(nil)
|
||||
// )
|
||||
|
||||
// // opts heavily reduces the complexity for email verification because each type of verification is a simple option which implements the [Commander] interface.
|
||||
// func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand {
|
||||
// cmd := &CreateUserCommand{
|
||||
// user: &User{
|
||||
// Username: username,
|
||||
// Traits: &Human{},
|
||||
// },
|
||||
// }
|
||||
|
||||
// for _, opt := range opts {
|
||||
// opt.applyOnCreateHuman(cmd)
|
||||
// }
|
||||
// return cmd
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *CreateUserCommand) String() string {
|
||||
// return "CreateUserCommand"
|
||||
// }
|
||||
|
||||
// // Events implements [eventer].
|
||||
// func (c *CreateUserCommand) Events() []*eventstore.Event {
|
||||
// return []*eventstore.Event{
|
||||
// {
|
||||
// AggregateType: "user",
|
||||
// AggregateID: c.user.ID,
|
||||
// Type: "user.added",
|
||||
// Payload: c.user,
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
|
||||
// // Execute implements [Commander].
|
||||
// func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
// if err := c.ensureUserID(); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// c.email.UserID = c.user.ID
|
||||
// if err := opts.Invoke(ctx, c.email); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// type CreateHumanOpt interface {
|
||||
// applyOnCreateHuman(*CreateUserCommand)
|
||||
// }
|
||||
|
||||
// type createHumanIDOpt string
|
||||
|
||||
// // applyOnCreateHuman implements [CreateHumanOpt].
|
||||
// func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) {
|
||||
// cmd.user.ID = string(c)
|
||||
// }
|
||||
|
||||
// var _ CreateHumanOpt = (*createHumanIDOpt)(nil)
|
||||
|
||||
// func CreateHumanWithID(id string) CreateHumanOpt {
|
||||
// return createHumanIDOpt(id)
|
||||
// }
|
||||
|
||||
// func (c *CreateUserCommand) ensureUserID() (err error) {
|
||||
// if c.user.ID != "" {
|
||||
// return nil
|
||||
// }
|
||||
// c.user.ID, err = generateID()
|
||||
// return err
|
||||
// }
|
||||
37
backend/v3/domain/crypto.go
Normal file
37
backend/v3/domain/crypto.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package domain
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
|
||||
// "github.com/zitadel/zitadel/internal/crypto"
|
||||
// )
|
||||
|
||||
// type generateCodeCommand struct {
|
||||
// code string
|
||||
// value *crypto.CryptoValue
|
||||
// }
|
||||
|
||||
// // I didn't update this repository to the solution proposed please view one of the following interfaces for correct usage:
|
||||
// // - [UserRepository]
|
||||
// // - [InstanceRepository]
|
||||
// // - [OrgRepository]
|
||||
// type CryptoRepository interface {
|
||||
// GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *generateCodeCommand) String() string {
|
||||
// return "generateCodeCommand"
|
||||
// }
|
||||
|
||||
// func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
// config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm)
|
||||
// cmd.value, cmd.code, err = crypto.NewCode(generator)
|
||||
// return err
|
||||
// }
|
||||
|
||||
// var _ Commander = (*generateCodeCommand)(nil)
|
||||
127
backend/v3/domain/domain.go
Normal file
127
backend/v3/domain/domain.go
Normal file
@@ -0,0 +1,127 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
//go:generate enumer -type DomainValidationType -transform lower -trimprefix DomainValidationType -sql
|
||||
type DomainValidationType uint8
|
||||
|
||||
const (
|
||||
DomainValidationTypeDNS DomainValidationType = iota
|
||||
DomainValidationTypeHTTP
|
||||
)
|
||||
|
||||
//go:generate enumer -type DomainType -transform lower -trimprefix DomainType -sql
|
||||
type DomainType uint8
|
||||
|
||||
const (
|
||||
DomainTypeCustom DomainType = iota
|
||||
DomainTypeTrusted
|
||||
)
|
||||
|
||||
type domainColumns interface {
|
||||
// InstanceIDColumn returns the column for the instance id field.
|
||||
InstanceIDColumn() database.Column
|
||||
// DomainColumn returns the column for the domain field.
|
||||
DomainColumn() database.Column
|
||||
// IsPrimaryColumn returns the column for the is primary field.
|
||||
IsPrimaryColumn() database.Column
|
||||
// CreatedAtColumn returns the column for the created at field.
|
||||
CreatedAtColumn() database.Column
|
||||
// UpdatedAtColumn returns the column for the updated at field.
|
||||
UpdatedAtColumn() database.Column
|
||||
}
|
||||
|
||||
type domainConditions interface {
|
||||
// InstanceIDCondition returns a filter on the instance id field.
|
||||
InstanceIDCondition(instanceID string) database.Condition
|
||||
// DomainCondition returns a filter on the domain field.
|
||||
DomainCondition(op database.TextOperation, domain string) database.Condition
|
||||
// IsPrimaryCondition returns a filter on the is primary field.
|
||||
IsPrimaryCondition(isPrimary bool) database.Condition
|
||||
}
|
||||
|
||||
type domainChanges interface {
|
||||
// SetPrimary sets a domain as primary based on the condition.
|
||||
// All other domains will be set to non-primary.
|
||||
//
|
||||
// An error is returned if:
|
||||
// - The condition identifies multiple domains.
|
||||
// - The condition does not identify any domain.
|
||||
//
|
||||
// This is a no-op if:
|
||||
// - The domain is already primary.
|
||||
// - No domain matches the condition.
|
||||
SetPrimary() database.Change
|
||||
// SetUpdatedAt sets the updated at column.
|
||||
// This is used for reducing events.
|
||||
SetUpdatedAt(t time.Time) database.Change
|
||||
}
|
||||
|
||||
// import (
|
||||
// "math/rand/v2"
|
||||
// "strconv"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
|
||||
// // "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
// "github.com/zitadel/zitadel/internal/crypto"
|
||||
// )
|
||||
|
||||
// // The variables could also be moved to a struct.
|
||||
// // I just started with the singleton pattern and kept it like this.
|
||||
// var (
|
||||
// pool database.Pool
|
||||
// userCodeAlgorithm crypto.EncryptionAlgorithm
|
||||
// tracer tracing.Tracer
|
||||
// // logger logging.Logger
|
||||
|
||||
// userRepo func(database.QueryExecutor) UserRepository
|
||||
// // instanceRepo func(database.QueryExecutor) InstanceRepository
|
||||
// cryptoRepo func(database.QueryExecutor) CryptoRepository
|
||||
// orgRepo func(database.QueryExecutor) OrgRepository
|
||||
|
||||
// // instanceCache cache.Cache[instanceCacheIndex, string, *Instance]
|
||||
// orgCache cache.Cache[orgCacheIndex, string, *Org]
|
||||
|
||||
// generateID func() (string, error) = func() (string, error) {
|
||||
// return strconv.FormatUint(rand.Uint64(), 10), nil
|
||||
// }
|
||||
// )
|
||||
|
||||
// func SetPool(p database.Pool) {
|
||||
// pool = p
|
||||
// }
|
||||
|
||||
// func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) {
|
||||
// userCodeAlgorithm = algorithm
|
||||
// }
|
||||
|
||||
// func SetTracer(t tracing.Tracer) {
|
||||
// tracer = t
|
||||
// }
|
||||
|
||||
// // func SetLogger(l logging.Logger) {
|
||||
// // logger = l
|
||||
// // }
|
||||
|
||||
// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
|
||||
// userRepo = repo
|
||||
// }
|
||||
|
||||
// func SetOrgRepository(repo func(database.QueryExecutor) OrgRepository) {
|
||||
// orgRepo = repo
|
||||
// }
|
||||
|
||||
// // func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
|
||||
// // instanceRepo = repo
|
||||
// // }
|
||||
|
||||
// func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) {
|
||||
// cryptoRepo = repo
|
||||
// }
|
||||
67
backend/v3/domain/domain_test.go
Normal file
67
backend/v3/domain/domain_test.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package domain_test
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
// "log/slog"
|
||||
// "testing"
|
||||
|
||||
// "github.com/stretchr/testify/assert"
|
||||
// "github.com/stretchr/testify/require"
|
||||
// "go.opentelemetry.io/otel"
|
||||
// "go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
|
||||
// sdktrace "go.opentelemetry.io/otel/sdk/trace"
|
||||
// "go.uber.org/mock/gomock"
|
||||
|
||||
// . "github.com/zitadel/zitadel/backend/v3/domain"
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/database/dbmock"
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
|
||||
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
|
||||
// )
|
||||
|
||||
// These tests give an overview of how to use the domain package.
|
||||
// func TestExample(t *testing.T) {
|
||||
// t.Skip("skip example test because it is not a real test")
|
||||
// ctx := context.Background()
|
||||
|
||||
// ctrl := gomock.NewController(t)
|
||||
// pool := dbmock.NewMockPool(ctrl)
|
||||
// tx := dbmock.NewMockTransaction(ctrl)
|
||||
|
||||
// pool.EXPECT().Begin(gomock.Any(), gomock.Any()).Return(tx, nil)
|
||||
// tx.EXPECT().End(gomock.Any(), gomock.Any()).Return(nil)
|
||||
// SetPool(pool)
|
||||
|
||||
// exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
|
||||
// require.NoError(t, err)
|
||||
// tracerProvider := sdktrace.NewTracerProvider(
|
||||
// sdktrace.WithSyncer(exporter),
|
||||
// )
|
||||
// otel.SetTracerProvider(tracerProvider)
|
||||
// SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
|
||||
// defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
|
||||
|
||||
// SetLogger(logging.Logger{Logger: slog.Default()})
|
||||
|
||||
// SetUserRepository(repository.UserRepository)
|
||||
// SetOrgRepository(repository.OrgRepository)
|
||||
// // SetInstanceRepository(repository.Instance)
|
||||
// // SetCryptoRepository(repository.Crypto)
|
||||
|
||||
// t.Run("create org", func(t *testing.T) {
|
||||
// org := NewAddOrgCommand("testorg", NewAddMemberCommand("testuser", "ORG_OWNER"))
|
||||
// user := NewCreateHumanCommand("testuser")
|
||||
// err := Invoke(ctx, BatchCommands(org, user))
|
||||
// assert.NoError(t, err)
|
||||
// })
|
||||
|
||||
// t.Run("verified email", func(t *testing.T) {
|
||||
// err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
|
||||
// assert.NoError(t, err)
|
||||
// })
|
||||
|
||||
// t.Run("unverified email", func(t *testing.T) {
|
||||
// err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
|
||||
// assert.NoError(t, err)
|
||||
// })
|
||||
// }
|
||||
109
backend/v3/domain/domaintype_enumer.go
Normal file
109
backend/v3/domain/domaintype_enumer.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Code generated by "enumer -type DomainType -transform lower -trimprefix DomainType -sql"; DO NOT EDIT.
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _DomainTypeName = "customtrusted"
|
||||
|
||||
var _DomainTypeIndex = [...]uint8{0, 6, 13}
|
||||
|
||||
const _DomainTypeLowerName = "customtrusted"
|
||||
|
||||
func (i DomainType) String() string {
|
||||
if i >= DomainType(len(_DomainTypeIndex)-1) {
|
||||
return fmt.Sprintf("DomainType(%d)", i)
|
||||
}
|
||||
return _DomainTypeName[_DomainTypeIndex[i]:_DomainTypeIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _DomainTypeNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[DomainTypeCustom-(0)]
|
||||
_ = x[DomainTypeTrusted-(1)]
|
||||
}
|
||||
|
||||
var _DomainTypeValues = []DomainType{DomainTypeCustom, DomainTypeTrusted}
|
||||
|
||||
var _DomainTypeNameToValueMap = map[string]DomainType{
|
||||
_DomainTypeName[0:6]: DomainTypeCustom,
|
||||
_DomainTypeLowerName[0:6]: DomainTypeCustom,
|
||||
_DomainTypeName[6:13]: DomainTypeTrusted,
|
||||
_DomainTypeLowerName[6:13]: DomainTypeTrusted,
|
||||
}
|
||||
|
||||
var _DomainTypeNames = []string{
|
||||
_DomainTypeName[0:6],
|
||||
_DomainTypeName[6:13],
|
||||
}
|
||||
|
||||
// DomainTypeString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func DomainTypeString(s string) (DomainType, error) {
|
||||
if val, ok := _DomainTypeNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _DomainTypeNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to DomainType values", s)
|
||||
}
|
||||
|
||||
// DomainTypeValues returns all values of the enum
|
||||
func DomainTypeValues() []DomainType {
|
||||
return _DomainTypeValues
|
||||
}
|
||||
|
||||
// DomainTypeStrings returns a slice of all String values of the enum
|
||||
func DomainTypeStrings() []string {
|
||||
strs := make([]string, len(_DomainTypeNames))
|
||||
copy(strs, _DomainTypeNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsADomainType returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i DomainType) IsADomainType() bool {
|
||||
for _, v := range _DomainTypeValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i DomainType) Value() (driver.Value, error) {
|
||||
return i.String(), nil
|
||||
}
|
||||
|
||||
func (i *DomainType) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var str string
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
str = string(v)
|
||||
case string:
|
||||
str = v
|
||||
case fmt.Stringer:
|
||||
str = v.String()
|
||||
default:
|
||||
return fmt.Errorf("invalid value of DomainType: %[1]T(%[1]v)", value)
|
||||
}
|
||||
|
||||
val, err := DomainTypeString(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*i = val
|
||||
return nil
|
||||
}
|
||||
109
backend/v3/domain/domainvalidationtype_enumer.go
Normal file
109
backend/v3/domain/domainvalidationtype_enumer.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Code generated by "enumer -type DomainValidationType -transform lower -trimprefix DomainValidationType -sql"; DO NOT EDIT.
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _DomainValidationTypeName = "dnshttp"
|
||||
|
||||
var _DomainValidationTypeIndex = [...]uint8{0, 3, 7}
|
||||
|
||||
const _DomainValidationTypeLowerName = "dnshttp"
|
||||
|
||||
func (i DomainValidationType) String() string {
|
||||
if i >= DomainValidationType(len(_DomainValidationTypeIndex)-1) {
|
||||
return fmt.Sprintf("DomainValidationType(%d)", i)
|
||||
}
|
||||
return _DomainValidationTypeName[_DomainValidationTypeIndex[i]:_DomainValidationTypeIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _DomainValidationTypeNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[DomainValidationTypeDNS-(0)]
|
||||
_ = x[DomainValidationTypeHTTP-(1)]
|
||||
}
|
||||
|
||||
var _DomainValidationTypeValues = []DomainValidationType{DomainValidationTypeDNS, DomainValidationTypeHTTP}
|
||||
|
||||
var _DomainValidationTypeNameToValueMap = map[string]DomainValidationType{
|
||||
_DomainValidationTypeName[0:3]: DomainValidationTypeDNS,
|
||||
_DomainValidationTypeLowerName[0:3]: DomainValidationTypeDNS,
|
||||
_DomainValidationTypeName[3:7]: DomainValidationTypeHTTP,
|
||||
_DomainValidationTypeLowerName[3:7]: DomainValidationTypeHTTP,
|
||||
}
|
||||
|
||||
var _DomainValidationTypeNames = []string{
|
||||
_DomainValidationTypeName[0:3],
|
||||
_DomainValidationTypeName[3:7],
|
||||
}
|
||||
|
||||
// DomainValidationTypeString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func DomainValidationTypeString(s string) (DomainValidationType, error) {
|
||||
if val, ok := _DomainValidationTypeNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _DomainValidationTypeNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to DomainValidationType values", s)
|
||||
}
|
||||
|
||||
// DomainValidationTypeValues returns all values of the enum
|
||||
func DomainValidationTypeValues() []DomainValidationType {
|
||||
return _DomainValidationTypeValues
|
||||
}
|
||||
|
||||
// DomainValidationTypeStrings returns a slice of all String values of the enum
|
||||
func DomainValidationTypeStrings() []string {
|
||||
strs := make([]string, len(_DomainValidationTypeNames))
|
||||
copy(strs, _DomainValidationTypeNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsADomainValidationType returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i DomainValidationType) IsADomainValidationType() bool {
|
||||
for _, v := range _DomainValidationTypeValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i DomainValidationType) Value() (driver.Value, error) {
|
||||
return i.String(), nil
|
||||
}
|
||||
|
||||
func (i *DomainValidationType) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var str string
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
str = string(v)
|
||||
case string:
|
||||
str = v
|
||||
case fmt.Stringer:
|
||||
str = v.String()
|
||||
default:
|
||||
return fmt.Errorf("invalid value of DomainValidationType: %[1]T(%[1]v)", value)
|
||||
}
|
||||
|
||||
val, err := DomainValidationTypeString(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*i = val
|
||||
return nil
|
||||
}
|
||||
175
backend/v3/domain/email_verification.go
Normal file
175
backend/v3/domain/email_verification.go
Normal file
@@ -0,0 +1,175 @@
|
||||
package domain
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
// "time"
|
||||
// )
|
||||
|
||||
// // EmailVerifiedCommand verifies an email address for a user.
|
||||
// type EmailVerifiedCommand struct {
|
||||
// UserID string `json:"userId"`
|
||||
// Email *Email `json:"email"`
|
||||
// }
|
||||
|
||||
// func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand {
|
||||
// return &EmailVerifiedCommand{
|
||||
// UserID: userID,
|
||||
// Email: &Email{
|
||||
// VerifiedAt: time.Time{},
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *EmailVerifiedCommand) String() string {
|
||||
// return "EmailVerifiedCommand"
|
||||
// }
|
||||
|
||||
// var (
|
||||
// _ Commander = (*EmailVerifiedCommand)(nil)
|
||||
// _ SetEmailOpt = (*EmailVerifiedCommand)(nil)
|
||||
// )
|
||||
|
||||
// // Execute implements [Commander]
|
||||
// func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
// repo := userRepo(opts.DB).Human()
|
||||
// return repo.Update(ctx, repo.IDCondition(cmd.UserID), repo.SetEmailVerifiedAt(time.Time{}))
|
||||
// }
|
||||
|
||||
// // applyOnSetEmail implements [SetEmailOpt]
|
||||
// func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
// cmd.UserID = setEmailCmd.UserID
|
||||
// cmd.Email.Address = setEmailCmd.Email
|
||||
// setEmailCmd.verification = cmd
|
||||
// }
|
||||
|
||||
// // SendCodeCommand sends a verification code to the user's email address.
|
||||
// // If the URLTemplate is not set it will use the default of the organization / instance.
|
||||
// type SendCodeCommand struct {
|
||||
// UserID string `json:"userId"`
|
||||
// Email string `json:"email"`
|
||||
// URLTemplate *string `json:"urlTemplate"`
|
||||
// generator *generateCodeCommand
|
||||
// }
|
||||
|
||||
// var (
|
||||
// _ Commander = (*SendCodeCommand)(nil)
|
||||
// _ SetEmailOpt = (*SendCodeCommand)(nil)
|
||||
// )
|
||||
|
||||
// func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
|
||||
// return &SendCodeCommand{
|
||||
// UserID: userID,
|
||||
// generator: &generateCodeCommand{},
|
||||
// URLTemplate: urlTemplate,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *SendCodeCommand) String() string {
|
||||
// return "SendCodeCommand"
|
||||
// }
|
||||
|
||||
// // Execute implements [Commander]
|
||||
// func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
// if err := cmd.ensureEmail(ctx, opts); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// if err := cmd.ensureURL(ctx, opts); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// // TODO: queue notification
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
|
||||
// if cmd.Email != "" {
|
||||
// return nil
|
||||
// }
|
||||
// repo := userRepo(opts.DB).Human()
|
||||
// email, err := repo.GetEmail(ctx, repo.IDCondition(cmd.UserID))
|
||||
// if err != nil || !email.VerifiedAt.IsZero() {
|
||||
// return err
|
||||
// }
|
||||
// cmd.Email = email.Address
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error {
|
||||
// if cmd.URLTemplate != nil && *cmd.URLTemplate != "" {
|
||||
// return nil
|
||||
// }
|
||||
// _, _ = ctx, opts
|
||||
// // TODO: load default template
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // applyOnSetEmail implements [SetEmailOpt]
|
||||
// func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
// cmd.UserID = setEmailCmd.UserID
|
||||
// cmd.Email = setEmailCmd.Email
|
||||
// setEmailCmd.verification = cmd
|
||||
// }
|
||||
|
||||
// // ReturnCodeCommand creates the code and returns it to the caller.
|
||||
// // The caller gets the code by calling the Code field after the command got executed.
|
||||
// type ReturnCodeCommand struct {
|
||||
// UserID string `json:"userId"`
|
||||
// Email string `json:"email"`
|
||||
// Code string `json:"code"`
|
||||
// generator *generateCodeCommand
|
||||
// }
|
||||
|
||||
// var (
|
||||
// _ Commander = (*ReturnCodeCommand)(nil)
|
||||
// _ SetEmailOpt = (*ReturnCodeCommand)(nil)
|
||||
// )
|
||||
|
||||
// func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
|
||||
// return &ReturnCodeCommand{
|
||||
// UserID: userID,
|
||||
// generator: &generateCodeCommand{},
|
||||
// }
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *ReturnCodeCommand) String() string {
|
||||
// return "ReturnCodeCommand"
|
||||
// }
|
||||
|
||||
// // Execute implements [Commander]
|
||||
// func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
// if err := cmd.ensureEmail(ctx, opts); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// cmd.Code = cmd.generator.code
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
|
||||
// if cmd.Email != "" {
|
||||
// return nil
|
||||
// }
|
||||
// repo := userRepo(opts.DB).Human()
|
||||
// email, err := repo.GetEmail(ctx, repo.IDCondition(cmd.UserID))
|
||||
// if err != nil || !email.VerifiedAt.IsZero() {
|
||||
// return err
|
||||
// }
|
||||
// cmd.Email = email.Address
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // applyOnSetEmail implements [SetEmailOpt]
|
||||
// func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
|
||||
// cmd.UserID = setEmailCmd.UserID
|
||||
// cmd.Email = setEmailCmd.Email
|
||||
// setEmailCmd.verification = cmd
|
||||
// }
|
||||
7
backend/v3/domain/errors.go
Normal file
7
backend/v3/domain/errors.go
Normal file
@@ -0,0 +1,7 @@
|
||||
package domain
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNoAdminSpecified = errors.New("at least one admin must be specified")
|
||||
)
|
||||
117
backend/v3/domain/instance.go
Normal file
117
backend/v3/domain/instance.go
Normal file
@@ -0,0 +1,117 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type Instance struct {
|
||||
ID string `json:"id,omitempty" db:"id"`
|
||||
Name string `json:"name,omitempty" db:"name"`
|
||||
DefaultOrgID string `json:"defaultOrgId,omitempty" db:"default_org_id"`
|
||||
IAMProjectID string `json:"iamProjectId,omitempty" db:"iam_project_id"`
|
||||
ConsoleClientID string `json:"consoleClientId,omitempty" db:"console_client_id"`
|
||||
ConsoleAppID string `json:"consoleAppId,omitempty" db:"console_app_id"`
|
||||
DefaultLanguage string `json:"defaultLanguage,omitempty" db:"default_language"`
|
||||
CreatedAt time.Time `json:"createdAt" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
|
||||
|
||||
Domains []*InstanceDomain `json:"domains,omitempty" db:"-"`
|
||||
}
|
||||
|
||||
type instanceCacheIndex uint8
|
||||
|
||||
const (
|
||||
instanceCacheIndexUndefined instanceCacheIndex = iota
|
||||
instanceCacheIndexID
|
||||
)
|
||||
|
||||
// Keys implements the [cache.Entry].
|
||||
func (i *Instance) Keys(index instanceCacheIndex) (key []string) {
|
||||
if index == instanceCacheIndexID {
|
||||
return []string{i.ID}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil)
|
||||
|
||||
// instanceColumns define all the columns of the instance table.
|
||||
type instanceColumns interface {
|
||||
// IDColumn returns the column for the id field.
|
||||
IDColumn() database.Column
|
||||
// NameColumn returns the column for the name field.
|
||||
NameColumn() database.Column
|
||||
// DefaultOrgIDColumn returns the column for the default org id field
|
||||
DefaultOrgIDColumn() database.Column
|
||||
// IAMProjectIDColumn returns the column for the default IAM org id field
|
||||
IAMProjectIDColumn() database.Column
|
||||
// ConsoleClientIDColumn returns the column for the default IAM org id field
|
||||
ConsoleClientIDColumn() database.Column
|
||||
// ConsoleAppIDColumn returns the column for the console client id field
|
||||
ConsoleAppIDColumn() database.Column
|
||||
// DefaultLanguageColumn returns the column for the default language field
|
||||
DefaultLanguageColumn() database.Column
|
||||
// CreatedAtColumn returns the column for the created at field.
|
||||
CreatedAtColumn() database.Column
|
||||
// UpdatedAtColumn returns the column for the updated at field.
|
||||
UpdatedAtColumn() database.Column
|
||||
}
|
||||
|
||||
// instanceConditions define all the conditions for the instance table.
|
||||
type instanceConditions interface {
|
||||
// IDCondition returns an equal filter on the id field.
|
||||
IDCondition(instanceID string) database.Condition
|
||||
// NameCondition returns a filter on the name field.
|
||||
NameCondition(op database.TextOperation, name string) database.Condition
|
||||
}
|
||||
|
||||
// instanceChanges define all the changes for the instance table.
|
||||
type instanceChanges interface {
|
||||
// SetName sets the name column.
|
||||
SetName(name string) database.Change
|
||||
// SetUpdatedAt sets the updated at column.
|
||||
SetUpdatedAt(time time.Time) database.Change
|
||||
// SetIAMProject sets the iam project column.
|
||||
SetIAMProject(id string) database.Change
|
||||
// SetDefaultOrg sets the default org column.
|
||||
SetDefaultOrg(id string) database.Change
|
||||
// SetDefaultLanguage sets the default language column.
|
||||
SetDefaultLanguage(language language.Tag) database.Change
|
||||
// SetConsoleClientID sets the console client id column.
|
||||
SetConsoleClientID(id string) database.Change
|
||||
// SetConsoleAppID sets the console app id column.
|
||||
SetConsoleAppID(id string) database.Change
|
||||
}
|
||||
|
||||
// InstanceRepository is the interface for the instance repository.
|
||||
type InstanceRepository interface {
|
||||
instanceColumns
|
||||
instanceConditions
|
||||
instanceChanges
|
||||
|
||||
// TODO
|
||||
// Member returns the member repository which is a sub repository of the instance repository.
|
||||
// Member() MemberRepository
|
||||
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
|
||||
|
||||
Create(ctx context.Context, instance *Instance) error
|
||||
Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
|
||||
Delete(ctx context.Context, id string) (int64, error)
|
||||
|
||||
// Domains returns the domain sub repository for the instance.
|
||||
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
|
||||
// If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future.
|
||||
Domains(shouldLoad bool) InstanceDomainRepository
|
||||
}
|
||||
|
||||
type CreateInstance struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
79
backend/v3/domain/instance_domain.go
Normal file
79
backend/v3/domain/instance_domain.go
Normal file
@@ -0,0 +1,79 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type InstanceDomain struct {
|
||||
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
|
||||
Domain string `json:"domain,omitempty" db:"domain"`
|
||||
// IsPrimary indicates if the domain is the primary domain of the instance.
|
||||
// It is only set for custom domains.
|
||||
IsPrimary *bool `json:"isPrimary,omitempty" db:"is_primary"`
|
||||
// IsGenerated indicates if the domain is a generated domain.
|
||||
// It is only set for custom domains.
|
||||
IsGenerated *bool `json:"isGenerated,omitempty" db:"is_generated"`
|
||||
Type DomainType `json:"type,omitempty" db:"type"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
|
||||
}
|
||||
|
||||
type AddInstanceDomain struct {
|
||||
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
|
||||
Domain string `json:"domain,omitempty" db:"domain"`
|
||||
IsPrimary *bool `json:"isPrimary,omitempty" db:"is_primary"`
|
||||
IsGenerated *bool `json:"isGenerated,omitempty" db:"is_generated"`
|
||||
Type DomainType `json:"type,omitempty" db:"type"`
|
||||
|
||||
// CreatedAt is the time when the domain was added.
|
||||
// It is set by the repository and should not be set by the caller.
|
||||
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
|
||||
// UpdatedAt is the time when the domain was last updated.
|
||||
// It is set by the repository and should not be set by the caller.
|
||||
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
|
||||
}
|
||||
|
||||
type instanceDomainColumns interface {
|
||||
domainColumns
|
||||
// IsGeneratedColumn returns the column for the is generated field.
|
||||
IsGeneratedColumn() database.Column
|
||||
// TypeColumn returns the column for the type field.
|
||||
TypeColumn() database.Column
|
||||
}
|
||||
|
||||
type instanceDomainConditions interface {
|
||||
domainConditions
|
||||
// TypeCondition returns a filter for the type field.
|
||||
TypeCondition(typ DomainType) database.Condition
|
||||
}
|
||||
|
||||
type instanceDomainChanges interface {
|
||||
domainChanges
|
||||
// SetType sets the type column.
|
||||
SetType(typ DomainType) database.Change
|
||||
}
|
||||
|
||||
type InstanceDomainRepository interface {
|
||||
instanceDomainColumns
|
||||
instanceDomainConditions
|
||||
instanceDomainChanges
|
||||
|
||||
// Get returns a single domain based on the criteria.
|
||||
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
||||
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error)
|
||||
// List returns a list of domains based on the criteria.
|
||||
// If no domains are found, it returns an empty slice.
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error)
|
||||
|
||||
// Add adds a new domain to the instance.
|
||||
Add(ctx context.Context, domain *AddInstanceDomain) error
|
||||
// Update updates an existing domain in the instance.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
|
||||
// Remove removes a domain from the instance.
|
||||
Remove(ctx context.Context, condition database.Condition) (int64, error)
|
||||
}
|
||||
158
backend/v3/domain/invoke.go
Normal file
158
backend/v3/domain/invoke.go
Normal file
@@ -0,0 +1,158 @@
|
||||
package domain
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
// "fmt"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
// )
|
||||
|
||||
// // Invoke provides a way to execute commands within the domain package.
|
||||
// // It uses a chain of responsibility pattern to handle the command execution.
|
||||
// // The default chain includes logging, tracing, and event publishing.
|
||||
// // If you want to invoke multiple commands in a single transaction, you can use the [commandBatch].
|
||||
// func Invoke(ctx context.Context, cmd Commander) error {
|
||||
// invoker := newEventStoreInvoker(newLoggingInvoker(newTraceInvoker(nil)))
|
||||
// opts := &CommandOpts{
|
||||
// Invoker: invoker.collector,
|
||||
// DB: pool,
|
||||
// }
|
||||
// return invoker.Invoke(ctx, cmd, opts)
|
||||
// }
|
||||
|
||||
// // eventStoreInvoker checks if the command implements the [eventer] interface.
|
||||
// // If it does, it collects the events and publishes them to the event store.
|
||||
// type eventStoreInvoker struct {
|
||||
// collector *eventCollector
|
||||
// }
|
||||
|
||||
// func newEventStoreInvoker(next Invoker) *eventStoreInvoker {
|
||||
// return &eventStoreInvoker{collector: &eventCollector{next: next}}
|
||||
// }
|
||||
|
||||
// func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
// err = i.collector.Invoke(ctx, command, opts)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// if len(i.collector.events) > 0 {
|
||||
// err = eventstore.Publish(ctx, i.collector.events, opts.DB)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // eventCollector collects events from all commands. The [eventStoreInvoker] pushes the collected events after all commands are executed.
|
||||
// type eventCollector struct {
|
||||
// next Invoker
|
||||
// events []*eventstore.Event
|
||||
// }
|
||||
|
||||
// type eventer interface {
|
||||
// Events() []*eventstore.Event
|
||||
// }
|
||||
|
||||
// func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
// if e, ok := command.(eventer); ok && len(e.Events()) > 0 {
|
||||
// // we need to ensure all commands are executed in the same transaction
|
||||
// close, err := opts.EnsureTx(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer func() { err = close(ctx, err) }()
|
||||
|
||||
// i.events = append(i.events, e.Events()...)
|
||||
// }
|
||||
// if i.next != nil {
|
||||
// return i.next.Invoke(ctx, command, opts)
|
||||
// }
|
||||
// return command.Execute(ctx, opts)
|
||||
// }
|
||||
|
||||
// // traceInvoker decorates each command with tracing.
|
||||
// type traceInvoker struct {
|
||||
// next Invoker
|
||||
// }
|
||||
|
||||
// func newTraceInvoker(next Invoker) *traceInvoker {
|
||||
// return &traceInvoker{next: next}
|
||||
// }
|
||||
|
||||
// func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
// ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
|
||||
// defer func() {
|
||||
// if err != nil {
|
||||
// span.RecordError(err)
|
||||
// }
|
||||
// span.End()
|
||||
// }()
|
||||
|
||||
// if i.next != nil {
|
||||
// return i.next.Invoke(ctx, command, opts)
|
||||
// }
|
||||
// return command.Execute(ctx, opts)
|
||||
// }
|
||||
|
||||
// // loggingInvoker decorates each command with logging.
|
||||
// // It is an example implementation and logs the command name at the beginning and success or failure after the command got executed.
|
||||
// type loggingInvoker struct {
|
||||
// next Invoker
|
||||
// }
|
||||
|
||||
// func newLoggingInvoker(next Invoker) *loggingInvoker {
|
||||
// return &loggingInvoker{next: next}
|
||||
// }
|
||||
|
||||
// func (i *loggingInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
// logger.InfoContext(ctx, "Invoking command", "command", command.String())
|
||||
|
||||
// if i.next != nil {
|
||||
// err = i.next.Invoke(ctx, command, opts)
|
||||
// } else {
|
||||
// err = command.Execute(ctx, opts)
|
||||
// }
|
||||
|
||||
// if err != nil {
|
||||
// logger.ErrorContext(ctx, "Command invocation failed", "command", command.String(), "error", err)
|
||||
// return err
|
||||
// }
|
||||
// logger.InfoContext(ctx, "Command invocation succeeded", "command", command.String())
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// type noopInvoker struct {
|
||||
// next Invoker
|
||||
// }
|
||||
|
||||
// func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
|
||||
// if i.next != nil {
|
||||
// return i.next.Invoke(ctx, command, opts)
|
||||
// }
|
||||
// return command.Execute(ctx, opts)
|
||||
// }
|
||||
|
||||
// // cacheInvoker could be used in the future to do the caching.
|
||||
// // My goal would be to have two interfaces:
|
||||
// // - cacheSetter: which caches an object
|
||||
// // - cacheGetter: which gets an object from the cache, this should also skip the command execution
|
||||
// type cacheInvoker struct {
|
||||
// next Invoker
|
||||
// }
|
||||
|
||||
// type cacher interface {
|
||||
// Cache(opts *CommandOpts)
|
||||
// }
|
||||
|
||||
// func (i *cacheInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
|
||||
// if c, ok := command.(cacher); ok {
|
||||
// c.Cache(opts)
|
||||
// }
|
||||
// if i.next != nil {
|
||||
// err = i.next.Invoke(ctx, command, opts)
|
||||
// } else {
|
||||
// err = command.Execute(ctx, opts)
|
||||
// }
|
||||
// return err
|
||||
// }
|
||||
137
backend/v3/domain/org_add.go
Normal file
137
backend/v3/domain/org_add.go
Normal file
@@ -0,0 +1,137 @@
|
||||
package domain
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
// )
|
||||
|
||||
// // AddOrgCommand adds a new organization.
|
||||
// // I'm unsure if we should add the Admins here or if this should be a separate command.
|
||||
// type AddOrgCommand struct {
|
||||
// ID string `json:"id"`
|
||||
// Name string `json:"name"`
|
||||
// Admins []*AddMemberCommand `json:"admins"`
|
||||
// }
|
||||
|
||||
// func NewAddOrgCommand(name string, admins ...*AddMemberCommand) *AddOrgCommand {
|
||||
// return &AddOrgCommand{
|
||||
// Name: name,
|
||||
// Admins: admins,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *AddOrgCommand) String() string {
|
||||
// return "AddOrgCommand"
|
||||
// }
|
||||
|
||||
// // Execute implements Commander.
|
||||
// func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
|
||||
// if len(cmd.Admins) == 0 {
|
||||
// return ErrNoAdminSpecified
|
||||
// }
|
||||
// if err = cmd.ensureID(); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// close, err := opts.EnsureTx(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer func() { err = close(ctx, err) }()
|
||||
// err = orgRepo(opts.DB).Create(ctx, &Org{
|
||||
// ID: cmd.ID,
|
||||
// Name: cmd.Name,
|
||||
// })
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// for _, admin := range cmd.Admins {
|
||||
// admin.orgID = cmd.ID
|
||||
// if err = opts.Invoke(ctx, admin); err != nil {
|
||||
// return err
|
||||
// }
|
||||
// }
|
||||
|
||||
// orgCache.Set(ctx, &Org{
|
||||
// ID: cmd.ID,
|
||||
// Name: cmd.Name,
|
||||
// })
|
||||
|
||||
// return nil
|
||||
// }
|
||||
|
||||
// // Events implements [eventer].
|
||||
// func (cmd *AddOrgCommand) Events() []*eventstore.Event {
|
||||
// return []*eventstore.Event{
|
||||
// {
|
||||
// AggregateType: "org",
|
||||
// AggregateID: cmd.ID,
|
||||
// Type: "org.added",
|
||||
// Payload: cmd,
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
|
||||
// var (
|
||||
// _ Commander = (*AddOrgCommand)(nil)
|
||||
// _ eventer = (*AddOrgCommand)(nil)
|
||||
// )
|
||||
|
||||
// func (cmd *AddOrgCommand) ensureID() (err error) {
|
||||
// if cmd.ID != "" {
|
||||
// return nil
|
||||
// }
|
||||
// cmd.ID, err = generateID()
|
||||
// return err
|
||||
// }
|
||||
|
||||
// // AddMemberCommand adds a new member to an organization.
|
||||
// // I'm not sure if we should make it more generic to also use it for instances.
|
||||
// type AddMemberCommand struct {
|
||||
// orgID string
|
||||
// UserID string `json:"userId"`
|
||||
// Roles []string `json:"roles"`
|
||||
// }
|
||||
|
||||
// func NewAddMemberCommand(userID string, roles ...string) *AddMemberCommand {
|
||||
// return &AddMemberCommand{
|
||||
// UserID: userID,
|
||||
// Roles: roles,
|
||||
// }
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *AddMemberCommand) String() string {
|
||||
// return "AddMemberCommand"
|
||||
// }
|
||||
|
||||
// // Execute implements Commander.
|
||||
// func (a *AddMemberCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
|
||||
// close, err := opts.EnsureTx(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer func() { err = close(ctx, err) }()
|
||||
|
||||
// return orgRepo(opts.DB).Member().AddMember(ctx, a.orgID, a.UserID, a.Roles)
|
||||
// }
|
||||
|
||||
// // Events implements [eventer].
|
||||
// func (a *AddMemberCommand) Events() []*eventstore.Event {
|
||||
// return []*eventstore.Event{
|
||||
// {
|
||||
// AggregateType: "org",
|
||||
// AggregateID: a.UserID,
|
||||
// Type: "member.added",
|
||||
// Payload: a,
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
|
||||
// var (
|
||||
// _ Commander = (*AddMemberCommand)(nil)
|
||||
// _ eventer = (*AddMemberCommand)(nil)
|
||||
// )
|
||||
100
backend/v3/domain/organization.go
Normal file
100
backend/v3/domain/organization.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
//go:generate enumer -type OrgState -transform lower -trimprefix OrgState -sql
|
||||
type OrgState uint8
|
||||
|
||||
const (
|
||||
OrgStateActive OrgState = iota
|
||||
OrgStateInactive
|
||||
)
|
||||
|
||||
type Organization struct {
|
||||
ID string `json:"id,omitempty" db:"id"`
|
||||
Name string `json:"name,omitempty" db:"name"`
|
||||
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
|
||||
State OrgState `json:"state,omitempty" db:"state"`
|
||||
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
|
||||
|
||||
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately
|
||||
}
|
||||
|
||||
// OrgIdentifierCondition is used to help specify a single Organization,
|
||||
// it will either be used as the organization ID or organization name,
|
||||
// as organizations can be identified either using (instanceID + ID) OR (instanceID + name)
|
||||
type OrgIdentifierCondition interface {
|
||||
database.Condition
|
||||
}
|
||||
|
||||
// organizationColumns define all the columns of the instance table.
|
||||
type organizationColumns interface {
|
||||
// IDColumn returns the column for the id field.
|
||||
IDColumn() database.Column
|
||||
// NameColumn returns the column for the name field.
|
||||
NameColumn() database.Column
|
||||
// InstanceIDColumn returns the column for the default org id field
|
||||
InstanceIDColumn() database.Column
|
||||
// StateColumn returns the column for the name field.
|
||||
StateColumn() database.Column
|
||||
// CreatedAtColumn returns the column for the created at field.
|
||||
CreatedAtColumn() database.Column
|
||||
// UpdatedAtColumn returns the column for the updated at field.
|
||||
UpdatedAtColumn() database.Column
|
||||
}
|
||||
|
||||
// organizationConditions define all the conditions for the instance table.
|
||||
type organizationConditions interface {
|
||||
// IDCondition returns an equal filter on the id field.
|
||||
IDCondition(instanceID string) OrgIdentifierCondition
|
||||
// NameCondition returns a filter on the name field.
|
||||
NameCondition(name string) OrgIdentifierCondition
|
||||
// InstanceIDCondition returns a filter on the instance id field.
|
||||
InstanceIDCondition(instanceID string) database.Condition
|
||||
// StateCondition returns a filter on the name field.
|
||||
StateCondition(state OrgState) database.Condition
|
||||
}
|
||||
|
||||
// organizationChanges define all the changes for the instance table.
|
||||
type organizationChanges interface {
|
||||
// SetName sets the name column.
|
||||
SetName(name string) database.Change
|
||||
// SetState sets the name column.
|
||||
SetState(state OrgState) database.Change
|
||||
}
|
||||
|
||||
// OrganizationRepository is the interface for the instance repository.
|
||||
type OrganizationRepository interface {
|
||||
organizationColumns
|
||||
organizationConditions
|
||||
organizationChanges
|
||||
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
|
||||
|
||||
Create(ctx context.Context, instance *Organization) error
|
||||
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
|
||||
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
|
||||
|
||||
// Domains returns the domain sub repository for the organization.
|
||||
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
|
||||
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
|
||||
Domains(shouldLoad bool) OrganizationDomainRepository
|
||||
}
|
||||
|
||||
type CreateOrganization struct {
|
||||
Name string `json:"name"`
|
||||
}
|
||||
|
||||
// MemberRepository is a sub repository of the org repository and maybe the instance repository.
|
||||
type MemberRepository interface {
|
||||
AddMember(ctx context.Context, orgID, userID string, roles []string) error
|
||||
SetMemberRoles(ctx context.Context, orgID, userID string, roles []string) error
|
||||
RemoveMember(ctx context.Context, orgID, userID string) error
|
||||
}
|
||||
84
backend/v3/domain/organization_domain.go
Normal file
84
backend/v3/domain/organization_domain.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type OrganizationDomain struct {
|
||||
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
|
||||
OrgID string `json:"orgId,omitempty" db:"org_id"`
|
||||
Domain string `json:"domain,omitempty" db:"domain"`
|
||||
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
|
||||
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
|
||||
ValidationType *DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
|
||||
|
||||
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
|
||||
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
|
||||
}
|
||||
|
||||
type AddOrganizationDomain struct {
|
||||
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
|
||||
OrgID string `json:"orgId,omitempty" db:"org_id"`
|
||||
Domain string `json:"domain,omitempty" db:"domain"`
|
||||
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
|
||||
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
|
||||
ValidationType *DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
|
||||
|
||||
// CreatedAt is the time when the domain was added.
|
||||
// It is set by the repository and should not be set by the caller.
|
||||
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
|
||||
// UpdatedAt is the time when the domain was added.
|
||||
// It is set by the repository and should not be set by the caller.
|
||||
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
|
||||
}
|
||||
|
||||
type organizationDomainColumns interface {
|
||||
domainColumns
|
||||
// OrgIDColumn returns the column for the org id field.
|
||||
OrgIDColumn() database.Column
|
||||
// IsVerifiedColumn returns the column for the is verified field.
|
||||
IsVerifiedColumn() database.Column
|
||||
// ValidationTypeColumn returns the column for the verification type field.
|
||||
ValidationTypeColumn() database.Column
|
||||
}
|
||||
|
||||
type organizationDomainConditions interface {
|
||||
domainConditions
|
||||
// OrgIDCondition returns a filter on the org id field.
|
||||
OrgIDCondition(orgID string) database.Condition
|
||||
// IsVerifiedCondition returns a filter on the is verified field.
|
||||
IsVerifiedCondition(isVerified bool) database.Condition
|
||||
}
|
||||
|
||||
type organizationDomainChanges interface {
|
||||
domainChanges
|
||||
// SetVerified sets the is verified column to true.
|
||||
SetVerified() database.Change
|
||||
// SetValidationType sets the verification type column.
|
||||
// If the domain is already verified, this is a no-op.
|
||||
SetValidationType(verificationType DomainValidationType) database.Change
|
||||
}
|
||||
|
||||
type OrganizationDomainRepository interface {
|
||||
organizationDomainColumns
|
||||
organizationDomainConditions
|
||||
organizationDomainChanges
|
||||
|
||||
// Get returns a single domain based on the criteria.
|
||||
// If no domain is found, it returns an error of type [database.ErrNotFound].
|
||||
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error)
|
||||
// List returns a list of domains based on the criteria.
|
||||
// If no domains are found, it returns an empty slice.
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error)
|
||||
|
||||
// Add adds a new domain to the organization.
|
||||
Add(ctx context.Context, domain *AddOrganizationDomain) error
|
||||
// Update updates an existing domain in the organization.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
|
||||
// Remove removes a domain from the organization.
|
||||
Remove(ctx context.Context, condition database.Condition) (int64, error)
|
||||
}
|
||||
109
backend/v3/domain/orgstate_enumer.go
Normal file
109
backend/v3/domain/orgstate_enumer.go
Normal file
@@ -0,0 +1,109 @@
|
||||
// Code generated by "enumer -type OrgState -transform lower -trimprefix OrgState -sql"; DO NOT EDIT.
|
||||
|
||||
package domain
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _OrgStateName = "activeinactive"
|
||||
|
||||
var _OrgStateIndex = [...]uint8{0, 6, 14}
|
||||
|
||||
const _OrgStateLowerName = "activeinactive"
|
||||
|
||||
func (i OrgState) String() string {
|
||||
if i >= OrgState(len(_OrgStateIndex)-1) {
|
||||
return fmt.Sprintf("OrgState(%d)", i)
|
||||
}
|
||||
return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _OrgStateNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[OrgStateActive-(0)]
|
||||
_ = x[OrgStateInactive-(1)]
|
||||
}
|
||||
|
||||
var _OrgStateValues = []OrgState{OrgStateActive, OrgStateInactive}
|
||||
|
||||
var _OrgStateNameToValueMap = map[string]OrgState{
|
||||
_OrgStateName[0:6]: OrgStateActive,
|
||||
_OrgStateLowerName[0:6]: OrgStateActive,
|
||||
_OrgStateName[6:14]: OrgStateInactive,
|
||||
_OrgStateLowerName[6:14]: OrgStateInactive,
|
||||
}
|
||||
|
||||
var _OrgStateNames = []string{
|
||||
_OrgStateName[0:6],
|
||||
_OrgStateName[6:14],
|
||||
}
|
||||
|
||||
// OrgStateString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func OrgStateString(s string) (OrgState, error) {
|
||||
if val, ok := _OrgStateNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _OrgStateNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to OrgState values", s)
|
||||
}
|
||||
|
||||
// OrgStateValues returns all values of the enum
|
||||
func OrgStateValues() []OrgState {
|
||||
return _OrgStateValues
|
||||
}
|
||||
|
||||
// OrgStateStrings returns a slice of all String values of the enum
|
||||
func OrgStateStrings() []string {
|
||||
strs := make([]string, len(_OrgStateNames))
|
||||
copy(strs, _OrgStateNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAOrgState returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i OrgState) IsAOrgState() bool {
|
||||
for _, v := range _OrgStateValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (i OrgState) Value() (driver.Value, error) {
|
||||
return i.String(), nil
|
||||
}
|
||||
|
||||
func (i *OrgState) Scan(value interface{}) error {
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var str string
|
||||
switch v := value.(type) {
|
||||
case []byte:
|
||||
str = string(v)
|
||||
case string:
|
||||
str = v
|
||||
case fmt.Stringer:
|
||||
str = v.String()
|
||||
default:
|
||||
return fmt.Errorf("invalid value of OrgState: %[1]T(%[1]v)", value)
|
||||
}
|
||||
|
||||
val, err := OrgStateString(str)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
*i = val
|
||||
return nil
|
||||
}
|
||||
74
backend/v3/domain/set_email.go
Normal file
74
backend/v3/domain/set_email.go
Normal file
@@ -0,0 +1,74 @@
|
||||
package domain
|
||||
|
||||
// import (
|
||||
// "context"
|
||||
|
||||
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
|
||||
// )
|
||||
|
||||
// // SetEmailCommand sets the email address of a user.
|
||||
// // If allows verification as a sub command.
|
||||
// // The verification command is executed after the email address is set.
|
||||
// // The verification command is executed in the same transaction as the email address update.
|
||||
// type SetEmailCommand struct {
|
||||
// UserID string `json:"userId"`
|
||||
// Email string `json:"email"`
|
||||
// verification Commander
|
||||
// }
|
||||
|
||||
// var (
|
||||
// _ Commander = (*SetEmailCommand)(nil)
|
||||
// _ eventer = (*SetEmailCommand)(nil)
|
||||
// _ CreateHumanOpt = (*SetEmailCommand)(nil)
|
||||
// )
|
||||
|
||||
// type SetEmailOpt interface {
|
||||
// applyOnSetEmail(*SetEmailCommand)
|
||||
// }
|
||||
|
||||
// func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand {
|
||||
// cmd := &SetEmailCommand{
|
||||
// UserID: userID,
|
||||
// Email: email,
|
||||
// }
|
||||
// verificationType.applyOnSetEmail(cmd)
|
||||
// return cmd
|
||||
// }
|
||||
|
||||
// // String implements [Commander].
|
||||
// func (cmd *SetEmailCommand) String() string {
|
||||
// return "SetEmailCommand"
|
||||
// }
|
||||
|
||||
// func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
|
||||
// close, err := opts.EnsureTx(ctx)
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
// defer func() { err = close(ctx, err) }()
|
||||
// // userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email)
|
||||
// repo := userRepo(opts.DB).Human()
|
||||
// err = repo.Update(ctx, repo.IDCondition(cmd.UserID), repo.SetEmailAddress(cmd.Email))
|
||||
// if err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// return opts.Invoke(ctx, cmd.verification)
|
||||
// }
|
||||
|
||||
// // Events implements [eventer].
|
||||
// func (cmd *SetEmailCommand) Events() []*eventstore.Event {
|
||||
// return []*eventstore.Event{
|
||||
// {
|
||||
// AggregateType: "user",
|
||||
// AggregateID: cmd.UserID,
|
||||
// Type: "user.email.set",
|
||||
// Payload: cmd,
|
||||
// },
|
||||
// }
|
||||
// }
|
||||
|
||||
// // applyOnCreateHuman implements [CreateHumanOpt].
|
||||
// func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand) {
|
||||
// createUserCmd.email = cmd
|
||||
// }
|
||||
241
backend/v3/domain/user.go
Normal file
241
backend/v3/domain/user.go
Normal file
@@ -0,0 +1,241 @@
|
||||
package domain
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
// userColumns define all the columns of the user table.
|
||||
type userColumns interface {
|
||||
// InstanceIDColumn returns the column for the instance id field.
|
||||
InstanceIDColumn() database.Column
|
||||
// OrgIDColumn returns the column for the org id field.
|
||||
OrgIDColumn() database.Column
|
||||
// IDColumn returns the column for the id field.
|
||||
IDColumn() database.Column
|
||||
// UsernameColumn returns the column for the username field.
|
||||
UsernameColumn() database.Column
|
||||
// CreatedAtColumn returns the column for the created at field.
|
||||
CreatedAtColumn() database.Column
|
||||
// UpdatedAtColumn returns the column for the updated at field.
|
||||
UpdatedAtColumn() database.Column
|
||||
// DeletedAtColumn returns the column for the deleted at field.
|
||||
DeletedAtColumn() database.Column
|
||||
}
|
||||
|
||||
// userConditions define all the conditions for the user table.
|
||||
type userConditions interface {
|
||||
// InstanceIDCondition returns an equal filter on the instance id field.
|
||||
InstanceIDCondition(instanceID string) database.Condition
|
||||
// OrgIDCondition returns an equal filter on the org id field.
|
||||
OrgIDCondition(orgID string) database.Condition
|
||||
// IDCondition returns an equal filter on the id field.
|
||||
IDCondition(userID string) database.Condition
|
||||
// UsernameCondition returns a filter on the username field.
|
||||
UsernameCondition(op database.TextOperation, username string) database.Condition
|
||||
// CreatedAtCondition returns a filter on the created at field.
|
||||
CreatedAtCondition(op database.NumberOperation, createdAt time.Time) database.Condition
|
||||
// UpdatedAtCondition returns a filter on the updated at field.
|
||||
UpdatedAtCondition(op database.NumberOperation, updatedAt time.Time) database.Condition
|
||||
// DeletedAtCondition filters for deleted users is isDeleted is set to true otherwise only not deleted users must be filtered.
|
||||
DeletedCondition(isDeleted bool) database.Condition
|
||||
// DeletedAtCondition filters for deleted users based on the given parameters.
|
||||
DeletedAtCondition(op database.NumberOperation, deletedAt time.Time) database.Condition
|
||||
}
|
||||
|
||||
// userChanges define all the changes for the user table.
|
||||
type userChanges interface {
|
||||
// SetUsername sets the username column.
|
||||
SetUsername(username string) database.Change
|
||||
}
|
||||
|
||||
// UserRepository is the interface for the user repository.
|
||||
type UserRepository interface {
|
||||
userColumns
|
||||
userConditions
|
||||
userChanges
|
||||
// Get returns a user based on the given condition.
|
||||
Get(ctx context.Context, opts ...database.QueryOption) (*User, error)
|
||||
// List returns a list of users based on the given condition.
|
||||
List(ctx context.Context, opts ...database.QueryOption) ([]*User, error)
|
||||
// Create creates a new user.
|
||||
Create(ctx context.Context, user *User) error
|
||||
// Delete removes users based on the given condition.
|
||||
Delete(ctx context.Context, condition database.Condition) error
|
||||
// Human returns the [HumanRepository].
|
||||
Human() HumanRepository
|
||||
// Machine returns the [MachineRepository].
|
||||
Machine() MachineRepository
|
||||
}
|
||||
|
||||
// humanColumns define all the columns of the human table which inherits the user table.
|
||||
type humanColumns interface {
|
||||
userColumns
|
||||
// FirstNameColumn returns the column for the first name field.
|
||||
FirstNameColumn() database.Column
|
||||
// LastNameColumn returns the column for the last name field.
|
||||
LastNameColumn() database.Column
|
||||
// EmailAddressColumn returns the column for the email address field.
|
||||
EmailAddressColumn() database.Column
|
||||
// EmailVerifiedAtColumn returns the column for the email verified at field.
|
||||
EmailVerifiedAtColumn() database.Column
|
||||
// PhoneNumberColumn returns the column for the phone number field.
|
||||
PhoneNumberColumn() database.Column
|
||||
// PhoneVerifiedAtColumn returns the column for the phone verified at field.
|
||||
PhoneVerifiedAtColumn() database.Column
|
||||
}
|
||||
|
||||
// humanConditions define all the conditions for the human table which inherits the user table.
|
||||
type humanConditions interface {
|
||||
userConditions
|
||||
// FirstNameCondition returns a filter on the first name field.
|
||||
FirstNameCondition(op database.TextOperation, firstName string) database.Condition
|
||||
// LastNameCondition returns a filter on the last name field.
|
||||
LastNameCondition(op database.TextOperation, lastName string) database.Condition
|
||||
// EmailAddressCondition returns a filter on the email address field.
|
||||
EmailAddressCondition(op database.TextOperation, email string) database.Condition
|
||||
// EmailVerifiedCondition returns a filter that checks if the email is verified or not.
|
||||
EmailVerifiedCondition(isVerified bool) database.Condition
|
||||
// EmailVerifiedAtCondition returns a filter on the email verified at field.
|
||||
EmailVerifiedAtCondition(op database.NumberOperation, emailVerifiedAt time.Time) database.Condition
|
||||
|
||||
// PhoneNumberCondition returns a filter on the phone number field.
|
||||
PhoneNumberCondition(op database.TextOperation, phoneNumber string) database.Condition
|
||||
// PhoneVerifiedCondition returns a filter that checks if the phone is verified or not.
|
||||
PhoneVerifiedCondition(isVerified bool) database.Condition
|
||||
// PhoneVerifiedAtCondition returns a filter on the phone verified at field.
|
||||
PhoneVerifiedAtCondition(op database.NumberOperation, phoneVerifiedAt time.Time) database.Condition
|
||||
}
|
||||
|
||||
// humanChanges define all the changes for the human table which inherits the user table.
|
||||
type humanChanges interface {
|
||||
userChanges
|
||||
// SetFirstName sets the first name field of the human.
|
||||
SetFirstName(firstName string) database.Change
|
||||
// SetLastName sets the last name field of the human.
|
||||
SetLastName(lastName string) database.Change
|
||||
|
||||
// SetEmail sets the email address and verified field of the email
|
||||
// if verifiedAt is nil the email is not verified
|
||||
SetEmail(address string, verifiedAt *time.Time) database.Change
|
||||
// SetEmailAddress sets the email address field of the email
|
||||
SetEmailAddress(email string) database.Change
|
||||
// SetEmailVerifiedAt sets the verified column of the email
|
||||
// if at is zero the statement uses the database timestamp
|
||||
SetEmailVerifiedAt(at time.Time) database.Change
|
||||
|
||||
// SetPhone sets the phone number and verified field
|
||||
// if verifiedAt is nil the phone is not verified
|
||||
SetPhone(number string, verifiedAt *time.Time) database.Change
|
||||
// SetPhoneNumber sets the phone number field
|
||||
SetPhoneNumber(phoneNumber string) database.Change
|
||||
// SetPhoneVerifiedAt sets the verified field of the phone
|
||||
// if at is zero the statement uses the database timestamp
|
||||
SetPhoneVerifiedAt(at time.Time) database.Change
|
||||
}
|
||||
|
||||
// HumanRepository is the interface for the human repository it inherits the user repository.
|
||||
type HumanRepository interface {
|
||||
humanColumns
|
||||
humanConditions
|
||||
humanChanges
|
||||
|
||||
// Get returns an email based on the given condition.
|
||||
GetEmail(ctx context.Context, condition database.Condition) (*Email, error)
|
||||
// Update updates human users based on the given condition and changes.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
|
||||
}
|
||||
|
||||
// machineColumns define all the columns of the machine table which inherits the user table.
|
||||
type machineColumns interface {
|
||||
userColumns
|
||||
// DescriptionColumn returns the column for the description field.
|
||||
DescriptionColumn() database.Column
|
||||
}
|
||||
|
||||
// machineConditions define all the conditions for the machine table which inherits the user table.
|
||||
type machineConditions interface {
|
||||
userConditions
|
||||
// DescriptionCondition returns a filter on the description field.
|
||||
DescriptionCondition(op database.TextOperation, description string) database.Condition
|
||||
}
|
||||
|
||||
// machineChanges define all the changes for the machine table which inherits the user table.
|
||||
type machineChanges interface {
|
||||
userChanges
|
||||
// SetDescription sets the description field of the machine.
|
||||
SetDescription(description string) database.Change
|
||||
}
|
||||
|
||||
// MachineRepository is the interface for the machine repository it inherits the user repository.
|
||||
type MachineRepository interface {
|
||||
// Update updates machine users based on the given condition and changes.
|
||||
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
|
||||
|
||||
machineColumns
|
||||
machineConditions
|
||||
machineChanges
|
||||
}
|
||||
|
||||
// UserTraits is implemented by [Human] and [Machine].
|
||||
type UserTraits interface {
|
||||
Type() UserType
|
||||
}
|
||||
|
||||
type UserType string
|
||||
|
||||
const (
|
||||
UserTypeHuman UserType = "human"
|
||||
UserTypeMachine UserType = "machine"
|
||||
)
|
||||
|
||||
type User struct {
|
||||
InstanceID string
|
||||
OrgID string
|
||||
ID string
|
||||
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
DeletedAt time.Time
|
||||
|
||||
Username string
|
||||
|
||||
Traits UserTraits
|
||||
}
|
||||
|
||||
type Human struct {
|
||||
FirstName string `json:"firstName"`
|
||||
LastName string `json:"lastName"`
|
||||
Email *Email `json:"email,omitempty"`
|
||||
Phone *Phone `json:"phone,omitempty"`
|
||||
}
|
||||
|
||||
// Type implements [UserTraits].
|
||||
func (h *Human) Type() UserType {
|
||||
return UserTypeHuman
|
||||
}
|
||||
|
||||
var _ UserTraits = (*Human)(nil)
|
||||
|
||||
type Email struct {
|
||||
Address string `json:"address"`
|
||||
VerifiedAt time.Time `json:"verifiedAt"`
|
||||
}
|
||||
|
||||
type Phone struct {
|
||||
Number string `json:"number"`
|
||||
VerifiedAt time.Time `json:"verifiedAt"`
|
||||
}
|
||||
|
||||
type Machine struct {
|
||||
Description string `json:"description"`
|
||||
}
|
||||
|
||||
// Type implements [UserTraits].
|
||||
func (m *Machine) Type() UserType {
|
||||
return UserTypeMachine
|
||||
}
|
||||
|
||||
var _ UserTraits = (*Machine)(nil)
|
||||
112
backend/v3/storage/cache/cache.go
vendored
Normal file
112
backend/v3/storage/cache/cache.go
vendored
Normal file
@@ -0,0 +1,112 @@
|
||||
// Package cache provides abstraction of cache implementations that can be used by zitadel.
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Purpose describes which object types are stored by a cache.
|
||||
type Purpose int
|
||||
|
||||
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
|
||||
const (
|
||||
PurposeUnspecified Purpose = iota
|
||||
PurposeAuthzInstance
|
||||
PurposeMilestones
|
||||
PurposeOrganization
|
||||
PurposeIdPFormCallback
|
||||
)
|
||||
|
||||
// Cache stores objects with a value of type `V`.
|
||||
// Objects may be referred to by one or more indices.
|
||||
// Implementations may encode the value for storage.
|
||||
// This means non-exported fields may be lost and objects
|
||||
// with function values may fail to encode.
|
||||
// See https://pkg.go.dev/encoding/json#Marshal for example.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
// Indices are defined when calling the constructor of an implementation of this interface.
|
||||
// It is illegal to refer to an idex not defined during construction.
|
||||
//
|
||||
// `K` is the type used as key in each index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
//
|
||||
// Implementations are free to use stricter type constraints or fixed typing.
|
||||
type Cache[I, K comparable, V Entry[I, K]] interface {
|
||||
// Get an object through specified index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// [ErrCacheMiss] is returned if the key was not found in the index,
|
||||
// or the object is not valid.
|
||||
Get(ctx context.Context, index I, key K) (V, bool)
|
||||
|
||||
// Set an object.
|
||||
// Keys are created on each index based in the [Entry.Keys] method.
|
||||
// If any key maps to an existing object, the object is invalidated,
|
||||
// regardless if the object has other keys defined in the new entry.
|
||||
// This to prevent ghost objects when an entry reduces the amount of keys
|
||||
// for a given index.
|
||||
Set(ctx context.Context, value V)
|
||||
|
||||
// Invalidate an object through specified index.
|
||||
// Implementations may choose to instantly delete the object,
|
||||
// defer until prune or a separate cleanup routine.
|
||||
// Invalidated object are no longer returned from Get.
|
||||
// It is safe to call Invalidate multiple times or on non-existing entries.
|
||||
Invalidate(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Delete one or more keys from a specific index.
|
||||
// An [IndexUnknownError] may be returned if the index is unknown.
|
||||
// The referred object is not invalidated and may still be accessible though
|
||||
// other indices and keys.
|
||||
// It is safe to call Delete multiple times or on non-existing entries
|
||||
Delete(ctx context.Context, index I, key ...K) error
|
||||
|
||||
// Truncate deletes all cached objects.
|
||||
Truncate(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Entry contains a value of type `V` to be cached.
|
||||
//
|
||||
// `I` is the type by which indices are identified,
|
||||
// typically an enum for type-safe access.
|
||||
//
|
||||
// `K` is the type used as key in an index.
|
||||
// Due to the limitations in type constraints, all indices use the same key type.
|
||||
type Entry[I, K comparable] interface {
|
||||
// Keys returns which keys map to the object in a specified index.
|
||||
// May return nil if the index in unknown or when there are no keys.
|
||||
Keys(index I) (key []K)
|
||||
}
|
||||
|
||||
type Connector int
|
||||
|
||||
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
|
||||
const (
|
||||
// Empty line comment ensures empty string for unspecified value
|
||||
ConnectorUnspecified Connector = iota //
|
||||
ConnectorMemory
|
||||
ConnectorPostgres
|
||||
ConnectorRedis
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Connector Connector
|
||||
|
||||
// Age since an object was added to the cache,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables max age checks.
|
||||
MaxAge time.Duration
|
||||
|
||||
// Age since last use (Get) of an object,
|
||||
// after which the object is considered invalid.
|
||||
// 0 disables last use age checks.
|
||||
LastUseAge time.Duration
|
||||
|
||||
// Log allows logging of the specific cache.
|
||||
// By default only errors are logged to stdout.
|
||||
Log *logging.Config
|
||||
}
|
||||
49
backend/v3/storage/cache/connector/connector.go
vendored
Normal file
49
backend/v3/storage/cache/connector/connector.go
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
|
||||
package connector
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop"
|
||||
)
|
||||
|
||||
type CachesConfig struct {
|
||||
Connectors struct {
|
||||
Memory gomap.Config
|
||||
}
|
||||
Instance *cache.Config
|
||||
Milestones *cache.Config
|
||||
Organization *cache.Config
|
||||
IdPFormCallbacks *cache.Config
|
||||
}
|
||||
|
||||
type Connectors struct {
|
||||
Config CachesConfig
|
||||
Memory *gomap.Connector
|
||||
}
|
||||
|
||||
func StartConnectors(conf *CachesConfig) (Connectors, error) {
|
||||
if conf == nil {
|
||||
return Connectors{}, nil
|
||||
}
|
||||
return Connectors{
|
||||
Config: *conf,
|
||||
Memory: gomap.NewConnector(conf.Connectors.Memory),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
|
||||
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
|
||||
return noop.NewCache[I, K, V](), nil
|
||||
}
|
||||
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
|
||||
c := gomap.NewCache[I, K, V](background, indices, *conf)
|
||||
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
|
||||
return c, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
|
||||
}
|
||||
23
backend/v3/storage/cache/connector/gomap/connector.go
vendored
Normal file
23
backend/v3/storage/cache/connector/gomap/connector.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
Enabled bool
|
||||
AutoPrune cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
type Connector struct {
|
||||
Config cache.AutoPruneConfig
|
||||
}
|
||||
|
||||
func NewConnector(config Config) *Connector {
|
||||
if !config.Enabled {
|
||||
return nil
|
||||
}
|
||||
return &Connector{
|
||||
Config: config.AutoPrune,
|
||||
}
|
||||
}
|
||||
200
backend/v3/storage/cache/connector/gomap/gomap.go
vendored
Normal file
200
backend/v3/storage/cache/connector/gomap/gomap.go
vendored
Normal file
@@ -0,0 +1,200 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"log/slog"
|
||||
"maps"
|
||||
"os"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
|
||||
config *cache.Config
|
||||
indexMap map[I]*index[K, V]
|
||||
logger *slog.Logger
|
||||
}
|
||||
|
||||
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
|
||||
// Object values are stored as-is and there is no encoding or decoding involved.
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
|
||||
m := &mapCache[I, K, V]{
|
||||
config: &config,
|
||||
indexMap: make(map[I]*index[K, V], len(indices)),
|
||||
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
|
||||
AddSource: true,
|
||||
Level: slog.LevelError,
|
||||
})),
|
||||
}
|
||||
if config.Log != nil {
|
||||
m.logger = config.Log.Slog()
|
||||
}
|
||||
m.logger.InfoContext(background, "map cache logging enabled")
|
||||
|
||||
for _, name := range indices {
|
||||
m.indexMap[name] = &index[K, V]{
|
||||
config: m.config,
|
||||
entries: make(map[K]*entry[V]),
|
||||
}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
entry, err := i.Get(key)
|
||||
if err == nil {
|
||||
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
|
||||
return entry.value, true
|
||||
}
|
||||
if errors.Is(err, cache.ErrCacheMiss) {
|
||||
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
|
||||
return value, false
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
|
||||
now := time.Now()
|
||||
entry := &entry[V]{
|
||||
value: value,
|
||||
created: now,
|
||||
}
|
||||
entry.lastUse.Store(now.UnixMicro())
|
||||
|
||||
for name, i := range c.indexMap {
|
||||
keys := value.Keys(name)
|
||||
i.Set(keys, entry)
|
||||
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
|
||||
}
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Invalidate(keys)
|
||||
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
|
||||
i, ok := c.indexMap[index]
|
||||
if !ok {
|
||||
return cache.NewIndexUnknownErr(index)
|
||||
}
|
||||
i.Delete(keys)
|
||||
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Prune()
|
||||
c.logger.DebugContext(ctx, "map cache prune", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
|
||||
for name, index := range c.indexMap {
|
||||
index.Truncate()
|
||||
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type index[K comparable, V any] struct {
|
||||
mutex sync.RWMutex
|
||||
config *cache.Config
|
||||
entries map[K]*entry[V]
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Get(key K) (*entry[V], error) {
|
||||
i.mutex.RLock()
|
||||
entry, ok := i.entries[key]
|
||||
i.mutex.RUnlock()
|
||||
if ok && entry.isValid(i.config) {
|
||||
return entry, nil
|
||||
}
|
||||
return nil, cache.ErrCacheMiss
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
c.entries[key] = entry
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (i *index[K, V]) Invalidate(keys []K) {
|
||||
i.mutex.RLock()
|
||||
for _, key := range keys {
|
||||
if entry, ok := i.entries[key]; ok {
|
||||
entry.invalid.Store(true)
|
||||
}
|
||||
}
|
||||
i.mutex.RUnlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Delete(keys []K) {
|
||||
c.mutex.Lock()
|
||||
for _, key := range keys {
|
||||
delete(c.entries, key)
|
||||
}
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Prune() {
|
||||
c.mutex.Lock()
|
||||
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
|
||||
return !entry.isValid(c.config)
|
||||
})
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
func (c *index[K, V]) Truncate() {
|
||||
c.mutex.Lock()
|
||||
c.entries = make(map[K]*entry[V])
|
||||
c.mutex.Unlock()
|
||||
}
|
||||
|
||||
type entry[V any] struct {
|
||||
value V
|
||||
created time.Time
|
||||
invalid atomic.Bool
|
||||
lastUse atomic.Int64 // UnixMicro time
|
||||
}
|
||||
|
||||
func (e *entry[V]) isValid(c *cache.Config) bool {
|
||||
if e.invalid.Load() {
|
||||
return false
|
||||
}
|
||||
now := time.Now()
|
||||
if c.MaxAge > 0 {
|
||||
if e.created.Add(c.MaxAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
}
|
||||
if c.LastUseAge > 0 {
|
||||
lastUse := e.lastUse.Load()
|
||||
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
|
||||
e.invalid.Store(true)
|
||||
return false
|
||||
}
|
||||
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
|
||||
}
|
||||
return true
|
||||
}
|
||||
329
backend/v3/storage/cache/connector/gomap/gomap_test.go
vendored
Normal file
329
backend/v3/storage/cache/connector/gomap/gomap_test.go
vendored
Normal file
@@ -0,0 +1,329 @@
|
||||
package gomap
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type testIndex int
|
||||
|
||||
const (
|
||||
testIndexID testIndex = iota
|
||||
testIndexName
|
||||
)
|
||||
|
||||
var testIndices = []testIndex{
|
||||
testIndexID,
|
||||
testIndexName,
|
||||
}
|
||||
|
||||
type testObject struct {
|
||||
id string
|
||||
names []string
|
||||
}
|
||||
|
||||
func (o *testObject) Keys(index testIndex) []string {
|
||||
switch index {
|
||||
case testIndexID:
|
||||
return []string{o.id}
|
||||
case testIndexName:
|
||||
return o.names
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Get(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
|
||||
type args struct {
|
||||
index testIndex
|
||||
key string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want *testObject
|
||||
wantOk bool
|
||||
}{
|
||||
{
|
||||
name: "ok",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "id",
|
||||
},
|
||||
want: obj,
|
||||
wantOk: true,
|
||||
},
|
||||
{
|
||||
name: "miss",
|
||||
args: args{
|
||||
index: testIndexID,
|
||||
key: "spanac",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
{
|
||||
name: "unknown index",
|
||||
args: args{
|
||||
index: 99,
|
||||
key: "id",
|
||||
},
|
||||
want: nil,
|
||||
wantOk: false,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
|
||||
assert.Equal(t, tt.want, got)
|
||||
assert.Equal(t, tt.wantOk, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_mapCache_Invalidate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Delete(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
obj := &testObject{
|
||||
id: "id",
|
||||
names: []string{"foo", "bar"},
|
||||
}
|
||||
c.Set(context.Background(), obj)
|
||||
err := c.Delete(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Shouldn't find object by deleted name
|
||||
got, ok := c.Get(context.Background(), testIndexName, "bar")
|
||||
assert.Nil(t, got)
|
||||
assert.False(t, ok)
|
||||
|
||||
// Should find object by other name
|
||||
got, ok = c.Get(context.Background(), testIndexName, "foo")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
|
||||
// Should find object by id
|
||||
got, ok = c.Get(context.Background(), testIndexID, "id")
|
||||
assert.Equal(t, obj, got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Prune(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
// invalidate one entry
|
||||
err := c.Invalidate(context.Background(), testIndexName, "bar")
|
||||
require.NoError(t, err)
|
||||
|
||||
err = c.(cache.Pruner).Prune(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
// Other object should still be found
|
||||
got, ok := c.Get(context.Background(), testIndexID, "id2")
|
||||
assert.Equal(t, objects[1], got)
|
||||
assert.True(t, ok)
|
||||
}
|
||||
|
||||
func Test_mapCache_Truncate(t *testing.T) {
|
||||
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
|
||||
MaxAge: time.Second,
|
||||
LastUseAge: time.Second / 4,
|
||||
Log: &logging.Config{
|
||||
Level: "debug",
|
||||
AddSource: true,
|
||||
},
|
||||
})
|
||||
objects := []*testObject{
|
||||
{
|
||||
id: "id1",
|
||||
names: []string{"foo", "bar"},
|
||||
},
|
||||
{
|
||||
id: "id2",
|
||||
names: []string{"hello"},
|
||||
},
|
||||
}
|
||||
for _, obj := range objects {
|
||||
c.Set(context.Background(), obj)
|
||||
}
|
||||
|
||||
err := c.Truncate(context.Background())
|
||||
require.NoError(t, err)
|
||||
|
||||
mc := c.(*mapCache[testIndex, string, *testObject])
|
||||
for _, index := range mc.indexMap {
|
||||
index.mutex.RLock()
|
||||
assert.Len(t, index.entries, 0)
|
||||
index.mutex.RUnlock()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_entry_isValid(t *testing.T) {
|
||||
type fields struct {
|
||||
created time.Time
|
||||
invalid bool
|
||||
lastUse time.Time
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
fields fields
|
||||
config *cache.Config
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "invalid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: true,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "max age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute + time.Second)),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "last use age exceeded",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "last use age disabled",
|
||||
fields: fields{
|
||||
created: time.Now().Add(-(time.Minute / 2)),
|
||||
invalid: false,
|
||||
lastUse: time.Now().Add(-(time.Second * 2)),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "valid",
|
||||
fields: fields{
|
||||
created: time.Now(),
|
||||
invalid: false,
|
||||
lastUse: time.Now(),
|
||||
},
|
||||
config: &cache.Config{
|
||||
MaxAge: time.Minute,
|
||||
LastUseAge: time.Second,
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
e := &entry[any]{
|
||||
created: tt.fields.created,
|
||||
}
|
||||
e.invalid.Store(tt.fields.invalid)
|
||||
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
|
||||
got := e.isValid(tt.config)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
21
backend/v3/storage/cache/connector/noop/noop.go
vendored
Normal file
21
backend/v3/storage/cache/connector/noop/noop.go
vendored
Normal file
@@ -0,0 +1,21 @@
|
||||
package noop
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/cache"
|
||||
)
|
||||
|
||||
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
|
||||
|
||||
// NewCache returns a cache that does nothing
|
||||
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
|
||||
return noop[I, K, V]{}
|
||||
}
|
||||
|
||||
func (noop[I, K, V]) Set(context.Context, V) {}
|
||||
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
|
||||
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
|
||||
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
|
||||
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }
|
||||
98
backend/v3/storage/cache/connector_enumer.go
vendored
Normal file
98
backend/v3/storage/cache/connector_enumer.go
vendored
Normal file
@@ -0,0 +1,98 @@
|
||||
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _ConnectorName = "memorypostgresredis"
|
||||
|
||||
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
|
||||
|
||||
const _ConnectorLowerName = "memorypostgresredis"
|
||||
|
||||
func (i Connector) String() string {
|
||||
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
|
||||
return fmt.Sprintf("Connector(%d)", i)
|
||||
}
|
||||
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _ConnectorNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[ConnectorUnspecified-(0)]
|
||||
_ = x[ConnectorMemory-(1)]
|
||||
_ = x[ConnectorPostgres-(2)]
|
||||
_ = x[ConnectorRedis-(3)]
|
||||
}
|
||||
|
||||
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
|
||||
|
||||
var _ConnectorNameToValueMap = map[string]Connector{
|
||||
_ConnectorName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorLowerName[0:0]: ConnectorUnspecified,
|
||||
_ConnectorName[0:6]: ConnectorMemory,
|
||||
_ConnectorLowerName[0:6]: ConnectorMemory,
|
||||
_ConnectorName[6:14]: ConnectorPostgres,
|
||||
_ConnectorLowerName[6:14]: ConnectorPostgres,
|
||||
_ConnectorName[14:19]: ConnectorRedis,
|
||||
_ConnectorLowerName[14:19]: ConnectorRedis,
|
||||
}
|
||||
|
||||
var _ConnectorNames = []string{
|
||||
_ConnectorName[0:0],
|
||||
_ConnectorName[0:6],
|
||||
_ConnectorName[6:14],
|
||||
_ConnectorName[14:19],
|
||||
}
|
||||
|
||||
// ConnectorString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func ConnectorString(s string) (Connector, error) {
|
||||
if val, ok := _ConnectorNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Connector values", s)
|
||||
}
|
||||
|
||||
// ConnectorValues returns all values of the enum
|
||||
func ConnectorValues() []Connector {
|
||||
return _ConnectorValues
|
||||
}
|
||||
|
||||
// ConnectorStrings returns a slice of all String values of the enum
|
||||
func ConnectorStrings() []string {
|
||||
strs := make([]string, len(_ConnectorNames))
|
||||
copy(strs, _ConnectorNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Connector) IsAConnector() bool {
|
||||
for _, v := range _ConnectorValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// MarshalText implements the encoding.TextMarshaler interface for Connector
|
||||
func (i Connector) MarshalText() ([]byte, error) {
|
||||
return []byte(i.String()), nil
|
||||
}
|
||||
|
||||
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
|
||||
func (i *Connector) UnmarshalText(text []byte) error {
|
||||
var err error
|
||||
*i, err = ConnectorString(string(text))
|
||||
return err
|
||||
}
|
||||
2
backend/v3/storage/cache/doc.go
vendored
Normal file
2
backend/v3/storage/cache/doc.go
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// this package is copy pasted from the internal/cache package
|
||||
package cache
|
||||
29
backend/v3/storage/cache/error.go
vendored
Normal file
29
backend/v3/storage/cache/error.go
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type IndexUnknownError[I comparable] struct {
|
||||
index I
|
||||
}
|
||||
|
||||
func NewIndexUnknownErr[I comparable](index I) error {
|
||||
return IndexUnknownError[I]{index}
|
||||
}
|
||||
|
||||
func (i IndexUnknownError[I]) Error() string {
|
||||
return fmt.Sprintf("index %v unknown", i.index)
|
||||
}
|
||||
|
||||
func (a IndexUnknownError[I]) Is(err error) bool {
|
||||
if b, ok := err.(IndexUnknownError[I]); ok {
|
||||
return a.index == b.index
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var (
|
||||
ErrCacheMiss = errors.New("cache miss")
|
||||
)
|
||||
76
backend/v3/storage/cache/pruner.go
vendored
Normal file
76
backend/v3/storage/cache/pruner.go
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/zitadel/logging"
|
||||
)
|
||||
|
||||
// Pruner is an optional [Cache] interface.
|
||||
type Pruner interface {
|
||||
// Prune deletes all invalidated or expired objects.
|
||||
Prune(ctx context.Context) error
|
||||
}
|
||||
|
||||
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
|
||||
Cache[I, K, V]
|
||||
Pruner
|
||||
}
|
||||
|
||||
type AutoPruneConfig struct {
|
||||
// Interval at which the cache is automatically pruned.
|
||||
// 0 or lower disables automatic pruning.
|
||||
Interval time.Duration
|
||||
|
||||
// Timeout for an automatic prune.
|
||||
// It is recommended to keep the value shorter than AutoPruneInterval
|
||||
// 0 or lower disables automatic pruning.
|
||||
Timeout time.Duration
|
||||
}
|
||||
|
||||
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
|
||||
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
|
||||
if c.Interval <= 0 {
|
||||
return func() {}
|
||||
}
|
||||
background, cancel := context.WithCancel(background)
|
||||
// randomize the first interval
|
||||
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
|
||||
go c.pruneTimer(background, pruner, purpose, timer)
|
||||
return cancel
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
|
||||
defer func() {
|
||||
if !timer.Stop() {
|
||||
<-timer.Chan()
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-background.Done():
|
||||
return
|
||||
case <-timer.Chan():
|
||||
err := c.doPrune(background, pruner)
|
||||
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
|
||||
timer.Reset(c.Interval)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
|
||||
ctx, cancel := context.WithCancel(background)
|
||||
defer cancel()
|
||||
if c.Timeout > 0 {
|
||||
ctx, cancel = context.WithTimeout(background, c.Timeout)
|
||||
defer cancel()
|
||||
}
|
||||
return pruner.Prune(ctx)
|
||||
}
|
||||
43
backend/v3/storage/cache/pruner_test.go
vendored
Normal file
43
backend/v3/storage/cache/pruner_test.go
vendored
Normal file
@@ -0,0 +1,43 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jonboulle/clockwork"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
type testPruner struct {
|
||||
called chan struct{}
|
||||
}
|
||||
|
||||
func (p *testPruner) Prune(context.Context) error {
|
||||
p.called <- struct{}{}
|
||||
return nil
|
||||
}
|
||||
|
||||
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
|
||||
c := AutoPruneConfig{
|
||||
Interval: time.Second,
|
||||
Timeout: time.Millisecond,
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
|
||||
pruner := testPruner{
|
||||
called: make(chan struct{}),
|
||||
}
|
||||
clock := clockwork.NewFakeClock()
|
||||
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
|
||||
defer close()
|
||||
clock.Advance(time.Second)
|
||||
|
||||
select {
|
||||
case _, ok := <-pruner.called:
|
||||
assert.True(t, ok)
|
||||
case <-ctx.Done():
|
||||
t.Fatal(ctx.Err())
|
||||
}
|
||||
}
|
||||
90
backend/v3/storage/cache/purpose_enumer.go
vendored
Normal file
90
backend/v3/storage/cache/purpose_enumer.go
vendored
Normal file
@@ -0,0 +1,90 @@
|
||||
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
|
||||
|
||||
package cache
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
|
||||
|
||||
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
|
||||
|
||||
func (i Purpose) String() string {
|
||||
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
|
||||
return fmt.Sprintf("Purpose(%d)", i)
|
||||
}
|
||||
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
|
||||
}
|
||||
|
||||
// An "invalid array index" compiler error signifies that the constant values have changed.
|
||||
// Re-run the stringer command to generate them again.
|
||||
func _PurposeNoOp() {
|
||||
var x [1]struct{}
|
||||
_ = x[PurposeUnspecified-(0)]
|
||||
_ = x[PurposeAuthzInstance-(1)]
|
||||
_ = x[PurposeMilestones-(2)]
|
||||
_ = x[PurposeOrganization-(3)]
|
||||
_ = x[PurposeIdPFormCallback-(4)]
|
||||
}
|
||||
|
||||
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
|
||||
|
||||
var _PurposeNameToValueMap = map[string]Purpose{
|
||||
_PurposeName[0:11]: PurposeUnspecified,
|
||||
_PurposeLowerName[0:11]: PurposeUnspecified,
|
||||
_PurposeName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeLowerName[11:25]: PurposeAuthzInstance,
|
||||
_PurposeName[25:35]: PurposeMilestones,
|
||||
_PurposeLowerName[25:35]: PurposeMilestones,
|
||||
_PurposeName[35:47]: PurposeOrganization,
|
||||
_PurposeLowerName[35:47]: PurposeOrganization,
|
||||
_PurposeName[47:65]: PurposeIdPFormCallback,
|
||||
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
|
||||
}
|
||||
|
||||
var _PurposeNames = []string{
|
||||
_PurposeName[0:11],
|
||||
_PurposeName[11:25],
|
||||
_PurposeName[25:35],
|
||||
_PurposeName[35:47],
|
||||
_PurposeName[47:65],
|
||||
}
|
||||
|
||||
// PurposeString retrieves an enum value from the enum constants string name.
|
||||
// Throws an error if the param is not part of the enum.
|
||||
func PurposeString(s string) (Purpose, error) {
|
||||
if val, ok := _PurposeNameToValueMap[s]; ok {
|
||||
return val, nil
|
||||
}
|
||||
|
||||
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
|
||||
return val, nil
|
||||
}
|
||||
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
|
||||
}
|
||||
|
||||
// PurposeValues returns all values of the enum
|
||||
func PurposeValues() []Purpose {
|
||||
return _PurposeValues
|
||||
}
|
||||
|
||||
// PurposeStrings returns a slice of all String values of the enum
|
||||
func PurposeStrings() []string {
|
||||
strs := make([]string, len(_PurposeNames))
|
||||
copy(strs, _PurposeNames)
|
||||
return strs
|
||||
}
|
||||
|
||||
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
|
||||
func (i Purpose) IsAPurpose() bool {
|
||||
for _, v := range _PurposeValues {
|
||||
if i == v {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
53
backend/v3/storage/database/change.go
Normal file
53
backend/v3/storage/database/change.go
Normal file
@@ -0,0 +1,53 @@
|
||||
package database
|
||||
|
||||
// Change represents a change to a column in a database table.
|
||||
// Its written in the SET clause of an UPDATE statement.
|
||||
type Change interface {
|
||||
Write(builder *StatementBuilder)
|
||||
}
|
||||
|
||||
type change[V Value] struct {
|
||||
column Column
|
||||
value V
|
||||
}
|
||||
|
||||
var _ Change = (*change[string])(nil)
|
||||
|
||||
func NewChange[V Value](col Column, value V) Change {
|
||||
return &change[V]{
|
||||
column: col,
|
||||
value: value,
|
||||
}
|
||||
}
|
||||
|
||||
func NewChangePtr[V Value](col Column, value *V) Change {
|
||||
if value == nil {
|
||||
return NewChange(col, NullInstruction)
|
||||
}
|
||||
return NewChange(col, *value)
|
||||
}
|
||||
|
||||
// Write implements [Change].
|
||||
func (c change[V]) Write(builder *StatementBuilder) {
|
||||
c.column.WriteUnqualified(builder)
|
||||
builder.WriteString(" = ")
|
||||
builder.WriteArg(c.value)
|
||||
}
|
||||
|
||||
type Changes []Change
|
||||
|
||||
func NewChanges(cols ...Change) Change {
|
||||
return Changes(cols)
|
||||
}
|
||||
|
||||
// Write implements [Change].
|
||||
func (m Changes) Write(builder *StatementBuilder) {
|
||||
for i, col := range m {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
col.Write(builder)
|
||||
}
|
||||
}
|
||||
|
||||
var _ Change = Changes(nil)
|
||||
85
backend/v3/storage/database/column.go
Normal file
85
backend/v3/storage/database/column.go
Normal file
@@ -0,0 +1,85 @@
|
||||
package database
|
||||
|
||||
type Columns []Column
|
||||
|
||||
// WriteQualified implements [Column].
|
||||
func (m Columns) WriteQualified(builder *StatementBuilder) {
|
||||
for i, col := range m {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
col.WriteQualified(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// WriteUnqualified implements [Column].
|
||||
func (m Columns) WriteUnqualified(builder *StatementBuilder) {
|
||||
for i, col := range m {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
col.WriteUnqualified(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// Column represents a column in a database table.
|
||||
type Column interface {
|
||||
// Write(builder *StatementBuilder)
|
||||
WriteQualified(builder *StatementBuilder)
|
||||
WriteUnqualified(builder *StatementBuilder)
|
||||
}
|
||||
|
||||
type column struct {
|
||||
table string
|
||||
name string
|
||||
}
|
||||
|
||||
func NewColumn(table, name string) Column {
|
||||
return column{table: table, name: name}
|
||||
}
|
||||
|
||||
// WriteQualified implements [Column].
|
||||
func (c column) WriteQualified(builder *StatementBuilder) {
|
||||
builder.Grow(len(c.table) + len(c.name) + 1)
|
||||
builder.WriteString(c.table)
|
||||
builder.WriteRune('.')
|
||||
builder.WriteString(c.name)
|
||||
}
|
||||
|
||||
// WriteUnqualified implements [Column].
|
||||
func (c column) WriteUnqualified(builder *StatementBuilder) {
|
||||
builder.WriteString(c.name)
|
||||
}
|
||||
|
||||
var _ Column = (*column)(nil)
|
||||
|
||||
// // ignoreCaseColumn represents two database columns, one for the
|
||||
// // original value and one for the lower case value.
|
||||
// type ignoreCaseColumn interface {
|
||||
// Column
|
||||
// WriteIgnoreCase(builder *StatementBuilder)
|
||||
// }
|
||||
|
||||
// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn {
|
||||
// return ignoreCaseCol{
|
||||
// column: col,
|
||||
// suffix: suffix,
|
||||
// }
|
||||
// }
|
||||
|
||||
// type ignoreCaseCol struct {
|
||||
// column Column
|
||||
// suffix string
|
||||
// }
|
||||
|
||||
// // WriteIgnoreCase implements [ignoreCaseColumn].
|
||||
// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) {
|
||||
// c.column.WriteQualified(builder)
|
||||
// builder.WriteString(c.suffix)
|
||||
// }
|
||||
|
||||
// // WriteQualified implements [ignoreCaseColumn].
|
||||
// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) {
|
||||
// c.column.WriteQualified(builder)
|
||||
// builder.WriteString(c.suffix)
|
||||
// }
|
||||
130
backend/v3/storage/database/condition.go
Normal file
130
backend/v3/storage/database/condition.go
Normal file
@@ -0,0 +1,130 @@
|
||||
package database
|
||||
|
||||
// Condition represents a SQL condition.
|
||||
// Its written after the WHERE keyword in a SQL statement.
|
||||
type Condition interface {
|
||||
Write(builder *StatementBuilder)
|
||||
}
|
||||
|
||||
type and struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (a *and) Write(builder *StatementBuilder) {
|
||||
if len(a.conditions) > 1 {
|
||||
builder.WriteString("(")
|
||||
defer builder.WriteString(")")
|
||||
}
|
||||
for i, condition := range a.conditions {
|
||||
if i > 0 {
|
||||
builder.WriteString(" AND ")
|
||||
}
|
||||
condition.Write(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// And combines multiple conditions with AND.
|
||||
func And(conditions ...Condition) *and {
|
||||
return &and{conditions: conditions}
|
||||
}
|
||||
|
||||
var _ Condition = (*and)(nil)
|
||||
|
||||
type or struct {
|
||||
conditions []Condition
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (o *or) Write(builder *StatementBuilder) {
|
||||
if len(o.conditions) > 1 {
|
||||
builder.WriteString("(")
|
||||
defer builder.WriteString(")")
|
||||
}
|
||||
for i, condition := range o.conditions {
|
||||
if i > 0 {
|
||||
builder.WriteString(" OR ")
|
||||
}
|
||||
condition.Write(builder)
|
||||
}
|
||||
}
|
||||
|
||||
// Or combines multiple conditions with OR.
|
||||
func Or(conditions ...Condition) *or {
|
||||
return &or{conditions: conditions}
|
||||
}
|
||||
|
||||
var _ Condition = (*or)(nil)
|
||||
|
||||
type isNull struct {
|
||||
column Column
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (i *isNull) Write(builder *StatementBuilder) {
|
||||
i.column.WriteQualified(builder)
|
||||
builder.WriteString(" IS NULL")
|
||||
}
|
||||
|
||||
// IsNull creates a condition that checks if a column is NULL.
|
||||
func IsNull(column Column) *isNull {
|
||||
return &isNull{column: column}
|
||||
}
|
||||
|
||||
var _ Condition = (*isNull)(nil)
|
||||
|
||||
type isNotNull struct {
|
||||
column Column
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (i *isNotNull) Write(builder *StatementBuilder) {
|
||||
i.column.WriteQualified(builder)
|
||||
builder.WriteString(" IS NOT NULL")
|
||||
}
|
||||
|
||||
// IsNotNull creates a condition that checks if a column is NOT NULL.
|
||||
func IsNotNull(column Column) *isNotNull {
|
||||
return &isNotNull{column: column}
|
||||
}
|
||||
|
||||
var _ Condition = (*isNotNull)(nil)
|
||||
|
||||
type valueCondition func(builder *StatementBuilder)
|
||||
|
||||
// NewTextCondition creates a condition that compares a text column with a value.
|
||||
func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
writeTextOperation(builder, col, op, value)
|
||||
})
|
||||
}
|
||||
|
||||
// NewDateCondition creates a condition that compares a numeric column with a value.
|
||||
func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
writeNumberOperation(builder, col, op, value)
|
||||
})
|
||||
}
|
||||
|
||||
// NewDateCondition creates a condition that compares a boolean column with a value.
|
||||
func NewBooleanCondition[V Boolean](col Column, value V) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
writeBooleanOperation(builder, col, value)
|
||||
})
|
||||
}
|
||||
|
||||
// NewColumnCondition creates a condition that compares two columns on equality.
|
||||
func NewColumnCondition(col1, col2 Column) Condition {
|
||||
return valueCondition(func(builder *StatementBuilder) {
|
||||
col1.WriteQualified(builder)
|
||||
builder.WriteString(" = ")
|
||||
col2.WriteQualified(builder)
|
||||
})
|
||||
}
|
||||
|
||||
// Write implements [Condition].
|
||||
func (c valueCondition) Write(builder *StatementBuilder) {
|
||||
c(builder)
|
||||
}
|
||||
|
||||
var _ Condition = (*valueCondition)(nil)
|
||||
10
backend/v3/storage/database/config.go
Normal file
10
backend/v3/storage/database/config.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Connector abstracts the database driver.
|
||||
type Connector interface {
|
||||
Connect(ctx context.Context) (Pool, error)
|
||||
}
|
||||
83
backend/v3/storage/database/database.go
Normal file
83
backend/v3/storage/database/database.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"context"
|
||||
)
|
||||
|
||||
// Pool is a connection pool. e.g. pgxpool
|
||||
type Pool interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
Migrator
|
||||
|
||||
Acquire(ctx context.Context) (Client, error)
|
||||
Close(ctx context.Context) error
|
||||
}
|
||||
|
||||
type PoolTest interface {
|
||||
Pool
|
||||
// MigrateTest is the same as [Migrator] but executes the migrations multiple times instead of only once.
|
||||
MigrateTest(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Client is a single database connection which can be released back to the pool.
|
||||
type Client interface {
|
||||
Beginner
|
||||
QueryExecutor
|
||||
Migrator
|
||||
|
||||
Release(ctx context.Context) error
|
||||
}
|
||||
|
||||
// Querier is a database client that can execute queries and return rows.
|
||||
type Querier interface {
|
||||
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
|
||||
QueryRow(ctx context.Context, stmt string, args ...any) Row
|
||||
}
|
||||
|
||||
// Executor is a database client that can execute statements.
|
||||
// It returns the number of rows affected or an error
|
||||
type Executor interface {
|
||||
Exec(ctx context.Context, stmt string, args ...any) (int64, error)
|
||||
}
|
||||
|
||||
// QueryExecutor is a database client that can execute queries and statements.
|
||||
type QueryExecutor interface {
|
||||
Querier
|
||||
Executor
|
||||
}
|
||||
|
||||
// Scanner scans a single row of data into the destination.
|
||||
type Scanner interface {
|
||||
Scan(dest ...any) error
|
||||
}
|
||||
|
||||
// Row is an abstraction of sql.Row.
|
||||
type Row interface {
|
||||
Scanner
|
||||
}
|
||||
|
||||
// Rows is an abstraction of sql.Rows.
|
||||
type Rows interface {
|
||||
Scanner
|
||||
Next() bool
|
||||
Close() error
|
||||
Err() error
|
||||
}
|
||||
|
||||
type CollectableRows interface {
|
||||
// Collect collects all rows and scans them into dest.
|
||||
// dest must be a pointer to a slice of pointer to structs
|
||||
// e.g. *[]*MyStruct
|
||||
// Rows are closed after this call.
|
||||
Collect(dest any) error
|
||||
// CollectFirst collects the first row and scans it into dest.
|
||||
// dest must be a pointer to a struct
|
||||
// e.g. *MyStruct{}
|
||||
// Rows are closed after this call.
|
||||
CollectFirst(dest any) error
|
||||
// CollectExactlyOneRow collects exactly one row and scans it into dest.
|
||||
// e.g. *MyStruct{}
|
||||
// Rows are closed after this call.
|
||||
CollectExactlyOneRow(dest any) error
|
||||
}
|
||||
1146
backend/v3/storage/database/dbmock/database.mock.go
Normal file
1146
backend/v3/storage/database/dbmock/database.mock.go
Normal file
File diff suppressed because it is too large
Load Diff
92
backend/v3/storage/database/dialect/config.go
Normal file
92
backend/v3/storage/database/dialect/config.go
Normal file
@@ -0,0 +1,92 @@
|
||||
package dialect
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"reflect"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
"github.com/spf13/viper"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
|
||||
)
|
||||
|
||||
type Hook struct {
|
||||
Match func(string) bool
|
||||
Decode func(config any) (database.Connector, error)
|
||||
Name string
|
||||
Constructor func() database.Connector
|
||||
}
|
||||
|
||||
var hooks = []Hook{
|
||||
{
|
||||
Match: postgres.NameMatcher,
|
||||
Decode: postgres.DecodeConfig,
|
||||
Name: postgres.Name,
|
||||
Constructor: func() database.Connector { return new(postgres.Config) },
|
||||
},
|
||||
// {
|
||||
// Match: gosql.NameMatcher,
|
||||
// Decode: gosql.DecodeConfig,
|
||||
// Name: gosql.Name,
|
||||
// Constructor: func() database.Connector { return new(gosql.Config) },
|
||||
// },
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
|
||||
|
||||
connector database.Connector
|
||||
}
|
||||
|
||||
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
if len(c.Dialects) != 1 {
|
||||
return nil, errors.New("exactly one dialect must be configured")
|
||||
}
|
||||
|
||||
return c.connector.Connect(ctx)
|
||||
}
|
||||
|
||||
// Hooks implements [configure.Unmarshaller].
|
||||
func (c Config) Hooks() []viper.DecoderConfigOption {
|
||||
return []viper.DecoderConfigOption{
|
||||
viper.DecodeHook(decodeHook),
|
||||
}
|
||||
}
|
||||
|
||||
func decodeHook(from, to reflect.Value) (_ any, err error) {
|
||||
if to.Type() != reflect.TypeOf(Config{}) {
|
||||
return from.Interface(), nil
|
||||
}
|
||||
|
||||
config := new(Config)
|
||||
if err = mapstructure.Decode(from.Interface(), config); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if err = config.decodeDialect(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (c *Config) decodeDialect() error {
|
||||
for _, hook := range hooks {
|
||||
for name, config := range c.Dialects {
|
||||
if !hook.Match(name) {
|
||||
continue
|
||||
}
|
||||
|
||||
connector, err := hook.Decode(config)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
c.connector = connector
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return errors.New("no dialect found")
|
||||
}
|
||||
89
backend/v3/storage/database/dialect/postgres/config.go
Normal file
89
backend/v3/storage/database/dialect/postgres/config.go
Normal file
@@ -0,0 +1,89 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
"github.com/mitchellh/mapstructure"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var (
|
||||
_ database.Connector = (*Config)(nil)
|
||||
Name = "postgres"
|
||||
isMigrated bool
|
||||
)
|
||||
|
||||
type Config struct {
|
||||
*pgxpool.Config
|
||||
*pgxpool.Pool
|
||||
|
||||
// Host string
|
||||
// Port int32
|
||||
// Database string
|
||||
// MaxOpenConns uint32
|
||||
// MaxIdleConns uint32
|
||||
// MaxConnLifetime time.Duration
|
||||
// MaxConnIdleTime time.Duration
|
||||
// User User
|
||||
// // Additional options to be appended as options=<Options>
|
||||
// // The value will be taken as is. Multiple options are space separated.
|
||||
// Options string
|
||||
|
||||
// configuredFields []string
|
||||
}
|
||||
|
||||
// Connect implements [database.Connector].
|
||||
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
|
||||
pool, err := c.getPool(ctx)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
if err = pool.Ping(ctx); err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &pgxPool{Pool: pool}, nil
|
||||
}
|
||||
|
||||
func (c *Config) getPool(ctx context.Context) (*pgxpool.Pool, error) {
|
||||
if c.Pool != nil {
|
||||
return c.Pool, nil
|
||||
}
|
||||
return pgxpool.NewWithConfig(ctx, c.Config)
|
||||
}
|
||||
|
||||
func NameMatcher(name string) bool {
|
||||
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
|
||||
}
|
||||
|
||||
func DecodeConfig(input any) (database.Connector, error) {
|
||||
switch c := input.(type) {
|
||||
case string:
|
||||
config, err := pgxpool.ParseConfig(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{Config: config}, nil
|
||||
case map[string]any:
|
||||
connector := new(Config)
|
||||
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
|
||||
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
|
||||
WeaklyTypedInput: true,
|
||||
Result: connector,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err = decoder.Decode(c); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Config{
|
||||
Config: &pgxpool.Config{},
|
||||
}, nil
|
||||
}
|
||||
return nil, errors.New("invalid configuration")
|
||||
}
|
||||
67
backend/v3/storage/database/dialect/postgres/conn.go
Normal file
67
backend/v3/storage/database/dialect/postgres/conn.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration"
|
||||
)
|
||||
|
||||
type pgxConn struct {
|
||||
*pgxpool.Conn
|
||||
}
|
||||
|
||||
var _ database.Client = (*pgxConn)(nil)
|
||||
|
||||
// Release implements [database.Client].
|
||||
func (c *pgxConn) Release(_ context.Context) error {
|
||||
c.Conn.Release()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Begin implements [database.Client].
|
||||
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Query implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
|
||||
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Conn.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &Rows{rows}, nil
|
||||
}
|
||||
|
||||
// QueryRow implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
|
||||
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{c.Conn.QueryRow(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := c.Conn.Exec(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
return res.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// Migrate implements [database.Migrator].
|
||||
func (c *pgxConn) Migrate(ctx context.Context) error {
|
||||
if isMigrated {
|
||||
return nil
|
||||
}
|
||||
err := migration.Migrate(ctx, c.Conn.Conn())
|
||||
isMigrated = err == nil
|
||||
return wrapError(err)
|
||||
}
|
||||
2
backend/v3/storage/database/dialect/postgres/doc.go
Normal file
2
backend/v3/storage/database/dialect/postgres/doc.go
Normal file
@@ -0,0 +1,2 @@
|
||||
// pgxpool v5 implementation of the interfaces defined in the database package.
|
||||
package postgres
|
||||
@@ -0,0 +1,50 @@
|
||||
// embedded is used for testing purposes
|
||||
package embedded
|
||||
|
||||
import (
|
||||
"net"
|
||||
"os"
|
||||
|
||||
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
|
||||
)
|
||||
|
||||
// StartEmbedded starts an embedded postgres v16 instance and returns a database connector and a stop function
|
||||
// the database is started on a random port and data are stored in a temporary directory
|
||||
// its used for testing purposes only
|
||||
func StartEmbedded() (connector database.Connector, stop func(), err error) {
|
||||
path, err := os.MkdirTemp("", "zitadel-embedded-postgres-*")
|
||||
logging.OnError(err).Fatal("unable to create temp dir")
|
||||
|
||||
port, close := getPort()
|
||||
|
||||
config := embeddedpostgres.DefaultConfig().Version(embeddedpostgres.V16).Port(uint32(port)).RuntimePath(path)
|
||||
embedded := embeddedpostgres.NewDatabase(config)
|
||||
|
||||
close()
|
||||
err = embedded.Start()
|
||||
logging.OnError(err).Fatal("unable to start db")
|
||||
|
||||
connector, err = postgres.DecodeConfig(config.GetConnectionURL())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return connector, func() {
|
||||
logging.OnError(embedded.Stop()).Error("unable to stop db")
|
||||
}, nil
|
||||
}
|
||||
|
||||
// getPort returns a free port and locks it until close is called
|
||||
func getPort() (port uint16, close func()) {
|
||||
l, err := net.Listen("tcp", ":0")
|
||||
logging.OnError(err).Fatal("unable to get port")
|
||||
port = uint16(l.Addr().(*net.TCPAddr).Port)
|
||||
logging.WithFields("port", port).Info("Port is available")
|
||||
return port, func() {
|
||||
logging.OnError(l.Close()).Error("unable to close port listener")
|
||||
}
|
||||
}
|
||||
38
backend/v3/storage/database/dialect/postgres/error.go
Normal file
38
backend/v3/storage/database/dialect/postgres/error.go
Normal file
@@ -0,0 +1,38 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
func wrapError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, pgx.ErrNoRows) {
|
||||
return database.NewNoRowFoundError(err)
|
||||
}
|
||||
var pgxErr *pgconn.PgError
|
||||
if !errors.As(err, &pgxErr) {
|
||||
return database.NewUnknownError(err)
|
||||
}
|
||||
switch pgxErr.Code {
|
||||
// 23514: check_violation - A value violates a CHECK constraint.
|
||||
case "23514":
|
||||
return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
// 23505: unique_violation - A value violates a UNIQUE constraint.
|
||||
case "23505":
|
||||
return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
// 23503: foreign_key_violation - A value violates a foreign key constraint.
|
||||
case "23503":
|
||||
return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
// 23502: not_null_violation - A value violates a NOT NULL constraint.
|
||||
case "23502":
|
||||
return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
}
|
||||
return database.NewUnknownError(err)
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed 001_instance_table/up.sql
|
||||
up001InstanceTable string
|
||||
//go:embed 001_instance_table/down.sql
|
||||
down001InstanceTable string
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerSQLMigration(1, up001InstanceTable, down001InstanceTable)
|
||||
}
|
||||
@@ -0,0 +1 @@
|
||||
DROP TABLE zitadel.instances;
|
||||
@@ -0,0 +1,25 @@
|
||||
CREATE TABLE IF NOT EXISTS zitadel.instances(
|
||||
id TEXT NOT NULL CHECK (id <> '') PRIMARY KEY,
|
||||
name TEXT NOT NULL CHECK (name <> ''),
|
||||
default_org_id TEXT, -- NOT NULL,
|
||||
iam_project_id TEXT, -- NOT NULL,
|
||||
console_client_id TEXT, -- NOT NULL,
|
||||
console_app_id TEXT, -- NOT NULL,
|
||||
default_language TEXT, -- NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
|
||||
);
|
||||
|
||||
CREATE OR REPLACE FUNCTION zitadel.set_updated_at()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at := NOW();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TRIGGER trigger_set_updated_at
|
||||
BEFORE UPDATE ON zitadel.instances
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
@@ -0,0 +1,16 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed 002_organization_table/up.sql
|
||||
up002OrganizationTable string
|
||||
//go:embed 002_organization_table/down.sql
|
||||
down002OrganizationTable string
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerSQLMigration(2, up002OrganizationTable, down002OrganizationTable)
|
||||
}
|
||||
@@ -0,0 +1,2 @@
|
||||
DROP TABLE zitadel.organizations;
|
||||
DROP Type zitadel.organization_state;
|
||||
@@ -0,0 +1,24 @@
|
||||
CREATE TYPE zitadel.organization_state AS ENUM (
|
||||
'active',
|
||||
'inactive'
|
||||
);
|
||||
|
||||
CREATE TABLE zitadel.organizations(
|
||||
id TEXT NOT NULL CHECK (id <> ''),
|
||||
name TEXT NOT NULL CHECK (name <> ''),
|
||||
instance_id TEXT NOT NULL REFERENCES zitadel.instances (id) ON DELETE CASCADE,
|
||||
state zitadel.organization_state NOT NULL,
|
||||
created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
|
||||
updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
|
||||
|
||||
PRIMARY KEY (instance_id, id)
|
||||
);
|
||||
|
||||
CREATE UNIQUE INDEX org_unique_instance_id_name_idx
|
||||
ON zitadel.organizations (instance_id, name);
|
||||
|
||||
CREATE TRIGGER trigger_set_updated_at
|
||||
BEFORE UPDATE ON zitadel.organizations
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
@@ -0,0 +1,16 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
_ "embed"
|
||||
)
|
||||
|
||||
var (
|
||||
//go:embed 003_domains_table/up.sql
|
||||
up003DomainsTable string
|
||||
//go:embed 003_domains_table/down.sql
|
||||
down003DomainsTable string
|
||||
)
|
||||
|
||||
func init() {
|
||||
registerSQLMigration(3, up003DomainsTable, down003DomainsTable)
|
||||
}
|
||||
@@ -0,0 +1,6 @@
|
||||
DROP TABLE IF EXISTS zitadel.instance_domains;
|
||||
DROP TABLE IF EXISTS zitadel.org_domains;
|
||||
DROP TYPE IF EXISTS zitadel.domain_type;
|
||||
DROP TYPE IF EXISTS zitadel.domain_validation_type;
|
||||
DROP FUNCTION IF EXISTS zitadel.check_verified_org_domain();
|
||||
DROP FUNCTION IF EXISTS zitadel.ensure_single_primary_instance_domain();
|
||||
@@ -0,0 +1,137 @@
|
||||
CREATE TYPE zitadel.domain_validation_type AS ENUM (
|
||||
'dns'
|
||||
, 'http'
|
||||
);
|
||||
|
||||
CREATE TYPE zitadel.domain_type AS ENUM (
|
||||
'custom'
|
||||
, 'trusted'
|
||||
);
|
||||
|
||||
CREATE TABLE zitadel.instance_domains(
|
||||
instance_id TEXT NOT NULL
|
||||
, domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255)
|
||||
, is_primary BOOLEAN
|
||||
, is_generated BOOLEAN
|
||||
, type zitadel.domain_type NOT NULL
|
||||
|
||||
, created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
|
||||
, updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
|
||||
|
||||
, PRIMARY KEY (domain)
|
||||
|
||||
, FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) ON DELETE CASCADE
|
||||
|
||||
, CONSTRAINT primary_cannot_be_trusted CHECK (is_primary IS NULL OR type != 'trusted')
|
||||
, CONSTRAINT generated_cannot_be_trusted CHECK (is_generated IS NULL OR type != 'trusted')
|
||||
, CONSTRAINT custom_values_set CHECK ((is_primary IS NOT NULL AND is_generated IS NOT NULL) OR type != 'custom')
|
||||
);
|
||||
|
||||
CREATE INDEX idx_instance_domain_instance ON zitadel.instance_domains(instance_id);
|
||||
|
||||
CREATE TABLE zitadel.org_domains(
|
||||
instance_id TEXT NOT NULL
|
||||
, org_id TEXT NOT NULL
|
||||
, domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255)
|
||||
, is_verified BOOLEAN NOT NULL DEFAULT FALSE
|
||||
, is_primary BOOLEAN NOT NULL DEFAULT FALSE
|
||||
, validation_type zitadel.domain_validation_type
|
||||
|
||||
, created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
|
||||
, updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
|
||||
|
||||
, PRIMARY KEY (instance_id, org_id, domain)
|
||||
|
||||
, FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) ON DELETE CASCADE
|
||||
|
||||
, UNIQUE (instance_id, org_id, domain)
|
||||
);
|
||||
|
||||
CREATE INDEX idx_org_domain ON zitadel.org_domains(instance_id, domain);
|
||||
|
||||
-- Trigger to update the updated_at timestamp on instance_domains
|
||||
CREATE TRIGGER trg_set_updated_at_instance_domains
|
||||
BEFORE UPDATE ON zitadel.instance_domains
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
-- Trigger to update the updated_at timestamp on org_domains
|
||||
CREATE TRIGGER trg_set_updated_at_org_domains
|
||||
BEFORE UPDATE ON zitadel.org_domains
|
||||
FOR EACH ROW
|
||||
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
|
||||
EXECUTE FUNCTION zitadel.set_updated_at();
|
||||
|
||||
-- Function to check for already verified org domains
|
||||
CREATE OR REPLACE FUNCTION zitadel.check_verified_org_domain()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
-- Check if there's already a verified domain within this instance (excluding the current record being updated)
|
||||
IF EXISTS (
|
||||
SELECT 1
|
||||
FROM zitadel.org_domains
|
||||
WHERE instance_id = NEW.instance_id
|
||||
AND domain = NEW.domain
|
||||
AND is_verified = TRUE
|
||||
AND (TG_OP = 'INSERT' OR (org_id != NEW.org_id))
|
||||
) THEN
|
||||
RAISE EXCEPTION 'org domain is already taken';
|
||||
END IF;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Trigger to enforce verified domain constraint on org_domains
|
||||
CREATE TRIGGER trg_check_verified_org_domain
|
||||
BEFORE INSERT OR UPDATE ON zitadel.org_domains
|
||||
FOR EACH ROW
|
||||
WHEN (NEW.is_verified IS TRUE)
|
||||
EXECUTE FUNCTION zitadel.check_verified_org_domain();
|
||||
|
||||
-- Function to ensure only one primary domain per instance in instance_domains
|
||||
CREATE OR REPLACE FUNCTION zitadel.ensure_single_primary_instance_domain()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
-- If setting this domain as primary, update all other domains in the same instance to non-primary
|
||||
UPDATE zitadel.instance_domains
|
||||
SET is_primary = FALSE, updated_at = NOW()
|
||||
WHERE instance_id = NEW.instance_id
|
||||
AND domain != NEW.domain
|
||||
AND is_primary = TRUE
|
||||
AND type = 'custom';
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Trigger to enforce single primary domain constraint on instance_domains
|
||||
CREATE TRIGGER trg_ensure_single_primary_instance_domain
|
||||
BEFORE INSERT OR UPDATE ON zitadel.instance_domains
|
||||
FOR EACH ROW
|
||||
WHEN (NEW.is_primary IS TRUE)
|
||||
EXECUTE FUNCTION zitadel.ensure_single_primary_instance_domain();
|
||||
|
||||
-- Function to ensure only one primary domain per organization in org_domains
|
||||
CREATE OR REPLACE FUNCTION zitadel.ensure_single_primary_org_domain()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
-- If setting this domain as primary, update all other domains in the same organization to non-primary
|
||||
UPDATE zitadel.org_domains
|
||||
SET is_primary = FALSE, updated_at = NOW()
|
||||
WHERE instance_id = NEW.instance_id
|
||||
AND org_id = NEW.org_id
|
||||
AND domain != NEW.domain
|
||||
AND is_primary = TRUE;
|
||||
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
-- Trigger to enforce single primary domain constraint on org_domains
|
||||
CREATE TRIGGER trg_ensure_single_primary_org_domain
|
||||
BEFORE INSERT OR UPDATE ON zitadel.org_domains
|
||||
FOR EACH ROW
|
||||
WHEN (NEW.is_primary IS TRUE)
|
||||
EXECUTE FUNCTION zitadel.ensure_single_primary_org_domain();
|
||||
@@ -0,0 +1,13 @@
|
||||
// This package contains the migration logic for the PostgreSQL dialect.
|
||||
// It uses the [github.com/jackc/tern/v2/migrate] package to handle the migration process.
|
||||
//
|
||||
// **Developer Note**:
|
||||
//
|
||||
// Each migration MUST be registered in an init function.
|
||||
// Create a go file for each migration with the sequence of the migration as prefix and some descriptive name.
|
||||
// The file name MUST be in the format <sequence>_<name>.go.
|
||||
// Each migration SHOULD provide an up and down migration.
|
||||
// Prefer to write SQL statements instead of funcs if it is reasonable.
|
||||
// To keep the folder clean create a folder to store the sql files with the following format: <sequence>_<name>/{up/down}.sql.
|
||||
// And use the go embed directive to embed the sql files.
|
||||
package migration
|
||||
@@ -0,0 +1,33 @@
|
||||
package migration
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/tern/v2/migrate"
|
||||
)
|
||||
|
||||
var migrations []*migrate.Migration
|
||||
|
||||
func Migrate(ctx context.Context, conn *pgx.Conn) error {
|
||||
// we need to ensure that the schema exists before we can run the migration
|
||||
// because creating the migrations table already required the schema
|
||||
_, err := conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS zitadel")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
migrator, err := migrate.NewMigrator(ctx, conn, "zitadel.migrations")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
migrator.Migrations = migrations
|
||||
return migrator.Migrate(ctx)
|
||||
}
|
||||
|
||||
func registerSQLMigration(sequence int32, up, down string) {
|
||||
migrations = append(migrations, &migrate.Migration{
|
||||
Sequence: sequence,
|
||||
UpSQL: up,
|
||||
DownSQL: down,
|
||||
})
|
||||
}
|
||||
@@ -0,0 +1,60 @@
|
||||
package migration_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded"
|
||||
)
|
||||
|
||||
func TestMigrate(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
stmt string
|
||||
args []any
|
||||
res []any
|
||||
}{
|
||||
{
|
||||
name: "schema",
|
||||
stmt: "SELECT EXISTS(SELECT 1 FROM information_schema.schemata where schema_name = 'zitadel') ;",
|
||||
res: []any{true},
|
||||
},
|
||||
{
|
||||
name: "001",
|
||||
stmt: "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE schemaname = 'zitadel' and tablename=$1)",
|
||||
args: []any{"instances"},
|
||||
res: []any{true},
|
||||
},
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
connector, stop, err := embedded.StartEmbedded()
|
||||
require.NoError(t, err, "failed to start embedded postgres")
|
||||
defer stop()
|
||||
|
||||
client, err := connector.Connect(ctx)
|
||||
require.NoError(t, err, "failed to connect to embedded postgres")
|
||||
|
||||
err = client.(database.Migrator).Migrate(ctx)
|
||||
require.NoError(t, err, "failed to execute migration steps")
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := make([]any, len(tt.res))
|
||||
for i := range got {
|
||||
got[i] = new(any)
|
||||
tt.res[i] = gu.Ptr(tt.res[i])
|
||||
}
|
||||
|
||||
require.NoError(t, client.QueryRow(ctx, tt.stmt, tt.args...).Scan(got...), "failed to execute check query")
|
||||
|
||||
assert.Equal(t, tt.res, got, "query result does not match")
|
||||
})
|
||||
}
|
||||
}
|
||||
100
backend/v3/storage/database/dialect/postgres/pool.go
Normal file
100
backend/v3/storage/database/dialect/postgres/pool.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration"
|
||||
)
|
||||
|
||||
type pgxPool struct {
|
||||
*pgxpool.Pool
|
||||
}
|
||||
|
||||
var _ database.Pool = (*pgxPool)(nil)
|
||||
|
||||
func PGxPool(pool *pgxpool.Pool) *pgxPool {
|
||||
return &pgxPool{
|
||||
Pool: pool,
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire implements [database.Pool].
|
||||
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
conn, err := c.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &pgxConn{Conn: conn}, nil
|
||||
}
|
||||
|
||||
// Query implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := c.Pool.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &Rows{rows}, nil
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{c.Pool.QueryRow(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := c.Pool.Exec(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
return res.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// Begin implements [database.Pool].
|
||||
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &pgxTx{tx}, nil
|
||||
}
|
||||
|
||||
// Close implements [database.Pool].
|
||||
func (c *pgxPool) Close(_ context.Context) error {
|
||||
c.Pool.Close()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Migrate implements [database.Migrator].
|
||||
func (c *pgxPool) Migrate(ctx context.Context) error {
|
||||
if isMigrated {
|
||||
return nil
|
||||
}
|
||||
|
||||
client, err := c.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = migration.Migrate(ctx, client.Conn())
|
||||
isMigrated = err == nil
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
// Migrate implements [database.PoolTest].
|
||||
func (c *pgxPool) MigrateTest(ctx context.Context) error {
|
||||
client, err := c.Pool.Acquire(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = migration.Migrate(ctx, client.Conn())
|
||||
isMigrated = err == nil
|
||||
return err
|
||||
}
|
||||
77
backend/v3/storage/database/dialect/postgres/rows.go
Normal file
77
backend/v3/storage/database/dialect/postgres/rows.go
Normal file
@@ -0,0 +1,77 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"github.com/georgysavva/scany/v2/pgxscan"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var (
|
||||
_ database.Rows = (*Rows)(nil)
|
||||
_ database.CollectableRows = (*Rows)(nil)
|
||||
_ database.Row = (*Row)(nil)
|
||||
)
|
||||
|
||||
type Row struct{ pgx.Row }
|
||||
|
||||
// Scan implements [database.Row].
|
||||
// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row.
|
||||
func (r *Row) Scan(dest ...any) error {
|
||||
return wrapError(r.Row.Scan(dest...))
|
||||
}
|
||||
|
||||
type Rows struct{ pgx.Rows }
|
||||
|
||||
// Err implements [database.Rows].
|
||||
// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows.
|
||||
func (r *Rows) Err() error {
|
||||
return wrapError(r.Rows.Err())
|
||||
}
|
||||
|
||||
func (r *Rows) Scan(dest ...any) error {
|
||||
return wrapError(r.Rows.Scan(dest...))
|
||||
}
|
||||
|
||||
// Collect implements [database.CollectableRows].
|
||||
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
|
||||
func (r *Rows) Collect(dest any) (err error) {
|
||||
defer func() {
|
||||
closeErr := r.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanAll(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectFirst implements [database.CollectableRows].
|
||||
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
|
||||
func (r *Rows) CollectFirst(dest any) (err error) {
|
||||
defer func() {
|
||||
closeErr := r.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanRow(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectExactlyOneRow implements [database.CollectableRows].
|
||||
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
|
||||
func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
|
||||
defer func() {
|
||||
closeErr := r.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanOne(dest, r.Rows))
|
||||
}
|
||||
|
||||
// Close implements [database.Rows].
|
||||
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
|
||||
func (r *Rows) Close() error {
|
||||
r.Rows.Close()
|
||||
return nil
|
||||
}
|
||||
107
backend/v3/storage/database/dialect/postgres/tx.go
Normal file
107
backend/v3/storage/database/dialect/postgres/tx.go
Normal file
@@ -0,0 +1,107 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type pgxTx struct{ pgx.Tx }
|
||||
|
||||
var _ database.Transaction = (*pgxTx)(nil)
|
||||
|
||||
// Commit implements [database.Transaction].
|
||||
func (tx *pgxTx) Commit(ctx context.Context) error {
|
||||
err := tx.Tx.Commit(ctx)
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
// Rollback implements [database.Transaction].
|
||||
func (tx *pgxTx) Rollback(ctx context.Context) error {
|
||||
err := tx.Tx.Rollback(ctx)
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
// End implements [database.Transaction].
|
||||
func (tx *pgxTx) End(ctx context.Context, err error) error {
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback(ctx)
|
||||
if rollbackErr != nil {
|
||||
err = errors.Join(err, rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Query implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
|
||||
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
rows, err := tx.Tx.Query(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &Rows{rows}, nil
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
|
||||
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{tx.Tx.QueryRow(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := tx.Tx.Exec(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
return res.RowsAffected(), nil
|
||||
}
|
||||
|
||||
// Begin implements [database.Transaction].
|
||||
// As postgres does not support nested transactions we use savepoints to emulate them.
|
||||
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
|
||||
savepoint, err := tx.Tx.Begin(ctx)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &pgxTx{savepoint}, nil
|
||||
}
|
||||
|
||||
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
|
||||
if opts == nil {
|
||||
return pgx.TxOptions{}
|
||||
}
|
||||
|
||||
return pgx.TxOptions{
|
||||
IsoLevel: isolationToPgx(opts.IsolationLevel),
|
||||
AccessMode: accessModeToPgx(opts.AccessMode),
|
||||
}
|
||||
}
|
||||
|
||||
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
|
||||
switch isolation {
|
||||
case database.IsolationLevelSerializable:
|
||||
return pgx.Serializable
|
||||
case database.IsolationLevelReadCommitted:
|
||||
return pgx.ReadCommitted
|
||||
default:
|
||||
return pgx.Serializable
|
||||
}
|
||||
}
|
||||
|
||||
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
|
||||
switch accessMode {
|
||||
case database.AccessModeReadWrite:
|
||||
return pgx.ReadWrite
|
||||
case database.AccessModeReadOnly:
|
||||
return pgx.ReadOnly
|
||||
default:
|
||||
return pgx.ReadWrite
|
||||
}
|
||||
}
|
||||
60
backend/v3/storage/database/dialect/sql/conn.go
Normal file
60
backend/v3/storage/database/dialect/sql/conn.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type sqlConn struct {
|
||||
*sql.Conn
|
||||
}
|
||||
|
||||
var _ database.Client = (*sqlConn)(nil)
|
||||
|
||||
// Release implements [database.Client].
|
||||
func (c *sqlConn) Release(_ context.Context) error {
|
||||
return c.Close()
|
||||
}
|
||||
|
||||
// Begin implements [database.Client].
|
||||
func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &sqlTx{tx}, nil
|
||||
}
|
||||
|
||||
// Query implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
|
||||
func (c *sqlConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
//nolint:rowserrcheck // Rows.Close is called by the caller
|
||||
rows, err := c.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &Rows{rows}, nil
|
||||
}
|
||||
|
||||
// QueryRow implements sql.Client.
|
||||
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
|
||||
func (c *sqlConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{c.QueryRowContext(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *sqlConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := c.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// Migrate implements [database.Migrator].
|
||||
func (c *sqlConn) Migrate(ctx context.Context) error {
|
||||
return ErrMigrate
|
||||
}
|
||||
3
backend/v3/storage/database/dialect/sql/doc.go
Normal file
3
backend/v3/storage/database/dialect/sql/doc.go
Normal file
@@ -0,0 +1,3 @@
|
||||
// [database/sql] implementation of the interfaces defined in the database package.
|
||||
// This package is used to migrate from event driven to relational Zitadel.
|
||||
package sql
|
||||
41
backend/v3/storage/database/dialect/sql/error.go
Normal file
41
backend/v3/storage/database/dialect/sql/error.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var ErrMigrate = errors.New("sql does not support migrations, use a different dialect")
|
||||
|
||||
func wrapError(err error) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
|
||||
return database.NewNoRowFoundError(err)
|
||||
}
|
||||
var pgxErr *pgconn.PgError
|
||||
if !errors.As(err, &pgxErr) {
|
||||
return database.NewUnknownError(err)
|
||||
}
|
||||
switch pgxErr.Code {
|
||||
// 23514: check_violation - A value violates a CHECK constraint.
|
||||
case "23514":
|
||||
return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
// 23505: unique_violation - A value violates a UNIQUE constraint.
|
||||
case "23505":
|
||||
return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
// 23503: foreign_key_violation - A value violates a foreign key constraint.
|
||||
case "23503":
|
||||
return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
// 23502: not_null_violation - A value violates a NOT NULL constraint.
|
||||
case "23502":
|
||||
return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
|
||||
}
|
||||
return database.NewUnknownError(err)
|
||||
}
|
||||
75
backend/v3/storage/database/dialect/sql/pool.go
Normal file
75
backend/v3/storage/database/dialect/sql/pool.go
Normal file
@@ -0,0 +1,75 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type sqlPool struct {
|
||||
*sql.DB
|
||||
}
|
||||
|
||||
var _ database.Pool = (*sqlPool)(nil)
|
||||
|
||||
func SQLPool(db *sql.DB) *sqlPool {
|
||||
return &sqlPool{
|
||||
DB: db,
|
||||
}
|
||||
}
|
||||
|
||||
// Acquire implements [database.Pool].
|
||||
func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
|
||||
conn, err := c.Conn(ctx)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &sqlConn{Conn: conn}, nil
|
||||
}
|
||||
|
||||
// Query implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
|
||||
func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
//nolint:rowserrcheck // Rows.Close is called by the caller
|
||||
rows, err := c.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &Rows{rows}, nil
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
|
||||
func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{c.QueryRowContext(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Pool].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := c.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// Begin implements [database.Pool].
|
||||
func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
|
||||
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &sqlTx{tx}, nil
|
||||
}
|
||||
|
||||
// Close implements [database.Pool].
|
||||
func (c *sqlPool) Close(_ context.Context) error {
|
||||
return c.DB.Close()
|
||||
}
|
||||
|
||||
// Migrate implements [database.Migrator].
|
||||
func (c *sqlPool) Migrate(ctx context.Context) error {
|
||||
return ErrMigrate
|
||||
}
|
||||
78
backend/v3/storage/database/dialect/sql/rows.go
Normal file
78
backend/v3/storage/database/dialect/sql/rows.go
Normal file
@@ -0,0 +1,78 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
pgxscan "github.com/georgysavva/scany/v2/dbscan"
|
||||
"github.com/jackc/pgx/v5"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var (
|
||||
_ database.Rows = (*Rows)(nil)
|
||||
_ database.CollectableRows = (*Rows)(nil)
|
||||
_ database.Row = (*Row)(nil)
|
||||
)
|
||||
|
||||
type Row struct{ pgx.Row }
|
||||
|
||||
// Scan implements [database.Row].
|
||||
// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row.
|
||||
func (r *Row) Scan(dest ...any) error {
|
||||
return wrapError(r.Row.Scan(dest...))
|
||||
}
|
||||
|
||||
type Rows struct{ *sql.Rows }
|
||||
|
||||
// Err implements [database.Rows].
|
||||
// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows.
|
||||
func (r *Rows) Err() error {
|
||||
return wrapError(r.Rows.Err())
|
||||
}
|
||||
|
||||
func (r *Rows) Scan(dest ...any) error {
|
||||
return wrapError(r.Rows.Scan(dest...))
|
||||
}
|
||||
|
||||
// Collect implements [database.CollectableRows].
|
||||
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
|
||||
func (r *Rows) Collect(dest any) (err error) {
|
||||
defer func() {
|
||||
closeErr := r.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanAll(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectFirst implements [database.CollectableRows].
|
||||
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
|
||||
func (r *Rows) CollectFirst(dest any) (err error) {
|
||||
defer func() {
|
||||
closeErr := r.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanRow(dest, r.Rows))
|
||||
}
|
||||
|
||||
// CollectExactlyOneRow implements [database.CollectableRows].
|
||||
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
|
||||
func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
|
||||
defer func() {
|
||||
closeErr := r.Close()
|
||||
if err == nil {
|
||||
err = closeErr
|
||||
}
|
||||
}()
|
||||
return wrapError(pgxscan.ScanOne(dest, r.Rows))
|
||||
}
|
||||
|
||||
// Close implements [database.Rows].
|
||||
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
|
||||
func (r *Rows) Close() error {
|
||||
return r.Rows.Close()
|
||||
}
|
||||
69
backend/v3/storage/database/dialect/sql/savepoint.go
Normal file
69
backend/v3/storage/database/dialect/sql/savepoint.go
Normal file
@@ -0,0 +1,69 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var _ database.Transaction = (*sqlSavepoint)(nil)
|
||||
|
||||
const (
|
||||
savepointName = "zitadel_savepoint"
|
||||
createSavepoint = "SAVEPOINT " + savepointName
|
||||
rollbackToSavepoint = "ROLLBACK TO SAVEPOINT " + savepointName
|
||||
commitSavepoint = "RELEASE SAVEPOINT " + savepointName
|
||||
)
|
||||
|
||||
type sqlSavepoint struct {
|
||||
parent database.Transaction
|
||||
}
|
||||
|
||||
// Commit implements [database.Transaction].
|
||||
func (s *sqlSavepoint) Commit(ctx context.Context) error {
|
||||
_, err := s.parent.Exec(ctx, commitSavepoint)
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
// Rollback implements [database.Transaction].
|
||||
func (s *sqlSavepoint) Rollback(ctx context.Context) error {
|
||||
_, err := s.parent.Exec(ctx, rollbackToSavepoint)
|
||||
return wrapError(err)
|
||||
}
|
||||
|
||||
// End implements [database.Transaction].
|
||||
func (s *sqlSavepoint) End(ctx context.Context, err error) error {
|
||||
if err != nil {
|
||||
rollbackErr := s.Rollback(ctx)
|
||||
if rollbackErr != nil {
|
||||
err = errors.Join(err, rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return s.Commit(ctx)
|
||||
}
|
||||
|
||||
// Query implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
|
||||
func (s *sqlSavepoint) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
return s.parent.Query(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
|
||||
func (s *sqlSavepoint) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return s.parent.QueryRow(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Exec implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (s *sqlSavepoint) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
return s.parent.Exec(ctx, sql, args...)
|
||||
}
|
||||
|
||||
// Begin implements [database.Transaction].
|
||||
// As postgres does not support nested transactions we use savepoints to emulate them.
|
||||
func (s *sqlSavepoint) Begin(ctx context.Context) (database.Transaction, error) {
|
||||
return s.parent.Begin(ctx)
|
||||
}
|
||||
104
backend/v3/storage/database/dialect/sql/tx.go
Normal file
104
backend/v3/storage/database/dialect/sql/tx.go
Normal file
@@ -0,0 +1,104 @@
|
||||
package sql
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type sqlTx struct{ *sql.Tx }
|
||||
|
||||
var _ database.Transaction = (*sqlTx)(nil)
|
||||
|
||||
func SQLTx(tx *sql.Tx) *sqlTx {
|
||||
return &sqlTx{
|
||||
Tx: tx,
|
||||
}
|
||||
}
|
||||
|
||||
// Commit implements [database.Transaction].
|
||||
func (tx *sqlTx) Commit(ctx context.Context) error {
|
||||
return wrapError(tx.Tx.Commit())
|
||||
}
|
||||
|
||||
// Rollback implements [database.Transaction].
|
||||
func (tx *sqlTx) Rollback(ctx context.Context) error {
|
||||
return wrapError(tx.Tx.Rollback())
|
||||
}
|
||||
|
||||
// End implements [database.Transaction].
|
||||
func (tx *sqlTx) End(ctx context.Context, err error) error {
|
||||
if err != nil {
|
||||
rollbackErr := tx.Rollback(ctx)
|
||||
if rollbackErr != nil {
|
||||
err = errors.Join(err, rollbackErr)
|
||||
}
|
||||
return err
|
||||
}
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
// Query implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
|
||||
func (tx *sqlTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
|
||||
//nolint:rowserrcheck // Rows.Close is called by the caller
|
||||
rows, err := tx.QueryContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &Rows{rows}, nil
|
||||
}
|
||||
|
||||
// QueryRow implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
|
||||
func (tx *sqlTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
|
||||
return &Row{tx.QueryRowContext(ctx, sql, args...)}
|
||||
}
|
||||
|
||||
// Exec implements [database.Transaction].
|
||||
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
|
||||
func (tx *sqlTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
|
||||
res, err := tx.ExecContext(ctx, sql, args...)
|
||||
if err != nil {
|
||||
return 0, wrapError(err)
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
// Begin implements [database.Transaction].
|
||||
// As postgres does not support nested transactions we use savepoints to emulate them.
|
||||
func (tx *sqlTx) Begin(ctx context.Context) (database.Transaction, error) {
|
||||
_, err := tx.ExecContext(ctx, createSavepoint)
|
||||
if err != nil {
|
||||
return nil, wrapError(err)
|
||||
}
|
||||
return &sqlSavepoint{tx}, nil
|
||||
}
|
||||
|
||||
func transactionOptionsToSQL(opts *database.TransactionOptions) *sql.TxOptions {
|
||||
if opts == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return &sql.TxOptions{
|
||||
Isolation: isolationToSQL(opts.IsolationLevel),
|
||||
ReadOnly: accessModeToSQL(opts.AccessMode),
|
||||
}
|
||||
}
|
||||
|
||||
func isolationToSQL(isolation database.IsolationLevel) sql.IsolationLevel {
|
||||
switch isolation {
|
||||
case database.IsolationLevelSerializable:
|
||||
return sql.LevelSerializable
|
||||
case database.IsolationLevelReadCommitted:
|
||||
return sql.LevelReadCommitted
|
||||
default:
|
||||
return sql.LevelSerializable
|
||||
}
|
||||
}
|
||||
|
||||
func accessModeToSQL(accessMode database.AccessMode) bool {
|
||||
return accessMode == database.AccessModeReadOnly
|
||||
}
|
||||
230
backend/v3/storage/database/errors.go
Normal file
230
backend/v3/storage/database/errors.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
var ErrNoChanges = errors.New("update must contain a change")
|
||||
|
||||
// NoRowFoundError is returned when QueryRow does not find any row.
|
||||
// It wraps the dialect specific original error to provide more context.
|
||||
type NoRowFoundError struct {
|
||||
original error
|
||||
}
|
||||
|
||||
func NewNoRowFoundError(original error) error {
|
||||
return &NoRowFoundError{
|
||||
original: original,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *NoRowFoundError) Error() string {
|
||||
return "no row found"
|
||||
}
|
||||
|
||||
func (e *NoRowFoundError) Is(target error) bool {
|
||||
_, ok := target.(*NoRowFoundError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *NoRowFoundError) Unwrap() error {
|
||||
return e.original
|
||||
}
|
||||
|
||||
// MultipleRowsFoundError is returned when QueryRow finds multiple rows.
|
||||
// It wraps the dialect specific original error to provide more context.
|
||||
type MultipleRowsFoundError struct {
|
||||
original error
|
||||
count int
|
||||
}
|
||||
|
||||
func NewMultipleRowsFoundError(original error, count int) error {
|
||||
return &MultipleRowsFoundError{
|
||||
original: original,
|
||||
count: count,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *MultipleRowsFoundError) Error() string {
|
||||
return fmt.Sprintf("multiple rows found: %d", e.count)
|
||||
}
|
||||
|
||||
func (e *MultipleRowsFoundError) Is(target error) bool {
|
||||
_, ok := target.(*MultipleRowsFoundError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *MultipleRowsFoundError) Unwrap() error {
|
||||
return e.original
|
||||
}
|
||||
|
||||
type IntegrityType string
|
||||
|
||||
const (
|
||||
IntegrityTypeCheck IntegrityType = "check"
|
||||
IntegrityTypeUnique IntegrityType = "unique"
|
||||
IntegrityTypeForeign IntegrityType = "foreign"
|
||||
IntegrityTypeNotNull IntegrityType = "not null"
|
||||
)
|
||||
|
||||
// IntegrityViolationError represents a generic integrity violation error.
|
||||
// It wraps the dialect specific original error to provide more context.
|
||||
type IntegrityViolationError struct {
|
||||
integrityType IntegrityType
|
||||
table string
|
||||
constraint string
|
||||
original error
|
||||
}
|
||||
|
||||
func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error {
|
||||
return &IntegrityViolationError{
|
||||
integrityType: typ,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *IntegrityViolationError) Error() string {
|
||||
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
|
||||
}
|
||||
|
||||
func (e *IntegrityViolationError) Is(target error) bool {
|
||||
_, ok := target.(*IntegrityViolationError)
|
||||
return ok
|
||||
}
|
||||
|
||||
// CheckError is returned when a check constraint fails.
|
||||
// It wraps the [IntegrityViolationError] to provide more context.
|
||||
// It is used to indicate that a check constraint was violated during an insert or update operation.
|
||||
type CheckError struct {
|
||||
IntegrityViolationError
|
||||
}
|
||||
|
||||
func NewCheckError(table, constraint string, original error) error {
|
||||
return &CheckError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeCheck,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *CheckError) Is(target error) bool {
|
||||
_, ok := target.(*CheckError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *CheckError) Unwrap() error {
|
||||
return &e.IntegrityViolationError
|
||||
}
|
||||
|
||||
// UniqueError is returned when a unique constraint fails.
|
||||
// It wraps the [IntegrityViolationError] to provide more context.
|
||||
// It is used to indicate that a unique constraint was violated during an insert or update operation.
|
||||
type UniqueError struct {
|
||||
IntegrityViolationError
|
||||
}
|
||||
|
||||
func NewUniqueError(table, constraint string, original error) error {
|
||||
return &UniqueError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeUnique,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *UniqueError) Is(target error) bool {
|
||||
_, ok := target.(*UniqueError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *UniqueError) Unwrap() error {
|
||||
return &e.IntegrityViolationError
|
||||
}
|
||||
|
||||
// ForeignKeyError is returned when a foreign key constraint fails.
|
||||
// It wraps the [IntegrityViolationError] to provide more context.
|
||||
// It is used to indicate that a foreign key constraint was violated during an insert or update operation
|
||||
type ForeignKeyError struct {
|
||||
IntegrityViolationError
|
||||
}
|
||||
|
||||
func NewForeignKeyError(table, constraint string, original error) error {
|
||||
return &ForeignKeyError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeForeign,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *ForeignKeyError) Is(target error) bool {
|
||||
_, ok := target.(*ForeignKeyError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *ForeignKeyError) Unwrap() error {
|
||||
return &e.IntegrityViolationError
|
||||
}
|
||||
|
||||
// NotNullError is returned when a not null constraint fails.
|
||||
// It wraps the [IntegrityViolationError] to provide more context.
|
||||
// It is used to indicate that a not null constraint was violated during an insert or update operation.
|
||||
type NotNullError struct {
|
||||
IntegrityViolationError
|
||||
}
|
||||
|
||||
func NewNotNullError(table, constraint string, original error) error {
|
||||
return &NotNullError{
|
||||
IntegrityViolationError: IntegrityViolationError{
|
||||
integrityType: IntegrityTypeNotNull,
|
||||
table: table,
|
||||
constraint: constraint,
|
||||
original: original,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
func (e *NotNullError) Is(target error) bool {
|
||||
_, ok := target.(*NotNullError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *NotNullError) Unwrap() error {
|
||||
return &e.IntegrityViolationError
|
||||
}
|
||||
|
||||
// UnknownError is returned when an unknown error occurs.
|
||||
// It wraps the dialect specific original error to provide more context.
|
||||
// It is used to indicate that an error occurred that does not fit into any of the other categories.
|
||||
type UnknownError struct {
|
||||
original error
|
||||
}
|
||||
|
||||
func NewUnknownError(original error) error {
|
||||
return &UnknownError{
|
||||
original: original,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *UnknownError) Error() string {
|
||||
return fmt.Sprintf("unknown database error: %v", e.original)
|
||||
}
|
||||
|
||||
func (e *UnknownError) Is(target error) bool {
|
||||
_, ok := target.(*UnknownError)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (e *UnknownError) Unwrap() error {
|
||||
return e.original
|
||||
}
|
||||
78
backend/v3/storage/database/events_testing/events_test.go
Normal file
78
backend/v3/storage/database/events_testing/events_test.go
Normal file
@@ -0,0 +1,78 @@
|
||||
//go:build integration
|
||||
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
v2beta "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta"
|
||||
v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
)
|
||||
|
||||
const ConnString = "host=localhost port=5432 user=zitadel password=zitadel dbname=zitadel sslmode=disable"
|
||||
|
||||
var (
|
||||
dbPool *pgxpool.Pool
|
||||
CTX context.Context
|
||||
Instance *integration.Instance
|
||||
SystemClient system.SystemServiceClient
|
||||
OrgClient v2beta_org.OrganizationServiceClient
|
||||
)
|
||||
|
||||
var pool database.Pool
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(func() int {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
CTX = integration.WithSystemAuthorization(ctx)
|
||||
Instance = integration.NewInstance(CTX)
|
||||
|
||||
SystemClient = integration.SystemClient()
|
||||
OrgClient = Instance.Client.OrgV2beta
|
||||
|
||||
defer func() {
|
||||
_, err := Instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
|
||||
InstanceId: Instance.Instance.Id,
|
||||
})
|
||||
if err != nil {
|
||||
log.Printf("Failed to delete instance on cleanup: %v\n", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var err error
|
||||
dbConfig, err := pgxpool.ParseConfig(ConnString)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
dbConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
|
||||
orgState, err := conn.LoadType(ctx, "zitadel.organization_state")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
conn.TypeMap().RegisterType(orgState)
|
||||
return nil
|
||||
}
|
||||
|
||||
dbPool, err = pgxpool.NewWithConfig(context.Background(), dbConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
pool = postgres.PGxPool(dbPool)
|
||||
|
||||
return m.Run()
|
||||
}())
|
||||
}
|
||||
@@ -0,0 +1,310 @@
|
||||
//go:build integration
|
||||
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
v2beta "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
)
|
||||
|
||||
func TestServer_TestInstanceDomainReduces(t *testing.T) {
|
||||
instance := integration.NewInstance(CTX)
|
||||
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceDomainRepo := instanceRepo.Domains(true)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete instance on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for instance to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
_, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.Instance.Id)),
|
||||
)
|
||||
assert.NoError(ttt, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
t.Run("test instance custom domain add reduces", func(t *testing.T) {
|
||||
// Add a domain to the instance
|
||||
domainName := gofakeit.DomainName()
|
||||
beforeAdd := time.Now()
|
||||
_, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
afterAdd := time.Now()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete instance domain on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
// event instance.domain.added
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
|
||||
assert.False(ttt, *domain.IsPrimary)
|
||||
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test instance custom domain set primary reduces", func(t *testing.T) {
|
||||
// Add a domain to the instance
|
||||
domainName := gofakeit.DomainName()
|
||||
_, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
// first we change the primary domain to something else
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
instanceDomainRepo.IsPrimaryCondition(false),
|
||||
),
|
||||
),
|
||||
database.WithLimit(1),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
_, err = SystemClient.SetPrimaryDomain(CTX, &system.SetPrimaryDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domain.Domain,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete instance domain on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for domain to be created
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
require.False(ttt, *domain.IsPrimary)
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Set domain as primary
|
||||
beforeSetPrimary := time.Now()
|
||||
_, err = SystemClient.SetPrimaryDomain(CTX, &system.SetPrimaryDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
afterSetPrimary := time.Now()
|
||||
|
||||
// Test that set primary reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.IsPrimaryCondition(true),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
// event instance.domain.primary.set
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.True(ttt, *domain.IsPrimary)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeSetPrimary, afterSetPrimary)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test instance custom domain remove reduces", func(t *testing.T) {
|
||||
// Add a domain to the instance
|
||||
domainName := gofakeit.DomainName()
|
||||
_, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Remove the domain
|
||||
_, err = instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
),
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test instance trusted domain add reduces", func(t *testing.T) {
|
||||
// Add a domain to the instance
|
||||
domainName := gofakeit.DomainName()
|
||||
beforeAdd := time.Now()
|
||||
_, err := instance.Client.InstanceV2Beta.AddTrustedDomain(CTX, &v2beta.AddTrustedDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
afterAdd := time.Now()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := instance.Client.InstanceV2Beta.RemoveTrustedDomain(CTX, &v2beta.RemoveTrustedDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete instance domain on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
// event instance.domain.added
|
||||
assert.Equal(ttt, domainName, domain.Domain)
|
||||
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
|
||||
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test instance trusted domain remove reduces", func(t *testing.T) {
|
||||
// Add a domain to the instance
|
||||
domainName := gofakeit.DomainName()
|
||||
_, err := instance.Client.InstanceV2Beta.AddTrustedDomain(CTX, &v2beta.AddTrustedDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Wait for domain to be created and verify it exists
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
_, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// Remove the domain
|
||||
_, err = instance.Client.InstanceV2Beta.RemoveTrustedDomain(CTX, &v2beta.RemoveTrustedDomainRequest{
|
||||
InstanceId: instance.Instance.Id,
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
domain, err := instanceDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
|
||||
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted),
|
||||
),
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
162
backend/v3/storage/database/events_testing/instance_test.go
Normal file
162
backend/v3/storage/database/events_testing/instance_test.go
Normal file
@@ -0,0 +1,162 @@
|
||||
//go:build integration
|
||||
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
)
|
||||
|
||||
func TestServer_TestInstanceReduces(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
t.Run("test instance add reduces", func(t *testing.T) {
|
||||
instanceName := gofakeit.Name()
|
||||
beforeCreate := time.Now()
|
||||
instance, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
|
||||
InstanceName: instanceName,
|
||||
Owner: &system.CreateInstanceRequest_Machine_{
|
||||
Machine: &system.CreateInstanceRequest_Machine{
|
||||
UserName: "owner",
|
||||
Name: "owner",
|
||||
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
|
||||
},
|
||||
},
|
||||
})
|
||||
afterCreate := time.Now()
|
||||
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
|
||||
InstanceId: instance.GetInstanceId(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete instance on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
// event instance.added
|
||||
assert.Equal(ttt, instanceName, instance.Name)
|
||||
// event instance.default.org.set
|
||||
assert.NotNil(t, instance.DefaultOrgID)
|
||||
// event instance.iam.project.set
|
||||
assert.NotNil(t, instance.IAMProjectID)
|
||||
// event instance.iam.console.set
|
||||
assert.NotNil(t, instance.ConsoleAppID)
|
||||
// event instance.default.language.set
|
||||
assert.NotNil(t, instance.DefaultLanguage)
|
||||
// event instance.added
|
||||
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
|
||||
// event instance.added
|
||||
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test instance update reduces", func(t *testing.T) {
|
||||
instanceName := gofakeit.Name()
|
||||
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
|
||||
InstanceName: instanceName,
|
||||
Owner: &system.CreateInstanceRequest_Machine_{
|
||||
Machine: &system.CreateInstanceRequest_Machine{
|
||||
UserName: "owner",
|
||||
Name: "owner",
|
||||
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
|
||||
InstanceId: res.GetInstanceId(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete instance on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// check instance exists
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, instanceName, instance.Name)
|
||||
}, retryDuration, tick)
|
||||
|
||||
instanceName += "new"
|
||||
beforeUpdate := time.Now()
|
||||
_, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{
|
||||
InstanceId: res.InstanceId,
|
||||
InstanceName: instanceName,
|
||||
})
|
||||
afterUpdate := time.Now()
|
||||
require.NoError(t, err)
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
// event instance.changed
|
||||
assert.Equal(t, instanceName, instance.Name)
|
||||
assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test instance delete reduces", func(t *testing.T) {
|
||||
instanceName := gofakeit.Name()
|
||||
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
|
||||
InstanceName: instanceName,
|
||||
Owner: &system.CreateInstanceRequest_Machine_{
|
||||
Machine: &system.CreateInstanceRequest_Machine{
|
||||
UserName: "owner",
|
||||
Name: "owner",
|
||||
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// check instance exists
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, instanceName, instance.Name)
|
||||
}, retryDuration, tick)
|
||||
|
||||
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
|
||||
InstanceId: res.InstanceId,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
instance, err := instanceRepo.Get(CTX,
|
||||
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
|
||||
)
|
||||
// event instance.removed
|
||||
assert.Nil(t, instance)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
133
backend/v3/storage/database/events_testing/org_domain_test.go
Normal file
133
backend/v3/storage/database/events_testing/org_domain_test.go
Normal file
@@ -0,0 +1,133 @@
|
||||
//go:build integration
|
||||
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
v2beta "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
|
||||
)
|
||||
|
||||
func TestServer_TestOrgDomainReduces(t *testing.T) {
|
||||
org, err := OrgClient.CreateOrganization(CTX, &v2beta.CreateOrganizationRequest{
|
||||
Name: gofakeit.Name(),
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
orgDomainRepo := orgRepo.Domains(false)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{
|
||||
Id: org.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete organization on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Wait for org to be created
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
_, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(orgRepo.IDCondition(org.GetId())),
|
||||
)
|
||||
assert.NoError(ttt, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// The API call also sets the domain as primary, so we don't do a separate test for that.
|
||||
t.Run("test organization domain add reduces", func(t *testing.T) {
|
||||
// Add a domain to the organization
|
||||
domainName := gofakeit.DomainName()
|
||||
beforeAdd := time.Now()
|
||||
_, err := OrgClient.AddOrganizationDomain(CTX, &v2beta.AddOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
afterAdd := time.Now()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
Domain: domainName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete domain on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Test that domain add reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
gottenDomain, err := orgDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||
orgDomainRepo.OrgIDCondition(org.Id),
|
||||
orgDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(ttt, err)
|
||||
// event org.domain.added
|
||||
assert.Equal(t, domainName, gottenDomain.Domain)
|
||||
assert.Equal(t, Instance.Instance.Id, gottenDomain.InstanceID)
|
||||
assert.Equal(t, org.Id, gottenDomain.OrgID)
|
||||
|
||||
assert.WithinRange(t, gottenDomain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(t, gottenDomain.UpdatedAt, beforeAdd, afterAdd)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test org domain remove reduces", func(t *testing.T) {
|
||||
// Add a domain to the organization
|
||||
domainName := gofakeit.DomainName()
|
||||
_, err := OrgClient.AddOrganizationDomain(CTX, &v2beta.AddOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
Domain: domainName,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete domain on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// Remove the domain
|
||||
_, err = OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
|
||||
OrganizationId: org.GetId(),
|
||||
Domain: domainName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test that domain remove reduces
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
|
||||
domain, err := orgDomainRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
|
||||
orgDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
),
|
||||
),
|
||||
)
|
||||
// event instance.domain.removed
|
||||
assert.Nil(ttt, domain)
|
||||
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
269
backend/v3/storage/database/events_testing/organization_test.go
Normal file
269
backend/v3/storage/database/events_testing/organization_test.go
Normal file
@@ -0,0 +1,269 @@
|
||||
//go:build integration
|
||||
|
||||
package events_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
"github.com/zitadel/zitadel/internal/integration"
|
||||
v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
|
||||
)
|
||||
|
||||
func TestServer_TestOrganizationReduces(t *testing.T) {
|
||||
instanceID := Instance.ID()
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
|
||||
t.Run("test org add reduces", func(t *testing.T) {
|
||||
beforeCreate := time.Now()
|
||||
orgName := gofakeit.Name()
|
||||
|
||||
org, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
|
||||
Name: orgName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
afterCreate := time.Now()
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
|
||||
Id: org.GetId(),
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete organization on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(org.GetId()),
|
||||
orgRepo.InstanceIDCondition(instanceID),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(tt, err)
|
||||
|
||||
// event org.added
|
||||
assert.NotNil(t, organization.ID)
|
||||
assert.Equal(t, orgName, organization.Name)
|
||||
assert.NotNil(t, organization.InstanceID)
|
||||
assert.Equal(t, domain.OrgStateActive, organization.State)
|
||||
assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate)
|
||||
assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test org change reduces", func(t *testing.T) {
|
||||
orgName := gofakeit.Name()
|
||||
|
||||
// 1. create org
|
||||
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
|
||||
Name: orgName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Cleanup(func() {
|
||||
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete organization on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// 2. update org name
|
||||
beforeUpdate := time.Now()
|
||||
orgName = orgName + "_new"
|
||||
_, err = OrgClient.UpdateOrganization(CTX, &v2beta_org.UpdateOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
Name: orgName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
afterUpdate := time.Now()
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
orgRepo.InstanceIDCondition(instanceID),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// event org.changed
|
||||
assert.Equal(t, orgName, organization.Name)
|
||||
assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test org deactivate reduces", func(t *testing.T) {
|
||||
orgName := gofakeit.Name()
|
||||
|
||||
// 1. create org
|
||||
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
|
||||
Name: orgName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
// Cleanup: delete the organization
|
||||
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete organization on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// 2. deactivate org name
|
||||
beforeDeactivate := time.Now()
|
||||
_, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
})
|
||||
|
||||
require.NoError(t, err)
|
||||
afterDeactivate := time.Now()
|
||||
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
orgRepo.InstanceIDCondition(instanceID),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// event org.deactivate
|
||||
assert.Equal(t, domain.OrgStateInactive, organization.State)
|
||||
assert.WithinRange(t, organization.UpdatedAt, beforeDeactivate, afterDeactivate)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test org activate reduces", func(t *testing.T) {
|
||||
orgName := gofakeit.Name()
|
||||
|
||||
// 1. create org
|
||||
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
|
||||
Name: orgName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
t.Cleanup(func() {
|
||||
// Cleanup: delete the organization
|
||||
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
})
|
||||
if err != nil {
|
||||
t.Logf("Failed to delete organization on cleanup: %v", err)
|
||||
}
|
||||
})
|
||||
|
||||
// 2. deactivate org name
|
||||
_, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
// 3. check org deactivated
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
orgRepo.InstanceIDCondition(instanceID),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, domain.OrgStateInactive, organization.State)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// 4. activate org name
|
||||
beforeActivate := time.Now()
|
||||
_, err = OrgClient.ActivateOrganization(CTX, &v2beta_org.ActivateOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
afterActivate := time.Now()
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
orgRepo.InstanceIDCondition(instanceID),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// event org.reactivate
|
||||
assert.Equal(t, orgName, organization.Name)
|
||||
assert.Equal(t, domain.OrgStateActive, organization.State)
|
||||
assert.WithinRange(t, organization.UpdatedAt, beforeActivate, afterActivate)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
|
||||
t.Run("test org remove reduces", func(t *testing.T) {
|
||||
orgName := gofakeit.Name()
|
||||
|
||||
// 1. create org
|
||||
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
|
||||
Name: orgName,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// 2. check org retrievable
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
_, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
orgRepo.InstanceIDCondition(instanceID),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
}, retryDuration, tick)
|
||||
|
||||
// 3. delete org
|
||||
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
|
||||
Id: organization.Id,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
|
||||
assert.EventuallyWithT(t, func(t *assert.CollectT) {
|
||||
organization, err := orgRepo.Get(CTX,
|
||||
database.WithCondition(
|
||||
database.And(
|
||||
orgRepo.IDCondition(organization.Id),
|
||||
orgRepo.InstanceIDCondition(instanceID),
|
||||
),
|
||||
),
|
||||
)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
|
||||
// event org.remove
|
||||
assert.Nil(t, organization)
|
||||
}, retryDuration, tick)
|
||||
})
|
||||
}
|
||||
3
backend/v3/storage/database/gen_mock.go
Normal file
3
backend/v3/storage/database/gen_mock.go
Normal file
@@ -0,0 +1,3 @@
|
||||
package database
|
||||
|
||||
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction
|
||||
9
backend/v3/storage/database/migration.go
Normal file
9
backend/v3/storage/database/migration.go
Normal file
@@ -0,0 +1,9 @@
|
||||
package database
|
||||
|
||||
import "context"
|
||||
|
||||
type Migrator interface {
|
||||
// Migrate executes migrations to setup the database.
|
||||
// The method can be called once per running Zitadel.
|
||||
Migrate(ctx context.Context) error
|
||||
}
|
||||
136
backend/v3/storage/database/operators.go
Normal file
136
backend/v3/storage/database/operators.go
Normal file
@@ -0,0 +1,136 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
type Value interface {
|
||||
Boolean | Number | Text | Instruction
|
||||
}
|
||||
|
||||
type Operation interface {
|
||||
BooleanOperation | NumberOperation | TextOperation
|
||||
}
|
||||
|
||||
type Text interface {
|
||||
~string | ~[]byte
|
||||
}
|
||||
|
||||
// TextOperation are operations that can be performed on text values.
|
||||
type TextOperation uint8
|
||||
|
||||
const (
|
||||
// TextOperationEqual compares two strings for equality.
|
||||
TextOperationEqual TextOperation = iota + 1
|
||||
// TextOperationEqualIgnoreCase compares two strings for equality, ignoring case.
|
||||
TextOperationEqualIgnoreCase
|
||||
// TextOperationNotEqual compares two strings for inequality.
|
||||
TextOperationNotEqual
|
||||
// TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case.
|
||||
TextOperationNotEqualIgnoreCase
|
||||
// TextOperationStartsWith checks if the first string starts with the second.
|
||||
TextOperationStartsWith
|
||||
// TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
|
||||
TextOperationStartsWithIgnoreCase
|
||||
)
|
||||
|
||||
var textOperations = map[TextOperation]string{
|
||||
TextOperationEqual: " = ",
|
||||
TextOperationEqualIgnoreCase: " LIKE ",
|
||||
TextOperationNotEqual: " <> ",
|
||||
TextOperationNotEqualIgnoreCase: " NOT LIKE ",
|
||||
TextOperationStartsWith: " LIKE ",
|
||||
TextOperationStartsWithIgnoreCase: " LIKE ",
|
||||
}
|
||||
|
||||
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) {
|
||||
switch op {
|
||||
case TextOperationEqual, TextOperationNotEqual:
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteArg(value)
|
||||
case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase:
|
||||
builder.WriteString("LOWER(")
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(")")
|
||||
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteString("LOWER(")
|
||||
builder.WriteArg(value)
|
||||
builder.WriteString(")")
|
||||
case TextOperationStartsWith:
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteArg(value)
|
||||
builder.WriteString(" || '%'")
|
||||
case TextOperationStartsWithIgnoreCase:
|
||||
builder.WriteString("LOWER(")
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(")")
|
||||
|
||||
builder.WriteString(textOperations[op])
|
||||
builder.WriteString("LOWER(")
|
||||
builder.WriteArg(value)
|
||||
builder.WriteString(")")
|
||||
builder.WriteString(" || '%'")
|
||||
default:
|
||||
panic("unsupported text operation")
|
||||
}
|
||||
}
|
||||
|
||||
type Number interface {
|
||||
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
|
||||
}
|
||||
|
||||
// NumberOperation are operations that can be performed on number values.
|
||||
type NumberOperation uint8
|
||||
|
||||
const (
|
||||
// NumberOperationEqual compares two numbers for equality.
|
||||
NumberOperationEqual NumberOperation = iota + 1
|
||||
// NumberOperationNotEqual compares two numbers for inequality.
|
||||
NumberOperationNotEqual
|
||||
// NumberOperationLessThan compares two numbers to check if the first is less than the second.
|
||||
NumberOperationLessThan
|
||||
// NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
|
||||
NumberOperationAtLeast
|
||||
// NumberOperationGreaterThan compares two numbers to check if the first is greater than the second.
|
||||
NumberOperationGreaterThan
|
||||
// NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
|
||||
NumberOperationAtMost
|
||||
)
|
||||
|
||||
var numberOperations = map[NumberOperation]string{
|
||||
NumberOperationEqual: " = ",
|
||||
NumberOperationNotEqual: " <> ",
|
||||
NumberOperationLessThan: " < ",
|
||||
NumberOperationAtLeast: " <= ",
|
||||
NumberOperationGreaterThan: " > ",
|
||||
NumberOperationAtMost: " >= ",
|
||||
}
|
||||
|
||||
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) {
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(numberOperations[op])
|
||||
builder.WriteArg(value)
|
||||
}
|
||||
|
||||
type Boolean interface {
|
||||
~bool
|
||||
}
|
||||
|
||||
// BooleanOperation are operations that can be performed on boolean values.
|
||||
type BooleanOperation uint8
|
||||
|
||||
const (
|
||||
BooleanOperationIsTrue BooleanOperation = iota + 1
|
||||
BooleanOperationIsFalse
|
||||
)
|
||||
|
||||
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
|
||||
col.WriteQualified(builder)
|
||||
builder.WriteString(" = ")
|
||||
builder.WriteArg(value)
|
||||
}
|
||||
21
backend/v3/storage/database/order.go
Normal file
21
backend/v3/storage/database/order.go
Normal file
@@ -0,0 +1,21 @@
|
||||
package database
|
||||
|
||||
// Order represents a SQL condition.
|
||||
// Its written after the ORDER BY keyword in a SQL statement.
|
||||
type Order interface {
|
||||
Write(builder *StatementBuilder)
|
||||
}
|
||||
|
||||
type orderBy struct {
|
||||
column Column
|
||||
}
|
||||
|
||||
func OrderBy(column Column) Order {
|
||||
return &orderBy{column: column}
|
||||
}
|
||||
|
||||
// Write implements [Order].
|
||||
func (o *orderBy) Write(builder *StatementBuilder) {
|
||||
builder.WriteString(" ORDER BY ")
|
||||
o.column.WriteQualified(builder)
|
||||
}
|
||||
174
backend/v3/storage/database/query.go
Normal file
174
backend/v3/storage/database/query.go
Normal file
@@ -0,0 +1,174 @@
|
||||
package database
|
||||
|
||||
type QueryOption func(opts *QueryOpts)
|
||||
|
||||
// WithCondition sets the condition for the query.
|
||||
func WithCondition(condition Condition) QueryOption {
|
||||
return func(opts *QueryOpts) {
|
||||
opts.Condition = condition
|
||||
}
|
||||
}
|
||||
|
||||
// WithOrderBy sets the columns to order the results by.
|
||||
func WithOrderBy(ordering OrderDirection, orderBy ...Column) QueryOption {
|
||||
return func(opts *QueryOpts) {
|
||||
opts.OrderBy = orderBy
|
||||
opts.Ordering = ordering
|
||||
}
|
||||
}
|
||||
|
||||
func WithOrderByAscending(columns ...Column) QueryOption {
|
||||
return WithOrderBy(OrderDirectionAsc, columns...)
|
||||
}
|
||||
|
||||
func WithOrderByDescending(columns ...Column) QueryOption {
|
||||
return WithOrderBy(OrderDirectionDesc, columns...)
|
||||
}
|
||||
|
||||
// WithLimit sets the maximum number of results to return.
|
||||
func WithLimit(limit uint32) QueryOption {
|
||||
return func(opts *QueryOpts) {
|
||||
opts.Limit = limit
|
||||
}
|
||||
}
|
||||
|
||||
// WithOffset sets the number of results to skip before returning the results.
|
||||
func WithOffset(offset uint32) QueryOption {
|
||||
return func(opts *QueryOpts) {
|
||||
opts.Offset = offset
|
||||
}
|
||||
}
|
||||
|
||||
// WithGroupBy sets the columns to group the results by.
|
||||
func WithGroupBy(groupBy ...Column) QueryOption {
|
||||
return func(opts *QueryOpts) {
|
||||
opts.GroupBy = groupBy
|
||||
}
|
||||
}
|
||||
|
||||
// WithLeftJoin adds a LEFT JOIN to the query.
|
||||
func WithLeftJoin(table string, columns Condition) QueryOption {
|
||||
return func(opts *QueryOpts) {
|
||||
opts.Joins = append(opts.Joins, join{
|
||||
table: table,
|
||||
typ: JoinTypeLeft,
|
||||
columns: columns,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
type joinType string
|
||||
|
||||
const (
|
||||
JoinTypeLeft joinType = "LEFT"
|
||||
)
|
||||
|
||||
type join struct {
|
||||
table string
|
||||
typ joinType
|
||||
columns Condition
|
||||
}
|
||||
|
||||
type OrderDirection uint8
|
||||
|
||||
const (
|
||||
OrderDirectionAsc OrderDirection = iota
|
||||
OrderDirectionDesc
|
||||
)
|
||||
|
||||
// QueryOpts holds the options for a query.
|
||||
// It is used to build the SQL SELECT statement.
|
||||
type QueryOpts struct {
|
||||
// Condition is the condition to filter the results.
|
||||
// It is used to build the WHERE clause of the SQL statement.
|
||||
Condition Condition
|
||||
// OrderBy is the columns to order the results by.
|
||||
// It is used to build the ORDER BY clause of the SQL statement.
|
||||
OrderBy Columns
|
||||
// Ordering defines if the columns should be ordered ascending or descending.
|
||||
// Default is ascending.
|
||||
Ordering OrderDirection
|
||||
// Limit is the maximum number of results to return.
|
||||
// It is used to build the LIMIT clause of the SQL statement.
|
||||
Limit uint32
|
||||
// Offset is the number of results to skip before returning the results.
|
||||
// It is used to build the OFFSET clause of the SQL statement.
|
||||
Offset uint32
|
||||
// GroupBy is the columns to group the results by.
|
||||
// It is used to build the GROUP BY clause of the SQL statement.
|
||||
GroupBy Columns
|
||||
// Joins is a list of joins to be applied to the query.
|
||||
// It is used to build the JOIN clauses of the SQL statement.
|
||||
Joins []join
|
||||
}
|
||||
|
||||
func (opts *QueryOpts) Write(builder *StatementBuilder) {
|
||||
opts.WriteLeftJoins(builder)
|
||||
opts.WriteCondition(builder)
|
||||
opts.WriteGroupBy(builder)
|
||||
opts.WriteOrderBy(builder)
|
||||
opts.WriteLimit(builder)
|
||||
opts.WriteOffset(builder)
|
||||
}
|
||||
|
||||
func (opts *QueryOpts) WriteCondition(builder *StatementBuilder) {
|
||||
if opts.Condition == nil {
|
||||
return
|
||||
}
|
||||
builder.WriteString(" WHERE ")
|
||||
opts.Condition.Write(builder)
|
||||
}
|
||||
|
||||
func (opts *QueryOpts) WriteOrderBy(builder *StatementBuilder) {
|
||||
if len(opts.OrderBy) == 0 {
|
||||
return
|
||||
}
|
||||
builder.WriteString(" ORDER BY ")
|
||||
for i, col := range opts.OrderBy {
|
||||
if i > 0 {
|
||||
builder.WriteString(", ")
|
||||
}
|
||||
col.WriteQualified(builder)
|
||||
if opts.Ordering == OrderDirectionDesc {
|
||||
builder.WriteString(" DESC")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (opts *QueryOpts) WriteLimit(builder *StatementBuilder) {
|
||||
if opts.Limit == 0 {
|
||||
return
|
||||
}
|
||||
builder.WriteString(" LIMIT ")
|
||||
builder.WriteArg(opts.Limit)
|
||||
}
|
||||
|
||||
func (opts *QueryOpts) WriteOffset(builder *StatementBuilder) {
|
||||
if opts.Offset == 0 {
|
||||
return
|
||||
}
|
||||
builder.WriteString(" OFFSET ")
|
||||
builder.WriteArg(opts.Offset)
|
||||
}
|
||||
|
||||
func (opts *QueryOpts) WriteGroupBy(builder *StatementBuilder) {
|
||||
if len(opts.GroupBy) == 0 {
|
||||
return
|
||||
}
|
||||
builder.WriteString(" GROUP BY ")
|
||||
opts.GroupBy.WriteQualified(builder)
|
||||
}
|
||||
|
||||
func (opts *QueryOpts) WriteLeftJoins(builder *StatementBuilder) {
|
||||
if len(opts.Joins) == 0 {
|
||||
return
|
||||
}
|
||||
for _, join := range opts.Joins {
|
||||
builder.WriteString(" ")
|
||||
builder.WriteString(string(join.typ))
|
||||
builder.WriteString(" JOIN ")
|
||||
builder.WriteString(join.table)
|
||||
builder.WriteString(" ON ")
|
||||
join.columns.Write(builder)
|
||||
}
|
||||
}
|
||||
5
backend/v3/storage/database/repository/doc.go
Normal file
5
backend/v3/storage/database/repository/doc.go
Normal file
@@ -0,0 +1,5 @@
|
||||
// Package implements the repositories defined in the domain package.
|
||||
// The repositories are used by the domain package to access the database.
|
||||
// the inheritance.sql file is me over-engineering table inheritance.
|
||||
// I would create a user table which is inherited by human_user and machine_user and the same for objects like idps.
|
||||
package repository
|
||||
135
backend/v3/storage/database/repository/inheritance.sql
Normal file
135
backend/v3/storage/database/repository/inheritance.sql
Normal file
@@ -0,0 +1,135 @@
|
||||
CREATE TABLE objects (
|
||||
id SERIAL PRIMARY KEY,
|
||||
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
|
||||
deleted_at TIMESTAMP
|
||||
);
|
||||
|
||||
CREATE OR REPLACE FUNCTION update_updated_at_column()
|
||||
RETURNS TRIGGER AS $$
|
||||
BEGIN
|
||||
NEW.updated_at = NOW();
|
||||
RETURN NEW;
|
||||
END;
|
||||
$$ LANGUAGE plpgsql;
|
||||
|
||||
CREATE TABLE instances(
|
||||
name VARCHAR(50) NOT NULL
|
||||
, PRIMARY KEY (id)
|
||||
) INHERITS (objects);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON instances
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE instance_objects(
|
||||
instance_id INT NOT NULL
|
||||
, PRIMARY KEY (instance_id, id)
|
||||
-- as foreign keys are not inherited we need to define them on the child tables
|
||||
--, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (objects);
|
||||
|
||||
CREATE TABLE orgs(
|
||||
name VARCHAR(50) NOT NULL
|
||||
, PRIMARY KEY (instance_id, id)
|
||||
, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (instance_objects);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON orgs
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE org_objects(
|
||||
org_id INT NOT NULL
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- as foreign keys are not inherited we need to define them on the child tables
|
||||
-- CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id),
|
||||
-- CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (instance_objects);
|
||||
|
||||
CREATE TABLE users (
|
||||
username VARCHAR(50) NOT NULL
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- as foreign keys are not inherited we need to define them on the child tables
|
||||
-- , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
|
||||
-- , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (org_objects);
|
||||
|
||||
CREATE INDEX idx_users_username ON users(username);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE human_users(
|
||||
first_name VARCHAR(50)
|
||||
, last_name VARCHAR(50)
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id),
|
||||
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
|
||||
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (users);
|
||||
|
||||
CREATE INDEX idx_human_users_username ON human_users(username);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON human_users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE TABLE machine_users(
|
||||
description VARCHAR(50)
|
||||
, PRIMARY KEY (instance_id, org_id, id)
|
||||
-- , CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id)
|
||||
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
|
||||
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
|
||||
) INHERITS (users);
|
||||
|
||||
CREATE INDEX idx_machine_users_username ON machine_users(username);
|
||||
|
||||
CREATE TRIGGER set_updated_at
|
||||
BEFORE UPDATE
|
||||
ON machine_users
|
||||
FOR EACH ROW
|
||||
EXECUTE FUNCTION update_updated_at_column();
|
||||
|
||||
CREATE VIEW users_view AS (
|
||||
SELECT
|
||||
id
|
||||
, created_at
|
||||
, updated_at
|
||||
, deleted_at
|
||||
, instance_id
|
||||
, org_id
|
||||
, username
|
||||
, tableoid::regclass::TEXT AS type
|
||||
, first_name
|
||||
, last_name
|
||||
, NULL AS description
|
||||
FROM
|
||||
human_users
|
||||
|
||||
UNION
|
||||
|
||||
SELECT
|
||||
id
|
||||
, created_at
|
||||
, updated_at
|
||||
, deleted_at
|
||||
, instance_id
|
||||
, org_id
|
||||
, username
|
||||
, tableoid::regclass::TEXT AS type
|
||||
, NULL AS first_name
|
||||
, NULL AS last_name
|
||||
, description
|
||||
FROM
|
||||
machine_users
|
||||
);
|
||||
306
backend/v3/storage/database/repository/instance.go
Normal file
306
backend/v3/storage/database/repository/instance.go
Normal file
@@ -0,0 +1,306 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"time"
|
||||
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var _ domain.InstanceRepository = (*instance)(nil)
|
||||
|
||||
type instance struct {
|
||||
repository
|
||||
shouldLoadDomains bool
|
||||
domainRepo *instanceDomain
|
||||
}
|
||||
|
||||
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
|
||||
return &instance{
|
||||
repository: repository{
|
||||
client: client,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// repository
|
||||
// -------------------------------------------------------------
|
||||
|
||||
const (
|
||||
queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` +
|
||||
` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` +
|
||||
` FROM zitadel.instances`
|
||||
)
|
||||
|
||||
// Get implements [domain.InstanceRepository].
|
||||
func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
|
||||
opts = append(opts,
|
||||
i.joinDomains(),
|
||||
database.WithGroupBy(i.IDColumn()),
|
||||
)
|
||||
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryInstanceStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstance(ctx, i.client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.InstanceRepository].
|
||||
func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
|
||||
opts = append(opts,
|
||||
i.joinDomains(),
|
||||
database.WithGroupBy(i.IDColumn()),
|
||||
)
|
||||
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryInstanceStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstances(ctx, i.client, &builder)
|
||||
}
|
||||
|
||||
func (i *instance) joinDomains() database.QueryOption {
|
||||
columns := make([]database.Condition, 0, 2)
|
||||
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn()))
|
||||
|
||||
// If domains should not be joined, we make sure to return null for the domain columns
|
||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||
if !i.shouldLoadDomains {
|
||||
columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn()))
|
||||
}
|
||||
|
||||
return database.WithLeftJoin(
|
||||
"zitadel.instance_domains",
|
||||
database.And(columns...),
|
||||
)
|
||||
}
|
||||
|
||||
// Create implements [domain.InstanceRepository].
|
||||
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
|
||||
var (
|
||||
builder database.StatementBuilder
|
||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||
)
|
||||
if !instance.CreatedAt.IsZero() {
|
||||
createdAt = instance.CreatedAt
|
||||
}
|
||||
if !instance.UpdatedAt.IsZero() {
|
||||
updatedAt = instance.UpdatedAt
|
||||
}
|
||||
|
||||
builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
|
||||
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
|
||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||
|
||||
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.InstanceRepository].
|
||||
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`UPDATE zitadel.instances SET `)
|
||||
|
||||
database.Changes(changes).Write(&builder)
|
||||
|
||||
idCondition := i.IDCondition(id)
|
||||
writeCondition(&builder, idCondition)
|
||||
|
||||
stmt := builder.String()
|
||||
|
||||
return i.client.Exec(ctx, stmt, builder.Args()...)
|
||||
}
|
||||
|
||||
// Delete implements [domain.InstanceRepository].
|
||||
func (i instance) Delete(ctx context.Context, id string) (int64, error) {
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`DELETE FROM zitadel.instances`)
|
||||
|
||||
idCondition := i.IDCondition(id)
|
||||
writeCondition(&builder, idCondition)
|
||||
|
||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// changes
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// SetName implements [domain.instanceChanges].
|
||||
func (i instance) SetName(name string) database.Change {
|
||||
return database.NewChange(i.NameColumn(), name)
|
||||
}
|
||||
|
||||
// SetUpdatedAt implements [domain.instanceChanges].
|
||||
func (i instance) SetUpdatedAt(time time.Time) database.Change {
|
||||
return database.NewChange(i.UpdatedAtColumn(), time)
|
||||
}
|
||||
|
||||
func (i instance) SetIAMProject(id string) database.Change {
|
||||
return database.NewChange(i.IAMProjectIDColumn(), id)
|
||||
}
|
||||
func (i instance) SetDefaultOrg(id string) database.Change {
|
||||
return database.NewChange(i.DefaultOrgIDColumn(), id)
|
||||
}
|
||||
func (i instance) SetDefaultLanguage(lang language.Tag) database.Change {
|
||||
return database.NewChange(i.DefaultLanguageColumn(), lang.String())
|
||||
}
|
||||
func (i instance) SetConsoleClientID(id string) database.Change {
|
||||
return database.NewChange(i.ConsoleClientIDColumn(), id)
|
||||
}
|
||||
func (i instance) SetConsoleAppID(id string) database.Change {
|
||||
return database.NewChange(i.ConsoleAppIDColumn(), id)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// conditions
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDCondition implements [domain.instanceConditions].
|
||||
func (i instance) IDCondition(id string) database.Condition {
|
||||
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
|
||||
}
|
||||
|
||||
// NameCondition implements [domain.instanceConditions].
|
||||
func (i instance) NameCondition(op database.TextOperation, name string) database.Condition {
|
||||
return database.NewTextCondition(i.NameColumn(), op, name)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDColumn implements [domain.instanceColumns].
|
||||
func (instance) IDColumn() database.Column {
|
||||
return database.NewColumn("instances", "id")
|
||||
}
|
||||
|
||||
// NameColumn implements [domain.instanceColumns].
|
||||
func (instance) NameColumn() database.Column {
|
||||
return database.NewColumn("instances", "name")
|
||||
}
|
||||
|
||||
// CreatedAtColumn implements [domain.instanceColumns].
|
||||
func (instance) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("instances", "created_at")
|
||||
}
|
||||
|
||||
// DefaultOrgIdColumn implements [domain.instanceColumns].
|
||||
func (instance) DefaultOrgIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "default_org_id")
|
||||
}
|
||||
|
||||
// IAMProjectIDColumn implements [domain.instanceColumns].
|
||||
func (instance) IAMProjectIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "iam_project_id")
|
||||
}
|
||||
|
||||
// ConsoleClientIDColumn implements [domain.instanceColumns].
|
||||
func (instance) ConsoleClientIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "console_client_id")
|
||||
}
|
||||
|
||||
// ConsoleAppIDColumn implements [domain.instanceColumns].
|
||||
func (instance) ConsoleAppIDColumn() database.Column {
|
||||
return database.NewColumn("instances", "console_app_id")
|
||||
}
|
||||
|
||||
// DefaultLanguageColumn implements [domain.instanceColumns].
|
||||
func (instance) DefaultLanguageColumn() database.Column {
|
||||
return database.NewColumn("instances", "default_language")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.instanceColumns].
|
||||
func (instance) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("instances", "updated_at")
|
||||
}
|
||||
|
||||
type rawInstance struct {
|
||||
*domain.Instance
|
||||
RawDomains sql.Null[json.RawMessage] `json:"domains,omitzero" db:"domains"`
|
||||
}
|
||||
|
||||
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
|
||||
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var instance rawInstance
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if instance.RawDomains.Valid {
|
||||
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return instance.Instance, nil
|
||||
}
|
||||
|
||||
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Instance, error) {
|
||||
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawInstances []*rawInstance
|
||||
if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
instances := make([]*domain.Instance, len(rawInstances))
|
||||
for i, instance := range rawInstances {
|
||||
if instance.RawDomains.Valid {
|
||||
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
instances[i] = instance.Instance
|
||||
}
|
||||
return instances, nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// sub repositories
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Domains implements [domain.InstanceRepository].
|
||||
func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
|
||||
if !i.shouldLoadDomains {
|
||||
i.shouldLoadDomains = shouldLoad
|
||||
}
|
||||
|
||||
if i.domainRepo != nil {
|
||||
return i.domainRepo
|
||||
}
|
||||
|
||||
i.domainRepo = &instanceDomain{
|
||||
repository: i.repository,
|
||||
instance: i,
|
||||
}
|
||||
return i.domainRepo
|
||||
}
|
||||
214
backend/v3/storage/database/repository/instance_domain.go
Normal file
214
backend/v3/storage/database/repository/instance_domain.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var _ domain.InstanceDomainRepository = (*instanceDomain)(nil)
|
||||
|
||||
type instanceDomain struct {
|
||||
repository
|
||||
*instance
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// repository
|
||||
// -------------------------------------------------------------
|
||||
|
||||
const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_domains.domain, instance_domains.is_primary, instance_domains.created_at, instance_domains.updated_at ` +
|
||||
`FROM zitadel.instance_domains`
|
||||
|
||||
// Get implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
|
||||
func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryInstanceDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstanceDomain(ctx, i.client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance.
|
||||
func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryInstanceDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanInstanceDomains(ctx, i.client, &builder)
|
||||
}
|
||||
|
||||
// Add implements [domain.InstanceDomainRepository].
|
||||
func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
|
||||
var (
|
||||
builder database.StatementBuilder
|
||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||
)
|
||||
if !domain.CreatedAt.IsZero() {
|
||||
createdAt = domain.CreatedAt
|
||||
}
|
||||
if !domain.UpdatedAt.IsZero() {
|
||||
updatedAt = domain.UpdatedAt
|
||||
}
|
||||
|
||||
builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_primary, is_generated, type, created_at, updated_at) VALUES (`)
|
||||
builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt)
|
||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||
|
||||
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance.
|
||||
func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`UPDATE zitadel.instance_domains SET `)
|
||||
database.Changes(changes).Write(&builder)
|
||||
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// Remove implements [domain.InstanceDomainRepository].
|
||||
func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `)
|
||||
condition.Write(&builder)
|
||||
|
||||
return i.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// changes
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// SetPrimary implements [domain.InstanceDomainRepository].
|
||||
func (i instanceDomain) SetPrimary() database.Change {
|
||||
return database.NewChange(i.IsPrimaryColumn(), true)
|
||||
}
|
||||
|
||||
// SetUpdatedAt implements [domain.OrganizationDomainRepository].
|
||||
func (i instanceDomain) SetUpdatedAt(updatedAt time.Time) database.Change {
|
||||
return database.NewChange(i.UpdatedAtColumn(), updatedAt)
|
||||
}
|
||||
|
||||
// SetType implements [domain.InstanceDomainRepository].
|
||||
func (i instanceDomain) SetType(typ domain.DomainType) database.Change {
|
||||
return database.NewChange(i.TypeColumn(), typ)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// conditions
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// DomainCondition implements [domain.InstanceDomainRepository].
|
||||
func (i instanceDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
|
||||
return database.NewTextCondition(i.DomainColumn(), op, domain)
|
||||
}
|
||||
|
||||
// InstanceIDCondition implements [domain.InstanceDomainRepository].
|
||||
func (i instanceDomain) InstanceIDCondition(instanceID string) database.Condition {
|
||||
return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, instanceID)
|
||||
}
|
||||
|
||||
// IsPrimaryCondition implements [domain.InstanceDomainRepository].
|
||||
func (i instanceDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
|
||||
return database.NewBooleanCondition(i.IsPrimaryColumn(), isPrimary)
|
||||
}
|
||||
|
||||
// TypeCondition implements [domain.InstanceDomainRepository].
|
||||
func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition {
|
||||
return database.NewTextCondition(i.TypeColumn(), database.TextOperationEqual, typ.String())
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// CreatedAtColumn implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
|
||||
func (instanceDomain) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "created_at")
|
||||
}
|
||||
|
||||
// DomainColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) DomainColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "domain")
|
||||
}
|
||||
|
||||
// InstanceIDColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "instance_id")
|
||||
}
|
||||
|
||||
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) IsPrimaryColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "is_primary")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
|
||||
func (instanceDomain) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "updated_at")
|
||||
}
|
||||
|
||||
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) IsGeneratedColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "is_generated")
|
||||
}
|
||||
|
||||
// TypeColumn implements [domain.InstanceDomainRepository].
|
||||
func (instanceDomain) TypeColumn() database.Column {
|
||||
return database.NewColumn("instance_domains", "type")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
func scanInstanceDomains(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.InstanceDomain, error) {
|
||||
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var domains []*domain.InstanceDomain
|
||||
if err := rows.(database.CollectableRows).Collect(&domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return domains, nil
|
||||
}
|
||||
|
||||
func scanInstanceDomain(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.InstanceDomain, error) {
|
||||
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
domain := new(domain.InstanceDomain)
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return domain, nil
|
||||
}
|
||||
701
backend/v3/storage/database/repository/instance_domain_test.go
Normal file
701
backend/v3/storage/database/repository/instance_domain_test.go
Normal file
@@ -0,0 +1,701 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
)
|
||||
|
||||
func TestAddInstanceDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain
|
||||
instanceDomain domain.AddInstanceDomain
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "happy path custom domain",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path trusted domain",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeTrusted,
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add primary domain",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add custom domain without domain name",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: "",
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
},
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
{
|
||||
name: "add trusted domain without domain name",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: "",
|
||||
Type: domain.DomainTypeTrusted,
|
||||
},
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
{
|
||||
name: "add custom domain with same domain twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
|
||||
domainName := gofakeit.DomainName()
|
||||
|
||||
instanceDomain := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
}
|
||||
|
||||
err := domainRepo.Add(ctx, instanceDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// return same domain again
|
||||
return &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
}
|
||||
},
|
||||
err: new(database.UniqueError),
|
||||
},
|
||||
{
|
||||
name: "add trusted domain with same domain twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
|
||||
domainName := gofakeit.DomainName()
|
||||
|
||||
instanceDomain := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
Type: domain.DomainTypeTrusted,
|
||||
}
|
||||
|
||||
err := domainRepo.Add(ctx, instanceDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// return same domain again
|
||||
return &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
Type: domain.DomainTypeTrusted,
|
||||
}
|
||||
},
|
||||
err: new(database.UniqueError),
|
||||
},
|
||||
{
|
||||
name: "add domain with non-existent instance",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
InstanceID: "non-existent-instance",
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
},
|
||||
err: new(database.ForeignKeyError),
|
||||
},
|
||||
{
|
||||
name: "add domain without instance id",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
},
|
||||
err: new(database.ForeignKeyError),
|
||||
},
|
||||
{
|
||||
name: "add custom domain without primary",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsGenerated: gu.Ptr(false),
|
||||
},
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
{
|
||||
name: "add custom domain without generated",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeCustom,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
},
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
{
|
||||
name: "add trusted domain with primary",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeTrusted,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
},
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
{
|
||||
name: "add trusted domain with generated",
|
||||
instanceDomain: domain.AddInstanceDomain{
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeTrusted,
|
||||
IsGenerated: gu.Ptr(false),
|
||||
},
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
|
||||
beforeAdd := time.Now()
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
|
||||
var instanceDomain *domain.AddInstanceDomain
|
||||
if test.testFunc != nil {
|
||||
instanceDomain = test.testFunc(ctx, t, domainRepo)
|
||||
} else {
|
||||
instanceDomain = &test.instanceDomain
|
||||
}
|
||||
|
||||
err = domainRepo.Add(ctx, instanceDomain)
|
||||
afterAdd := time.Now()
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, instanceDomain.CreatedAt)
|
||||
assert.NotZero(t, instanceDomain.UpdatedAt)
|
||||
assert.WithinRange(t, instanceDomain.CreatedAt, beforeAdd, afterAdd)
|
||||
assert.WithinRange(t, instanceDomain.UpdatedAt, beforeAdd, afterAdd)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInstanceDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
domainName2 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName1,
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
domain2 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName2,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []database.QueryOption
|
||||
expected *domain.InstanceDomain
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "get primary domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
},
|
||||
expected: &domain.InstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName1,
|
||||
IsPrimary: gu.Ptr(true),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get by domain name",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
||||
},
|
||||
expected: &domain.InstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName2,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get non-existent domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
|
||||
},
|
||||
err: new(database.NoRowFoundError),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
result, err := domainRepo.Get(ctx, test.opts...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected.InstanceID, result.InstanceID)
|
||||
assert.Equal(t, test.expected.Domain, result.Domain)
|
||||
assert.Equal(t, test.expected.IsPrimary, result.IsPrimary)
|
||||
assert.NotEmpty(t, result.CreatedAt)
|
||||
assert.NotEmpty(t, result.UpdatedAt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInstanceDomains(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add multiple domains
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domains := []domain.AddInstanceDomain{
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
Type: domain.DomainTypeTrusted,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range domains {
|
||||
err = domainRepo.Add(t.Context(), &domains[i])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []database.QueryOption
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "list all domains",
|
||||
opts: []database.QueryOption{},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list primary domains",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
},
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "list by instance",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
|
||||
},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list non-existent instance",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.InstanceIDCondition("non-existent")),
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
results, err := domainRepo.List(ctx, test.opts...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, test.expectedCount)
|
||||
|
||||
for _, result := range results {
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.NotEmpty(t, result.Domain)
|
||||
assert.NotEmpty(t, result.CreatedAt)
|
||||
assert.NotEmpty(t, result.UpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInstanceDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domain
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainName := gofakeit.DomainName()
|
||||
instanceDomain := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName,
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), instanceDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition database.Condition
|
||||
changes []database.Change
|
||||
expected int64
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "set primary",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "update non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "no changes",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{},
|
||||
expected: 0,
|
||||
err: database.ErrNoChanges,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, rowsAffected)
|
||||
|
||||
// verify changes were applied if rows were affected
|
||||
if rowsAffected > 0 && len(test.changes) > 0 {
|
||||
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
|
||||
require.NoError(t, err)
|
||||
|
||||
// We know changes were applied since rowsAffected > 0
|
||||
// The specific verification of what changed is less important
|
||||
// than knowing the operation succeeded
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveInstanceDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: domainName1,
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
domain2 := &domain.AddInstanceDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(false),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition database.Condition
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "remove by domain name",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "remove by primary condition",
|
||||
condition: domainRepo.IsPrimaryCondition(false),
|
||||
expected: 1, // domain2 should still exist and be non-primary
|
||||
},
|
||||
{
|
||||
name: "remove non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
// count before removal
|
||||
beforeCount, err := domainRepo.List(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, rowsAffected)
|
||||
|
||||
// verify removal
|
||||
afterCount, err := domainRepo.List(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstanceDomainConditions(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition database.Condition
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "domain condition equal",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "example.com"),
|
||||
expected: "instance_domains.domain = $1",
|
||||
},
|
||||
{
|
||||
name: "domain condition starts with",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationStartsWith, "example"),
|
||||
expected: "instance_domains.domain LIKE $1 || '%'",
|
||||
},
|
||||
{
|
||||
name: "instance id condition",
|
||||
condition: domainRepo.InstanceIDCondition("instance-123"),
|
||||
expected: "instance_domains.instance_id = $1",
|
||||
},
|
||||
{
|
||||
name: "is primary true",
|
||||
condition: domainRepo.IsPrimaryCondition(true),
|
||||
expected: "instance_domains.is_primary = $1",
|
||||
},
|
||||
{
|
||||
name: "is primary false",
|
||||
condition: domainRepo.IsPrimaryCondition(false),
|
||||
expected: "instance_domains.is_primary = $1",
|
||||
},
|
||||
{
|
||||
name: "is type custom",
|
||||
condition: domainRepo.TypeCondition(domain.DomainTypeCustom),
|
||||
expected: "instance_domains.type = $1",
|
||||
},
|
||||
{
|
||||
name: "is type trusted",
|
||||
condition: domainRepo.TypeCondition(domain.DomainTypeTrusted),
|
||||
expected: "instance_domains.type = $1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var builder database.StatementBuilder
|
||||
test.condition.Write(&builder)
|
||||
assert.Equal(t, test.expected, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestInstanceDomainChanges(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
change database.Change
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "set primary",
|
||||
change: domainRepo.SetPrimary(),
|
||||
expected: "is_primary = $1",
|
||||
},
|
||||
{
|
||||
name: "set type",
|
||||
change: domainRepo.SetType(domain.DomainTypeCustom),
|
||||
expected: "type = $1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var builder database.StatementBuilder
|
||||
test.change.Write(&builder)
|
||||
assert.Equal(t, test.expected, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
723
backend/v3/storage/database/repository/instance_test.go
Normal file
723
backend/v3/storage/database/repository/instance_test.go
Normal file
@@ -0,0 +1,723 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
)
|
||||
|
||||
func TestCreateInstance(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
instance: func() domain.Instance {
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
return instance
|
||||
}(),
|
||||
},
|
||||
{
|
||||
name: "create instance without name",
|
||||
instance: func() domain.Instance {
|
||||
instanceId := gofakeit.Name()
|
||||
// instanceName := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: "",
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
return instance
|
||||
}(),
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
{
|
||||
name: "adding same instance twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
// change the name to make sure same only the id clashes
|
||||
inst.Name = gofakeit.Name()
|
||||
require.NoError(t, err)
|
||||
return &inst
|
||||
},
|
||||
err: new(database.UniqueError),
|
||||
},
|
||||
func() struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
} {
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
return struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
}{
|
||||
name: "adding instance with same name twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// change the id
|
||||
inst.ID = instanceId
|
||||
inst.CreatedAt = time.Time{}
|
||||
inst.UpdatedAt = time.Time{}
|
||||
return &inst
|
||||
},
|
||||
instance: domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
},
|
||||
// two instances can have the sane name
|
||||
err: nil,
|
||||
}
|
||||
}(),
|
||||
{
|
||||
name: "adding instance with no id",
|
||||
instance: func() domain.Instance {
|
||||
// instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
// ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
return instance
|
||||
}(),
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
var instance *domain.Instance
|
||||
if tt.testFunc != nil {
|
||||
instance = tt.testFunc(ctx, t)
|
||||
} else {
|
||||
instance = &tt.instance
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
// create instance
|
||||
beforeCreate := time.Now()
|
||||
err := instanceRepo.Create(ctx, instance)
|
||||
assert.ErrorIs(t, err, tt.err)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
afterCreate := time.Now()
|
||||
|
||||
// check instance values
|
||||
instance, err = instanceRepo.Get(ctx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.instance.ID, instance.ID)
|
||||
assert.Equal(t, tt.instance.Name, instance.Name)
|
||||
assert.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID)
|
||||
assert.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID)
|
||||
assert.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID)
|
||||
assert.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID)
|
||||
assert.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage)
|
||||
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
|
||||
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateInstance(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
rowsAffected int64
|
||||
getErr error
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
return &inst
|
||||
},
|
||||
rowsAffected: 1,
|
||||
},
|
||||
{
|
||||
name: "update deleted instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// delete instance
|
||||
affectedRows, err := instanceRepo.Delete(ctx,
|
||||
inst.ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), affectedRows)
|
||||
|
||||
return &inst
|
||||
},
|
||||
rowsAffected: 0,
|
||||
},
|
||||
{
|
||||
name: "update non existent instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceId := gofakeit.Name()
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
}
|
||||
return &inst
|
||||
},
|
||||
rowsAffected: 0,
|
||||
getErr: new(database.NoRowFoundError),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
instance := tt.testFunc(ctx, t)
|
||||
|
||||
beforeUpdate := time.Now()
|
||||
// update name
|
||||
newName := "new_" + instance.Name
|
||||
rowsAffected, err := instanceRepo.Update(ctx,
|
||||
instance.ID,
|
||||
instanceRepo.SetName(newName),
|
||||
)
|
||||
afterUpdate := time.Now()
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.Equal(t, tt.rowsAffected, rowsAffected)
|
||||
|
||||
if rowsAffected == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
// check instance values
|
||||
instance, err = instanceRepo.Get(ctx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
)
|
||||
require.Equal(t, tt.getErr, err)
|
||||
|
||||
assert.Equal(t, newName, instance.Name)
|
||||
assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetInstance(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
err error
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
func() test {
|
||||
instanceId := gofakeit.Name()
|
||||
return test{
|
||||
name: "happy path get using id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
return &inst
|
||||
},
|
||||
}
|
||||
}(),
|
||||
{
|
||||
name: "happy path including domains",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
d := &domain.AddInstanceDomain{
|
||||
InstanceID: inst.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsPrimary: gu.Ptr(true),
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
err = domainRepo.Add(ctx, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
|
||||
InstanceID: d.InstanceID,
|
||||
Domain: d.Domain,
|
||||
IsPrimary: d.IsPrimary,
|
||||
IsGenerated: d.IsGenerated,
|
||||
Type: d.Type,
|
||||
CreatedAt: d.CreatedAt,
|
||||
UpdatedAt: d.UpdatedAt,
|
||||
})
|
||||
|
||||
return &inst
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get non existent instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: "get non existent instance",
|
||||
}
|
||||
return &inst
|
||||
},
|
||||
err: new(database.NoRowFoundError),
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
var instance *domain.Instance
|
||||
if tt.testFunc != nil {
|
||||
instance = tt.testFunc(ctx, t)
|
||||
}
|
||||
|
||||
// check instance values
|
||||
returnedInstance, err := instanceRepo.Get(ctx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
)
|
||||
if tt.err != nil {
|
||||
require.ErrorIs(t, err, tt.err)
|
||||
return
|
||||
}
|
||||
|
||||
if instance.ID == "get non existent instance" {
|
||||
assert.Nil(t, returnedInstance)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, returnedInstance.ID, instance.ID)
|
||||
assert.Equal(t, returnedInstance.Name, instance.Name)
|
||||
assert.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID)
|
||||
assert.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID)
|
||||
assert.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID)
|
||||
assert.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID)
|
||||
assert.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListInstance(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pool, stop, err := newEmbeddedDB(ctx)
|
||||
require.NoError(t, err)
|
||||
defer stop()
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
|
||||
conditionClauses []database.Condition
|
||||
noInstanceReturned bool
|
||||
}
|
||||
tests := []test{
|
||||
{
|
||||
name: "happy path single instance no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
}
|
||||
|
||||
return instances
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "happy path multiple instance no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
noOfInstances := 5
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
}
|
||||
|
||||
return instances
|
||||
},
|
||||
},
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
return test{
|
||||
name: "instance filter on id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
}
|
||||
|
||||
return instances
|
||||
},
|
||||
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceName := gofakeit.Name()
|
||||
return test{
|
||||
name: "multiple instance filter on name",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
noOfInstances := 5
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
}
|
||||
|
||||
return instances
|
||||
},
|
||||
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)},
|
||||
}
|
||||
}(),
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
instances := tt.testFunc(ctx, t)
|
||||
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
var condition database.Condition
|
||||
if len(tt.conditionClauses) > 0 {
|
||||
condition = database.And(tt.conditionClauses...)
|
||||
}
|
||||
|
||||
// check instance values
|
||||
returnedInstances, err := instanceRepo.List(ctx,
|
||||
database.WithCondition(condition),
|
||||
database.WithOrderByAscending(instanceRepo.CreatedAtColumn()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
if tt.noInstanceReturned {
|
||||
assert.Nil(t, returnedInstances)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(instances), len(returnedInstances))
|
||||
for i, instance := range instances {
|
||||
assert.Equal(t, returnedInstances[i].ID, instance.ID)
|
||||
assert.Equal(t, returnedInstances[i].Name, instance.Name)
|
||||
assert.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID)
|
||||
assert.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID)
|
||||
assert.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID)
|
||||
assert.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID)
|
||||
assert.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeleteInstance(t *testing.T) {
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T)
|
||||
instanceID string
|
||||
noOfDeletedRows int64
|
||||
}
|
||||
tests := []test{
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
var noOfInstances int64 = 1
|
||||
return test{
|
||||
name: "happy path delete single instance filter id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
}
|
||||
},
|
||||
instanceID: instanceId,
|
||||
noOfDeletedRows: noOfInstances,
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
non_existent_instance_name := gofakeit.Name()
|
||||
return test{
|
||||
name: "delete non existent instance",
|
||||
instanceID: non_existent_instance_name,
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceName := gofakeit.Name()
|
||||
return test{
|
||||
name: "deleted already deleted instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
}
|
||||
|
||||
// delete instance
|
||||
affectedRows, err := instanceRepo.Delete(ctx,
|
||||
instances[0].ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), affectedRows)
|
||||
},
|
||||
instanceID: instanceName,
|
||||
// this test should return 0 affected rows as the instance was already deleted
|
||||
noOfDeletedRows: 0,
|
||||
}
|
||||
}(),
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
if tt.testFunc != nil {
|
||||
tt.testFunc(ctx, t)
|
||||
}
|
||||
|
||||
// delete instance
|
||||
noOfDeletedRows, err := instanceRepo.Delete(ctx,
|
||||
tt.instanceID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
||||
|
||||
// check instance was deleted
|
||||
instance, err := instanceRepo.Get(ctx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(tt.instanceID),
|
||||
),
|
||||
)
|
||||
require.ErrorIs(t, err, new(database.NoRowFoundError))
|
||||
assert.Nil(t, instance)
|
||||
})
|
||||
}
|
||||
}
|
||||
282
backend/v3/storage/database/repository/org.go
Normal file
282
backend/v3/storage/database/repository/org.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// repository
|
||||
// -------------------------------------------------------------
|
||||
|
||||
var _ domain.OrganizationRepository = (*org)(nil)
|
||||
|
||||
type org struct {
|
||||
repository
|
||||
shouldLoadDomains bool
|
||||
domainRepo domain.OrganizationDomainRepository
|
||||
}
|
||||
|
||||
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
|
||||
return &org{
|
||||
repository: repository{
|
||||
client: client,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
|
||||
` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` +
|
||||
` FROM zitadel.organizations`
|
||||
|
||||
// Get implements [domain.OrganizationRepository].
|
||||
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
|
||||
opts = append(opts,
|
||||
o.joinDomains(),
|
||||
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
||||
)
|
||||
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganization(ctx, o.client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.OrganizationRepository].
|
||||
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
|
||||
opts = append(opts,
|
||||
o.joinDomains(),
|
||||
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
|
||||
)
|
||||
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganizations(ctx, o.client, &builder)
|
||||
}
|
||||
|
||||
func (o *org) joinDomains() database.QueryOption {
|
||||
columns := make([]database.Condition, 0, 3)
|
||||
columns = append(columns,
|
||||
database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()),
|
||||
database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()),
|
||||
)
|
||||
|
||||
// If domains should not be joined, we make sure to return null for the domain columns
|
||||
// the query optimizer of the dialect should optimize this away if no domains are requested
|
||||
if !o.shouldLoadDomains {
|
||||
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
|
||||
}
|
||||
|
||||
return database.WithLeftJoin(
|
||||
"zitadel.org_domains",
|
||||
database.And(columns...),
|
||||
)
|
||||
}
|
||||
|
||||
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
|
||||
` VALUES ($1, $2, $3, $4)` +
|
||||
` RETURNING created_at, updated_at`
|
||||
|
||||
// Create implements [domain.OrganizationRepository].
|
||||
func (o *org) Create(ctx context.Context, organization *domain.Organization) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
|
||||
builder.WriteString(createOrganizationStmt)
|
||||
|
||||
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.OrganizationRepository].
|
||||
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(`UPDATE zitadel.organizations SET `)
|
||||
|
||||
instanceIDCondition := o.InstanceIDCondition(instanceID)
|
||||
|
||||
conditions := []database.Condition{id, instanceIDCondition}
|
||||
database.Changes(changes).Write(&builder)
|
||||
writeCondition(&builder, database.And(conditions...))
|
||||
|
||||
stmt := builder.String()
|
||||
|
||||
rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...)
|
||||
return rowsAffected, err
|
||||
}
|
||||
|
||||
// Delete implements [domain.OrganizationRepository].
|
||||
func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) {
|
||||
builder := database.StatementBuilder{}
|
||||
|
||||
builder.WriteString(`DELETE FROM zitadel.organizations`)
|
||||
|
||||
instanceIDCondition := o.InstanceIDCondition(instanceID)
|
||||
|
||||
conditions := []database.Condition{id, instanceIDCondition}
|
||||
writeCondition(&builder, database.And(conditions...))
|
||||
|
||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// changes
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// SetName implements [domain.organizationChanges].
|
||||
func (o org) SetName(name string) database.Change {
|
||||
return database.NewChange(o.NameColumn(), name)
|
||||
}
|
||||
|
||||
// SetState implements [domain.organizationChanges].
|
||||
func (o org) SetState(state domain.OrgState) database.Change {
|
||||
return database.NewChange(o.StateColumn(), state)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// conditions
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDCondition implements [domain.organizationConditions].
|
||||
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
|
||||
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
|
||||
}
|
||||
|
||||
// NameCondition implements [domain.organizationConditions].
|
||||
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
|
||||
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
|
||||
}
|
||||
|
||||
// InstanceIDCondition implements [domain.organizationConditions].
|
||||
func (o org) InstanceIDCondition(instanceID string) database.Condition {
|
||||
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
|
||||
}
|
||||
|
||||
// StateCondition implements [domain.organizationConditions].
|
||||
func (o org) StateCondition(state domain.OrgState) database.Condition {
|
||||
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// IDColumn implements [domain.organizationColumns].
|
||||
func (org) IDColumn() database.Column {
|
||||
return database.NewColumn("organizations", "id")
|
||||
}
|
||||
|
||||
// NameColumn implements [domain.organizationColumns].
|
||||
func (org) NameColumn() database.Column {
|
||||
return database.NewColumn("organizations", "name")
|
||||
}
|
||||
|
||||
// InstanceIDColumn implements [domain.organizationColumns].
|
||||
func (org) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn("organizations", "instance_id")
|
||||
}
|
||||
|
||||
// StateColumn implements [domain.organizationColumns].
|
||||
func (org) StateColumn() database.Column {
|
||||
return database.NewColumn("organizations", "state")
|
||||
}
|
||||
|
||||
// CreatedAtColumn implements [domain.organizationColumns].
|
||||
func (org) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("organizations", "created_at")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.organizationColumns].
|
||||
func (org) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("organizations", "updated_at")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type rawOrganization struct {
|
||||
*domain.Organization
|
||||
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
|
||||
}
|
||||
|
||||
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
|
||||
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var org rawOrganization
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return org.Organization, nil
|
||||
}
|
||||
|
||||
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
|
||||
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var rawOrgs []*rawOrganization
|
||||
if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
organizations := make([]*domain.Organization, len(rawOrgs))
|
||||
for i, org := range rawOrgs {
|
||||
if len(org.RawDomains) > 0 {
|
||||
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
organizations[i] = org.Organization
|
||||
}
|
||||
return organizations, nil
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// sub repositories
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Domains implements [domain.OrganizationRepository].
|
||||
func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
|
||||
if !o.shouldLoadDomains {
|
||||
o.shouldLoadDomains = shouldLoad
|
||||
}
|
||||
|
||||
if o.domainRepo != nil {
|
||||
return o.domainRepo
|
||||
}
|
||||
|
||||
o.domainRepo = &orgDomain{
|
||||
repository: o.repository,
|
||||
org: o,
|
||||
}
|
||||
|
||||
return o.domainRepo
|
||||
}
|
||||
230
backend/v3/storage/database/repository/org_domain.go
Normal file
230
backend/v3/storage/database/repository/org_domain.go
Normal file
@@ -0,0 +1,230 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
var _ domain.OrganizationDomainRepository = (*orgDomain)(nil)
|
||||
|
||||
type orgDomain struct {
|
||||
repository
|
||||
*org
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// repository
|
||||
// -------------------------------------------------------------
|
||||
|
||||
const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_verified, is_primary, validation_type, created_at, updated_at ` +
|
||||
`FROM zitadel.org_domains`
|
||||
|
||||
// Get implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org.
|
||||
func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganizationDomain(ctx, o.client, &builder)
|
||||
}
|
||||
|
||||
// List implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org.
|
||||
func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
builder.WriteString(queryOrganizationDomainStmt)
|
||||
options.Write(&builder)
|
||||
|
||||
return scanOrganizationDomains(ctx, o.client, &builder)
|
||||
}
|
||||
|
||||
// Add implements [domain.OrganizationDomainRepository].
|
||||
func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error {
|
||||
var (
|
||||
builder database.StatementBuilder
|
||||
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
|
||||
)
|
||||
if !domain.CreatedAt.IsZero() {
|
||||
createdAt = domain.CreatedAt
|
||||
}
|
||||
if !domain.UpdatedAt.IsZero() {
|
||||
updatedAt = domain.UpdatedAt
|
||||
}
|
||||
|
||||
builder.WriteString(`INSERT INTO zitadel.org_domains (instance_id, org_id, domain, is_verified, is_primary, validation_type, created_at, updated_at) VALUES (`)
|
||||
builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt)
|
||||
builder.WriteString(`) RETURNING created_at, updated_at`)
|
||||
|
||||
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
|
||||
}
|
||||
|
||||
// Update implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org.
|
||||
func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
|
||||
if len(changes) == 0 {
|
||||
return 0, database.ErrNoChanges
|
||||
}
|
||||
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`UPDATE zitadel.org_domains SET `)
|
||||
database.Changes(changes).Write(&builder)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// Remove implements [domain.OrganizationDomainRepository].
|
||||
func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
|
||||
var builder database.StatementBuilder
|
||||
|
||||
builder.WriteString(`DELETE FROM zitadel.org_domains `)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
return o.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// changes
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// SetPrimary implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) SetPrimary() database.Change {
|
||||
return database.NewChange(o.IsPrimaryColumn(), true)
|
||||
}
|
||||
|
||||
// SetValidationType implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change {
|
||||
return database.NewChange(o.ValidationTypeColumn(), verificationType)
|
||||
}
|
||||
|
||||
// SetVerified implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) SetVerified() database.Change {
|
||||
return database.NewChange(o.IsVerifiedColumn(), true)
|
||||
}
|
||||
|
||||
// SetUpdatedAt implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) SetUpdatedAt(updatedAt time.Time) database.Change {
|
||||
return database.NewChange(o.UpdatedAtColumn(), updatedAt)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// conditions
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// DomainCondition implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
|
||||
return database.NewTextCondition(o.DomainColumn(), op, domain)
|
||||
}
|
||||
|
||||
// InstanceIDCondition implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDCondition of orgDomain.org.
|
||||
func (o orgDomain) InstanceIDCondition(instanceID string) database.Condition {
|
||||
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
|
||||
}
|
||||
|
||||
// IsPrimaryCondition implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
|
||||
return database.NewBooleanCondition(o.IsPrimaryColumn(), isPrimary)
|
||||
}
|
||||
|
||||
// IsVerifiedCondition implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) IsVerifiedCondition(isVerified bool) database.Condition {
|
||||
return database.NewBooleanCondition(o.IsVerifiedColumn(), isVerified)
|
||||
}
|
||||
|
||||
// OrgIDCondition implements [domain.OrganizationDomainRepository].
|
||||
func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
|
||||
return database.NewTextCondition(o.OrgIDColumn(), database.TextOperationEqual, orgID)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
|
||||
func (orgDomain) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "created_at")
|
||||
}
|
||||
|
||||
// DomainColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) DomainColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "domain")
|
||||
}
|
||||
|
||||
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
|
||||
func (orgDomain) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "instance_id")
|
||||
}
|
||||
|
||||
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) IsPrimaryColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "is_primary")
|
||||
}
|
||||
|
||||
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) IsVerifiedColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "is_verified")
|
||||
}
|
||||
|
||||
// OrgIDColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) OrgIDColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "org_id")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
|
||||
// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
|
||||
func (orgDomain) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "updated_at")
|
||||
}
|
||||
|
||||
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
|
||||
func (orgDomain) ValidationTypeColumn() database.Column {
|
||||
return database.NewColumn("org_domains", "validation_type")
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// scanners
|
||||
// -------------------------------------------------------------
|
||||
|
||||
func scanOrganizationDomain(ctx context.Context, client database.Querier, builder *database.StatementBuilder) (*domain.OrganizationDomain, error) {
|
||||
rows, err := client.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
domain := &domain.OrganizationDomain{}
|
||||
if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return domain, nil
|
||||
}
|
||||
|
||||
func scanOrganizationDomains(ctx context.Context, client database.Querier, builder *database.StatementBuilder) ([]*domain.OrganizationDomain, error) {
|
||||
rows, err := client.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var domains []*domain.OrganizationDomain
|
||||
if err := rows.(database.CollectableRows).Collect(&domains); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return domains, nil
|
||||
}
|
||||
972
backend/v3/storage/database/repository/org_domain_test.go
Normal file
972
backend/v3/storage/database/repository/org_domain_test.go
Normal file
@@ -0,0 +1,972 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/brianvoe/gofakeit/v6"
|
||||
"github.com/muhlemmer/gu"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
|
||||
)
|
||||
|
||||
func TestAddOrganizationDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
err := instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain
|
||||
organizationDomain domain.AddOrganizationDomain
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add verified domain",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add primary domain",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "add domain without domain name",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: "",
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
err: new(database.CheckError),
|
||||
},
|
||||
{
|
||||
name: "add domain with same domain twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain {
|
||||
domainName := gofakeit.DomainName()
|
||||
|
||||
organizationDomain := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName,
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
}
|
||||
|
||||
err := domainRepo.Add(ctx, organizationDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
// return same domain again
|
||||
return &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
}
|
||||
},
|
||||
err: new(database.UniqueError),
|
||||
},
|
||||
{
|
||||
name: "add domain with non-existent instance",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
InstanceID: "non-existent-instance",
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
err: new(database.ForeignKeyError),
|
||||
},
|
||||
{
|
||||
name: "add domain with non-existent organization",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: "non-existent-org",
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
err: new(database.ForeignKeyError),
|
||||
},
|
||||
{
|
||||
name: "add domain without instance id",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
err: new(database.ForeignKeyError),
|
||||
},
|
||||
{
|
||||
name: "add domain without org id",
|
||||
organizationDomain: domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
err: new(database.ForeignKeyError),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
|
||||
var organizationDomain *domain.AddOrganizationDomain
|
||||
if test.testFunc != nil {
|
||||
organizationDomain = test.testFunc(ctx, t, domainRepo)
|
||||
} else {
|
||||
organizationDomain = &test.organizationDomain
|
||||
}
|
||||
|
||||
err = domainRepo.Add(ctx, organizationDomain)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.NotZero(t, organizationDomain.CreatedAt)
|
||||
assert.NotZero(t, organizationDomain.UpdatedAt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetOrganizationDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
domainName2 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
}
|
||||
domain2 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName2,
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []database.QueryOption
|
||||
expected *domain.OrganizationDomain
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "get primary domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get by domain name",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName2,
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get by org ID",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get verified domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
|
||||
},
|
||||
expected: &domain.OrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "get non-existent domain",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
|
||||
},
|
||||
err: new(database.NoRowFoundError),
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
result, err := domainRepo.Get(ctx, test.opts...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected.InstanceID, result.InstanceID)
|
||||
assert.Equal(t, test.expected.OrgID, result.OrgID)
|
||||
assert.Equal(t, test.expected.Domain, result.Domain)
|
||||
assert.Equal(t, test.expected.IsVerified, result.IsVerified)
|
||||
assert.Equal(t, test.expected.IsPrimary, result.IsPrimary)
|
||||
assert.Equal(t, test.expected.ValidationType, result.ValidationType)
|
||||
assert.NotEmpty(t, result.CreatedAt)
|
||||
assert.NotEmpty(t, result.UpdatedAt)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestListOrganizationDomains(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add multiple domains
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domains := []domain.AddOrganizationDomain{
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
},
|
||||
{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
IsVerified: true,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
}
|
||||
|
||||
for i := range domains {
|
||||
err = domainRepo.Add(t.Context(), &domains[i])
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
opts []database.QueryOption
|
||||
expectedCount int
|
||||
}{
|
||||
{
|
||||
name: "list all domains",
|
||||
opts: []database.QueryOption{},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list verified domains",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
|
||||
},
|
||||
expectedCount: 2,
|
||||
},
|
||||
{
|
||||
name: "list primary domains",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
|
||||
},
|
||||
expectedCount: 1,
|
||||
},
|
||||
{
|
||||
name: "list by organization",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
|
||||
},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list by instance",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
|
||||
},
|
||||
expectedCount: 3,
|
||||
},
|
||||
{
|
||||
name: "list non-existent organization",
|
||||
opts: []database.QueryOption{
|
||||
database.WithCondition(domainRepo.OrgIDCondition("non-existent")),
|
||||
},
|
||||
expectedCount: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
results, err := domainRepo.List(ctx, test.opts...)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, results, test.expectedCount)
|
||||
|
||||
for _, result := range results {
|
||||
assert.Equal(t, instanceID, result.InstanceID)
|
||||
assert.Equal(t, orgID, result.OrgID)
|
||||
assert.NotEmpty(t, result.Domain)
|
||||
assert.NotEmpty(t, result.CreatedAt)
|
||||
assert.NotEmpty(t, result.UpdatedAt)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUpdateOrganizationDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domain
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainName := gofakeit.DomainName()
|
||||
organizationDomain := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName,
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), organizationDomain)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition database.Condition
|
||||
changes []database.Change
|
||||
expected int64
|
||||
err error
|
||||
}{
|
||||
{
|
||||
name: "set verified",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "set primary",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetPrimary()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "set validation type",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple changes",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{
|
||||
domainRepo.SetVerified(),
|
||||
domainRepo.SetPrimary(),
|
||||
domainRepo.SetValidationType(domain.DomainValidationTypeDNS),
|
||||
},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "update by org ID and domain",
|
||||
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "update non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
changes: []database.Change{domainRepo.SetVerified()},
|
||||
expected: 0,
|
||||
},
|
||||
{
|
||||
name: "no changes",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
|
||||
changes: []database.Change{},
|
||||
expected: 0,
|
||||
err: database.ErrNoChanges,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
|
||||
if test.err != nil {
|
||||
assert.ErrorIs(t, err, test.err)
|
||||
return
|
||||
}
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, rowsAffected)
|
||||
|
||||
// verify changes were applied if rows were affected
|
||||
if rowsAffected > 0 && len(test.changes) > 0 {
|
||||
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
|
||||
require.NoError(t, err)
|
||||
|
||||
// We know changes were applied since rowsAffected > 0
|
||||
// The specific verification of what changed is less important
|
||||
// than knowing the operation succeeded
|
||||
assert.NotNil(t, result)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRemoveOrganizationDomain(t *testing.T) {
|
||||
// create instance
|
||||
instanceID := gofakeit.UUID()
|
||||
instance := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleClient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create organization
|
||||
orgID := gofakeit.UUID()
|
||||
organization := domain.Organization{
|
||||
ID: orgID,
|
||||
Name: gofakeit.Name(),
|
||||
InstanceID: instanceID,
|
||||
State: domain.OrgStateActive,
|
||||
}
|
||||
|
||||
tx, err := pool.Begin(t.Context(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, tx.Rollback(t.Context()))
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository(tx)
|
||||
err = instanceRepo.Create(t.Context(), &instance)
|
||||
require.NoError(t, err)
|
||||
|
||||
orgRepo := repository.OrganizationRepository(tx)
|
||||
err = orgRepo.Create(t.Context(), &organization)
|
||||
require.NoError(t, err)
|
||||
|
||||
// add domains
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
domainName1 := gofakeit.DomainName()
|
||||
domainName2 := gofakeit.DomainName()
|
||||
|
||||
domain1 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName1,
|
||||
IsVerified: true,
|
||||
IsPrimary: true,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
|
||||
}
|
||||
domain2 := &domain.AddOrganizationDomain{
|
||||
InstanceID: instanceID,
|
||||
OrgID: orgID,
|
||||
Domain: domainName2,
|
||||
IsVerified: false,
|
||||
IsPrimary: false,
|
||||
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
|
||||
}
|
||||
|
||||
err = domainRepo.Add(t.Context(), domain1)
|
||||
require.NoError(t, err)
|
||||
err = domainRepo.Add(t.Context(), domain2)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition database.Condition
|
||||
expected int64
|
||||
}{
|
||||
{
|
||||
name: "remove by domain name",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "remove by primary condition",
|
||||
condition: domainRepo.IsPrimaryCondition(false),
|
||||
expected: 1, // domain2 should still exist and be non-primary
|
||||
},
|
||||
{
|
||||
name: "remove by org ID and domain",
|
||||
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
|
||||
expected: 1,
|
||||
},
|
||||
{
|
||||
name: "remove non-existent domain",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
|
||||
expected: 0,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
ctx := t.Context()
|
||||
|
||||
snapshot, err := tx.Begin(ctx)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
require.NoError(t, snapshot.Rollback(ctx))
|
||||
}()
|
||||
|
||||
orgRepo := repository.OrganizationRepository(snapshot)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
|
||||
// count before removal
|
||||
beforeCount, err := domainRepo.List(ctx)
|
||||
require.NoError(t, err)
|
||||
|
||||
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, test.expected, rowsAffected)
|
||||
|
||||
// verify removal
|
||||
afterCount, err := domainRepo.List(ctx)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationDomainConditions(t *testing.T) {
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
condition database.Condition
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "domain condition equal",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationEqual, "example.com"),
|
||||
expected: "org_domains.domain = $1",
|
||||
},
|
||||
{
|
||||
name: "domain condition starts with",
|
||||
condition: domainRepo.DomainCondition(database.TextOperationStartsWith, "example"),
|
||||
expected: "org_domains.domain LIKE $1 || '%'",
|
||||
},
|
||||
{
|
||||
name: "instance id condition",
|
||||
condition: domainRepo.InstanceIDCondition("instance-123"),
|
||||
expected: "org_domains.instance_id = $1",
|
||||
},
|
||||
{
|
||||
name: "org id condition",
|
||||
condition: domainRepo.OrgIDCondition("org-123"),
|
||||
expected: "org_domains.org_id = $1",
|
||||
},
|
||||
{
|
||||
name: "is primary true",
|
||||
condition: domainRepo.IsPrimaryCondition(true),
|
||||
expected: "org_domains.is_primary = $1",
|
||||
},
|
||||
{
|
||||
name: "is primary false",
|
||||
condition: domainRepo.IsPrimaryCondition(false),
|
||||
expected: "org_domains.is_primary = $1",
|
||||
},
|
||||
{
|
||||
name: "is verified true",
|
||||
condition: domainRepo.IsVerifiedCondition(true),
|
||||
expected: "org_domains.is_verified = $1",
|
||||
},
|
||||
{
|
||||
name: "is verified false",
|
||||
condition: domainRepo.IsVerifiedCondition(false),
|
||||
expected: "org_domains.is_verified = $1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var builder database.StatementBuilder
|
||||
test.condition.Write(&builder)
|
||||
assert.Equal(t, test.expected, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationDomainChanges(t *testing.T) {
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
change database.Change
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "set verified",
|
||||
change: domainRepo.SetVerified(),
|
||||
expected: "is_verified = $1",
|
||||
},
|
||||
{
|
||||
name: "set primary",
|
||||
change: domainRepo.SetPrimary(),
|
||||
expected: "is_primary = $1",
|
||||
},
|
||||
{
|
||||
name: "set validation type DNS",
|
||||
change: domainRepo.SetValidationType(domain.DomainValidationTypeDNS),
|
||||
expected: "validation_type = $1",
|
||||
},
|
||||
{
|
||||
name: "set validation type HTTP",
|
||||
change: domainRepo.SetValidationType(domain.DomainValidationTypeHTTP),
|
||||
expected: "validation_type = $1",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var builder database.StatementBuilder
|
||||
test.change.Write(&builder)
|
||||
assert.Equal(t, test.expected, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestOrganizationDomainColumns(t *testing.T) {
|
||||
orgRepo := repository.OrganizationRepository(pool)
|
||||
domainRepo := orgRepo.Domains(false)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
column database.Column
|
||||
qualified bool
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "instance id column qualified",
|
||||
column: domainRepo.InstanceIDColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.instance_id",
|
||||
},
|
||||
{
|
||||
name: "instance id column unqualified",
|
||||
column: domainRepo.InstanceIDColumn(),
|
||||
qualified: false,
|
||||
expected: "instance_id",
|
||||
},
|
||||
{
|
||||
name: "org id column qualified",
|
||||
column: domainRepo.OrgIDColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.org_id",
|
||||
},
|
||||
{
|
||||
name: "org id column unqualified",
|
||||
column: domainRepo.OrgIDColumn(),
|
||||
qualified: false,
|
||||
expected: "org_id",
|
||||
},
|
||||
{
|
||||
name: "domain column qualified",
|
||||
column: domainRepo.DomainColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.domain",
|
||||
},
|
||||
{
|
||||
name: "domain column unqualified",
|
||||
column: domainRepo.DomainColumn(),
|
||||
qualified: false,
|
||||
expected: "domain",
|
||||
},
|
||||
{
|
||||
name: "is verified column qualified",
|
||||
column: domainRepo.IsVerifiedColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.is_verified",
|
||||
},
|
||||
{
|
||||
name: "is verified column unqualified",
|
||||
column: domainRepo.IsVerifiedColumn(),
|
||||
qualified: false,
|
||||
expected: "is_verified",
|
||||
},
|
||||
{
|
||||
name: "is primary column qualified",
|
||||
column: domainRepo.IsPrimaryColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.is_primary",
|
||||
},
|
||||
{
|
||||
name: "is primary column unqualified",
|
||||
column: domainRepo.IsPrimaryColumn(),
|
||||
qualified: false,
|
||||
expected: "is_primary",
|
||||
},
|
||||
{
|
||||
name: "validation type column qualified",
|
||||
column: domainRepo.ValidationTypeColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.validation_type",
|
||||
},
|
||||
{
|
||||
name: "validation type column unqualified",
|
||||
column: domainRepo.ValidationTypeColumn(),
|
||||
qualified: false,
|
||||
expected: "validation_type",
|
||||
},
|
||||
{
|
||||
name: "created at column qualified",
|
||||
column: domainRepo.CreatedAtColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.created_at",
|
||||
},
|
||||
{
|
||||
name: "created at column unqualified",
|
||||
column: domainRepo.CreatedAtColumn(),
|
||||
qualified: false,
|
||||
expected: "created_at",
|
||||
},
|
||||
{
|
||||
name: "updated at column qualified",
|
||||
column: domainRepo.UpdatedAtColumn(),
|
||||
qualified: true,
|
||||
expected: "org_domains.updated_at",
|
||||
},
|
||||
{
|
||||
name: "updated at column unqualified",
|
||||
column: domainRepo.UpdatedAtColumn(),
|
||||
qualified: false,
|
||||
expected: "updated_at",
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var builder database.StatementBuilder
|
||||
if test.qualified {
|
||||
test.column.WriteQualified(&builder)
|
||||
} else {
|
||||
test.column.WriteUnqualified(&builder)
|
||||
}
|
||||
assert.Equal(t, test.expected, builder.String())
|
||||
})
|
||||
}
|
||||
}
|
||||
28
backend/v3/storage/database/repository/org_member.go
Normal file
28
backend/v3/storage/database/repository/org_member.go
Normal file
@@ -0,0 +1,28 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
)
|
||||
|
||||
type orgMember struct {
|
||||
*org
|
||||
}
|
||||
|
||||
// AddMember implements [domain.MemberRepository].
|
||||
func (o *orgMember) AddMember(ctx context.Context, orgID string, userID string, roles []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveMember implements [domain.MemberRepository].
|
||||
func (o *orgMember) RemoveMember(ctx context.Context, orgID string, userID string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetMemberRoles implements [domain.MemberRepository].
|
||||
func (o *orgMember) SetMemberRoles(ctx context.Context, orgID string, userID string, roles []string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ domain.MemberRepository = (*orgMember)(nil)
|
||||
1008
backend/v3/storage/database/repository/org_test.go
Normal file
1008
backend/v3/storage/database/repository/org_test.go
Normal file
File diff suppressed because it is too large
Load Diff
20
backend/v3/storage/database/repository/repository.go
Normal file
20
backend/v3/storage/database/repository/repository.go
Normal file
@@ -0,0 +1,20 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type repository struct {
|
||||
client database.QueryExecutor
|
||||
}
|
||||
|
||||
func writeCondition(
|
||||
builder *database.StatementBuilder,
|
||||
condition database.Condition,
|
||||
) {
|
||||
if condition == nil {
|
||||
return
|
||||
}
|
||||
builder.WriteString(" WHERE ")
|
||||
condition.Write(builder)
|
||||
}
|
||||
51
backend/v3/storage/database/repository/repository_test.go
Normal file
51
backend/v3/storage/database/repository/repository_test.go
Normal file
@@ -0,0 +1,51 @@
|
||||
package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
os.Exit(runTests(m))
|
||||
}
|
||||
|
||||
var pool database.PoolTest
|
||||
|
||||
func runTests(m *testing.M) int {
|
||||
var stop func()
|
||||
var err error
|
||||
ctx := context.Background()
|
||||
pool, stop, err = newEmbeddedDB(ctx)
|
||||
if err != nil {
|
||||
log.Printf("error with embedded postgres database: %v", err)
|
||||
return 1
|
||||
}
|
||||
defer stop()
|
||||
|
||||
return m.Run()
|
||||
}
|
||||
|
||||
func newEmbeddedDB(ctx context.Context) (pool database.PoolTest, stop func(), err error) {
|
||||
connector, stop, err := embedded.StartEmbedded()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to start embedded postgres: %w", err)
|
||||
}
|
||||
|
||||
pool_, err := connector.Connect(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to connect to embedded postgres: %w", err)
|
||||
}
|
||||
pool = pool_.(database.PoolTest)
|
||||
|
||||
err = pool.MigrateTest(ctx)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("unable to migrate database: %w", err)
|
||||
}
|
||||
return pool, stop, err
|
||||
}
|
||||
282
backend/v3/storage/database/repository/user.go
Normal file
282
backend/v3/storage/database/repository/user.go
Normal file
@@ -0,0 +1,282 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_at, updated_at, deleted_at,` +
|
||||
` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` +
|
||||
` FROM users_view users`
|
||||
|
||||
type user struct {
|
||||
repository
|
||||
}
|
||||
|
||||
func UserRepository(client database.QueryExecutor) domain.UserRepository {
|
||||
return &user{
|
||||
repository: repository{
|
||||
client: client,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
var _ domain.UserRepository = (*user)(nil)
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// repository
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Human implements [domain.UserRepository].
|
||||
func (u *user) Human() domain.HumanRepository {
|
||||
return &userHuman{user: u}
|
||||
}
|
||||
|
||||
// Machine implements [domain.UserRepository].
|
||||
func (u *user) Machine() domain.MachineRepository {
|
||||
return &userMachine{user: u}
|
||||
}
|
||||
|
||||
// List implements [domain.UserRepository].
|
||||
func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(queryUserStmt)
|
||||
options.WriteCondition(&builder)
|
||||
options.WriteOrderBy(&builder)
|
||||
options.WriteLimit(&builder)
|
||||
options.WriteOffset(&builder)
|
||||
|
||||
rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
defer func() {
|
||||
closeErr := rows.Close()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
err = closeErr
|
||||
}()
|
||||
for rows.Next() {
|
||||
user, err := scanUser(rows)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users = append(users, user)
|
||||
}
|
||||
if err := rows.Err(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users, nil
|
||||
}
|
||||
|
||||
// Get implements [domain.UserRepository].
|
||||
func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) {
|
||||
options := new(database.QueryOpts)
|
||||
for _, opt := range opts {
|
||||
opt(options)
|
||||
}
|
||||
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(queryUserStmt)
|
||||
options.WriteCondition(&builder)
|
||||
options.WriteOrderBy(&builder)
|
||||
options.WriteLimit(&builder)
|
||||
options.WriteOffset(&builder)
|
||||
|
||||
return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
|
||||
}
|
||||
|
||||
const (
|
||||
createHumanStmt = `INSERT INTO human_users (instance_id, org_id, user_id, username, first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at)` +
|
||||
` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` +
|
||||
` RETURNING created_at, updated_at`
|
||||
createMachineStmt = `INSERT INTO user_machines (instance_id, org_id, user_id, username, description)` +
|
||||
` VALUES ($1, $2, $3, $4, $5)` +
|
||||
` RETURNING created_at, updated_at`
|
||||
)
|
||||
|
||||
// Create implements [domain.UserRepository].
|
||||
func (u *user) Create(ctx context.Context, user *domain.User) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
|
||||
switch trait := user.Traits.(type) {
|
||||
case *domain.Human:
|
||||
builder.WriteString(createHumanStmt)
|
||||
builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
|
||||
case *domain.Machine:
|
||||
builder.WriteString(createMachineStmt)
|
||||
builder.AppendArgs(trait.Description)
|
||||
}
|
||||
return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
|
||||
}
|
||||
|
||||
// Delete implements [domain.UserRepository].
|
||||
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString("DELETE FROM users")
|
||||
writeCondition(&builder, condition)
|
||||
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return err
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// changes
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// SetUsername implements [domain.userChanges].
|
||||
func (u user) SetUsername(username string) database.Change {
|
||||
return database.NewChange(u.UsernameColumn(), username)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// conditions
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// InstanceIDCondition implements [domain.userConditions].
|
||||
func (u user) InstanceIDCondition(instanceID string) database.Condition {
|
||||
return database.NewTextCondition(u.InstanceIDColumn(), database.TextOperationEqual, instanceID)
|
||||
}
|
||||
|
||||
// OrgIDCondition implements [domain.userConditions].
|
||||
func (u user) OrgIDCondition(orgID string) database.Condition {
|
||||
return database.NewTextCondition(u.OrgIDColumn(), database.TextOperationEqual, orgID)
|
||||
}
|
||||
|
||||
// IDCondition implements [domain.userConditions].
|
||||
func (u user) IDCondition(userID string) database.Condition {
|
||||
return database.NewTextCondition(u.IDColumn(), database.TextOperationEqual, userID)
|
||||
}
|
||||
|
||||
// UsernameCondition implements [domain.userConditions].
|
||||
func (u user) UsernameCondition(op database.TextOperation, username string) database.Condition {
|
||||
return database.NewTextCondition(u.UsernameColumn(), op, username)
|
||||
}
|
||||
|
||||
// CreatedAtCondition implements [domain.userConditions].
|
||||
func (u user) CreatedAtCondition(op database.NumberOperation, createdAt time.Time) database.Condition {
|
||||
return database.NewNumberCondition(u.CreatedAtColumn(), op, createdAt)
|
||||
}
|
||||
|
||||
// UpdatedAtCondition implements [domain.userConditions].
|
||||
func (u user) UpdatedAtCondition(op database.NumberOperation, updatedAt time.Time) database.Condition {
|
||||
return database.NewNumberCondition(u.UpdatedAtColumn(), op, updatedAt)
|
||||
}
|
||||
|
||||
// DeletedCondition implements [domain.userConditions].
|
||||
func (u user) DeletedCondition(isDeleted bool) database.Condition {
|
||||
if isDeleted {
|
||||
return database.IsNotNull(u.DeletedAtColumn())
|
||||
}
|
||||
return database.IsNull(u.DeletedAtColumn())
|
||||
}
|
||||
|
||||
// DeletedAtCondition implements [domain.userConditions].
|
||||
func (u user) DeletedAtCondition(op database.NumberOperation, deletedAt time.Time) database.Condition {
|
||||
return database.NewNumberCondition(u.DeletedAtColumn(), op, deletedAt)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// InstanceIDColumn implements [domain.userColumns].
|
||||
func (user) InstanceIDColumn() database.Column {
|
||||
return database.NewColumn("users", "instance_id")
|
||||
}
|
||||
|
||||
// OrgIDColumn implements [domain.userColumns].
|
||||
func (user) OrgIDColumn() database.Column {
|
||||
return database.NewColumn("users", "org_id")
|
||||
}
|
||||
|
||||
// IDColumn implements [domain.userColumns].
|
||||
func (user) IDColumn() database.Column {
|
||||
return database.NewColumn("users", "id")
|
||||
}
|
||||
|
||||
// UsernameColumn implements [domain.userColumns].
|
||||
func (user) UsernameColumn() database.Column {
|
||||
return database.NewColumn("users", "username")
|
||||
}
|
||||
|
||||
// FirstNameColumn implements [domain.userColumns].
|
||||
func (user) CreatedAtColumn() database.Column {
|
||||
return database.NewColumn("users", "created_at")
|
||||
}
|
||||
|
||||
// UpdatedAtColumn implements [domain.userColumns].
|
||||
func (user) UpdatedAtColumn() database.Column {
|
||||
return database.NewColumn("users", "updated_at")
|
||||
}
|
||||
|
||||
// DeletedAtColumn implements [domain.userColumns].
|
||||
func (user) DeletedAtColumn() database.Column {
|
||||
return database.NewColumn("users", "deleted_at")
|
||||
}
|
||||
|
||||
func (u user) columns() database.Columns {
|
||||
return database.Columns{
|
||||
u.InstanceIDColumn(),
|
||||
u.OrgIDColumn(),
|
||||
u.IDColumn(),
|
||||
u.UsernameColumn(),
|
||||
u.CreatedAtColumn(),
|
||||
u.UpdatedAtColumn(),
|
||||
u.DeletedAtColumn(),
|
||||
}
|
||||
}
|
||||
|
||||
func scanUser(scanner database.Scanner) (*domain.User, error) {
|
||||
var (
|
||||
user domain.User
|
||||
human domain.Human
|
||||
email domain.Email
|
||||
phone domain.Phone
|
||||
machine domain.Machine
|
||||
typ domain.UserType
|
||||
)
|
||||
err := scanner.Scan(
|
||||
&user.InstanceID,
|
||||
&user.OrgID,
|
||||
&user.ID,
|
||||
&user.Username,
|
||||
&typ,
|
||||
&user.CreatedAt,
|
||||
&user.UpdatedAt,
|
||||
&user.DeletedAt,
|
||||
&human.FirstName,
|
||||
&human.LastName,
|
||||
&email.Address,
|
||||
&email.VerifiedAt,
|
||||
&phone.Number,
|
||||
&phone.VerifiedAt,
|
||||
&machine.Description,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch typ {
|
||||
case domain.UserTypeHuman:
|
||||
if email.Address != "" {
|
||||
human.Email = &email
|
||||
}
|
||||
if phone.Number != "" {
|
||||
human.Phone = &phone
|
||||
}
|
||||
user.Traits = &human
|
||||
case domain.UserTypeMachine:
|
||||
user.Traits = &machine
|
||||
}
|
||||
|
||||
return &user, nil
|
||||
}
|
||||
208
backend/v3/storage/database/repository/user_human.go
Normal file
208
backend/v3/storage/database/repository/user_human.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// repository
|
||||
// -------------------------------------------------------------
|
||||
|
||||
type userHuman struct {
|
||||
*user
|
||||
}
|
||||
|
||||
var _ domain.HumanRepository = (*userHuman)(nil)
|
||||
|
||||
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
|
||||
|
||||
// GetEmail implements [domain.HumanRepository].
|
||||
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
|
||||
var email domain.Email
|
||||
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(userEmailQuery)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
|
||||
&email.Address,
|
||||
&email.VerifiedAt,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &email, nil
|
||||
}
|
||||
|
||||
// Update implements [domain.HumanRepository].
|
||||
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(`UPDATE human_users SET `)
|
||||
database.Changes(changes).Write(&builder)
|
||||
writeCondition(&builder, condition)
|
||||
|
||||
stmt := builder.String()
|
||||
|
||||
_, err := h.client.Exec(ctx, stmt, builder.Args()...)
|
||||
return err
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// changes
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// SetFirstName implements [domain.humanChanges].
|
||||
func (h userHuman) SetFirstName(firstName string) database.Change {
|
||||
return database.NewChange(h.FirstNameColumn(), firstName)
|
||||
}
|
||||
|
||||
// SetLastName implements [domain.humanChanges].
|
||||
func (h userHuman) SetLastName(lastName string) database.Change {
|
||||
return database.NewChange(h.LastNameColumn(), lastName)
|
||||
}
|
||||
|
||||
// SetEmail implements [domain.humanChanges].
|
||||
func (h userHuman) SetEmail(address string, verified *time.Time) database.Change {
|
||||
return database.NewChanges(
|
||||
h.SetEmailAddress(address),
|
||||
database.NewChangePtr(h.EmailVerifiedAtColumn(), verified),
|
||||
)
|
||||
}
|
||||
|
||||
// SetEmailAddress implements [domain.humanChanges].
|
||||
func (h userHuman) SetEmailAddress(address string) database.Change {
|
||||
return database.NewChange(h.EmailAddressColumn(), address)
|
||||
}
|
||||
|
||||
// SetEmailVerifiedAt implements [domain.humanChanges].
|
||||
func (h userHuman) SetEmailVerifiedAt(at time.Time) database.Change {
|
||||
if at.IsZero() {
|
||||
return database.NewChange(h.EmailVerifiedAtColumn(), database.NowInstruction)
|
||||
}
|
||||
return database.NewChange(h.EmailVerifiedAtColumn(), at)
|
||||
}
|
||||
|
||||
// SetPhone implements [domain.humanChanges].
|
||||
func (h userHuman) SetPhone(number string, verifiedAt *time.Time) database.Change {
|
||||
return database.NewChanges(
|
||||
h.SetPhoneNumber(number),
|
||||
database.NewChangePtr(h.PhoneVerifiedAtColumn(), verifiedAt),
|
||||
)
|
||||
}
|
||||
|
||||
// SetPhoneNumber implements [domain.humanChanges].
|
||||
func (h userHuman) SetPhoneNumber(number string) database.Change {
|
||||
return database.NewChange(h.PhoneNumberColumn(), number)
|
||||
}
|
||||
|
||||
// SetPhoneVerifiedAt implements [domain.humanChanges].
|
||||
func (h userHuman) SetPhoneVerifiedAt(at time.Time) database.Change {
|
||||
if at.IsZero() {
|
||||
return database.NewChange(h.PhoneVerifiedAtColumn(), database.NowInstruction)
|
||||
}
|
||||
return database.NewChange(h.PhoneVerifiedAtColumn(), at)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// conditions
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// FirstNameCondition implements [domain.humanConditions].
|
||||
func (h userHuman) FirstNameCondition(op database.TextOperation, firstName string) database.Condition {
|
||||
return database.NewTextCondition(h.FirstNameColumn(), op, firstName)
|
||||
}
|
||||
|
||||
// LastNameCondition implements [domain.humanConditions].
|
||||
func (h userHuman) LastNameCondition(op database.TextOperation, lastName string) database.Condition {
|
||||
return database.NewTextCondition(h.LastNameColumn(), op, lastName)
|
||||
}
|
||||
|
||||
// EmailAddressCondition implements [domain.humanConditions].
|
||||
func (h userHuman) EmailAddressCondition(op database.TextOperation, email string) database.Condition {
|
||||
return database.NewTextCondition(h.EmailAddressColumn(), op, email)
|
||||
}
|
||||
|
||||
// EmailVerifiedCondition implements [domain.humanConditions].
|
||||
func (h *userHuman) EmailVerifiedCondition(isVerified bool) database.Condition {
|
||||
if isVerified {
|
||||
return database.IsNotNull(h.EmailVerifiedAtColumn())
|
||||
}
|
||||
return database.IsNull(h.EmailVerifiedAtColumn())
|
||||
}
|
||||
|
||||
// EmailVerifiedAtCondition implements [domain.humanConditions].
|
||||
func (h userHuman) EmailVerifiedAtCondition(op database.NumberOperation, verifiedAt time.Time) database.Condition {
|
||||
return database.NewNumberCondition(h.EmailVerifiedAtColumn(), op, verifiedAt)
|
||||
}
|
||||
|
||||
// PhoneNumberCondition implements [domain.humanConditions].
|
||||
func (h userHuman) PhoneNumberCondition(op database.TextOperation, phoneNumber string) database.Condition {
|
||||
return database.NewTextCondition(h.PhoneNumberColumn(), op, phoneNumber)
|
||||
}
|
||||
|
||||
// PhoneVerifiedCondition implements [domain.humanConditions].
|
||||
func (h userHuman) PhoneVerifiedCondition(isVerified bool) database.Condition {
|
||||
if isVerified {
|
||||
return database.IsNotNull(h.PhoneVerifiedAtColumn())
|
||||
}
|
||||
return database.IsNull(h.PhoneVerifiedAtColumn())
|
||||
}
|
||||
|
||||
// PhoneVerifiedAtCondition implements [domain.humanConditions].
|
||||
func (h userHuman) PhoneVerifiedAtCondition(op database.NumberOperation, verifiedAt time.Time) database.Condition {
|
||||
return database.NewNumberCondition(h.PhoneVerifiedAtColumn(), op, verifiedAt)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// FirstNameColumn implements [domain.humanColumns].
|
||||
func (h userHuman) FirstNameColumn() database.Column {
|
||||
return database.NewColumn("user_humans", "first_name")
|
||||
}
|
||||
|
||||
// LastNameColumn implements [domain.humanColumns].
|
||||
func (h userHuman) LastNameColumn() database.Column {
|
||||
return database.NewColumn("user_humans", "last_name")
|
||||
}
|
||||
|
||||
// EmailAddressColumn implements [domain.humanColumns].
|
||||
func (h userHuman) EmailAddressColumn() database.Column {
|
||||
return database.NewColumn("user_humans", "email_address")
|
||||
}
|
||||
|
||||
// EmailVerifiedAtColumn implements [domain.humanColumns].
|
||||
func (h userHuman) EmailVerifiedAtColumn() database.Column {
|
||||
return database.NewColumn("user_humans", "email_verified_at")
|
||||
}
|
||||
|
||||
// PhoneNumberColumn implements [domain.humanColumns].
|
||||
func (h userHuman) PhoneNumberColumn() database.Column {
|
||||
return database.NewColumn("user_humans", "phone_number")
|
||||
}
|
||||
|
||||
// PhoneVerifiedAtColumn implements [domain.humanColumns].
|
||||
func (h userHuman) PhoneVerifiedAtColumn() database.Column {
|
||||
return database.NewColumn("user_humans", "phone_verified_at")
|
||||
}
|
||||
|
||||
// func (h userHuman) columns() database.Columns {
|
||||
// return append(h.user.columns(),
|
||||
// h.FirstNameColumn(),
|
||||
// h.LastNameColumn(),
|
||||
// h.EmailAddressColumn(),
|
||||
// h.EmailVerifiedAtColumn(),
|
||||
// h.PhoneNumberColumn(),
|
||||
// h.PhoneVerifiedAtColumn(),
|
||||
// )
|
||||
// }
|
||||
|
||||
// func (h userHuman) writeReturning(builder *database.StatementBuilder) {
|
||||
// builder.WriteString(" RETURNING ")
|
||||
// h.columns().Write(builder)
|
||||
// }
|
||||
67
backend/v3/storage/database/repository/user_machine.go
Normal file
67
backend/v3/storage/database/repository/user_machine.go
Normal file
@@ -0,0 +1,67 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/backend/v3/domain"
|
||||
"github.com/zitadel/zitadel/backend/v3/storage/database"
|
||||
)
|
||||
|
||||
type userMachine struct {
|
||||
*user
|
||||
}
|
||||
|
||||
var _ domain.MachineRepository = (*userMachine)(nil)
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// repository
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// Update implements [domain.MachineRepository].
|
||||
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString("UPDATE user_machines SET ")
|
||||
database.Changes(changes).Write(&builder)
|
||||
writeCondition(&builder, condition)
|
||||
m.writeReturning()
|
||||
|
||||
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
|
||||
return err
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// changes
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// SetDescription implements [domain.machineChanges].
|
||||
func (m userMachine) SetDescription(description string) database.Change {
|
||||
return database.NewChange(m.DescriptionColumn(), description)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// conditions
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// DescriptionCondition implements [domain.machineConditions].
|
||||
func (m userMachine) DescriptionCondition(op database.TextOperation, description string) database.Condition {
|
||||
return database.NewTextCondition(m.DescriptionColumn(), op, description)
|
||||
}
|
||||
|
||||
// -------------------------------------------------------------
|
||||
// columns
|
||||
// -------------------------------------------------------------
|
||||
|
||||
// DescriptionColumn implements [domain.machineColumns].
|
||||
func (m userMachine) DescriptionColumn() database.Column {
|
||||
return database.NewColumn("user_machines", "description")
|
||||
}
|
||||
|
||||
func (m userMachine) columns() database.Columns {
|
||||
return append(m.user.columns(), m.DescriptionColumn())
|
||||
}
|
||||
|
||||
func (m *userMachine) writeReturning() {
|
||||
builder := database.StatementBuilder{}
|
||||
builder.WriteString(" RETURNING ")
|
||||
m.columns().WriteQualified(&builder)
|
||||
}
|
||||
68
backend/v3/storage/database/statement.go
Normal file
68
backend/v3/storage/database/statement.go
Normal file
@@ -0,0 +1,68 @@
|
||||
package database
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type Instruction string
|
||||
|
||||
const (
|
||||
DefaultInstruction Instruction = "DEFAULT"
|
||||
NowInstruction Instruction = "NOW()"
|
||||
NullInstruction Instruction = "NULL"
|
||||
)
|
||||
|
||||
// StatementBuilder is a helper to build SQL statement.
|
||||
type StatementBuilder struct {
|
||||
strings.Builder
|
||||
args []any
|
||||
existingArgs map[any]string
|
||||
}
|
||||
|
||||
// WriteArgs adds the argument to the statement and writes the placeholder to the query.
|
||||
func (b *StatementBuilder) WriteArg(arg any) {
|
||||
b.WriteString(b.AppendArg(arg))
|
||||
}
|
||||
|
||||
// WriteArgs adds the arguments to the statement and writes the placeholders to the query.
|
||||
// The placeholders are comma separated.
|
||||
func (b *StatementBuilder) WriteArgs(args ...any) {
|
||||
for i, arg := range args {
|
||||
if i > 0 {
|
||||
b.WriteString(", ")
|
||||
}
|
||||
b.WriteArg(arg)
|
||||
}
|
||||
}
|
||||
|
||||
// AppendArg adds the argument to the statement and returns the placeholder.
|
||||
func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
|
||||
if b.existingArgs == nil {
|
||||
b.existingArgs = make(map[any]string)
|
||||
}
|
||||
if placeholder, ok := b.existingArgs[arg]; ok {
|
||||
return placeholder
|
||||
}
|
||||
if instruction, ok := arg.(Instruction); ok {
|
||||
return string(instruction)
|
||||
}
|
||||
|
||||
b.args = append(b.args, arg)
|
||||
placeholder = "$" + strconv.Itoa(len(b.args))
|
||||
b.existingArgs[arg] = placeholder
|
||||
return placeholder
|
||||
}
|
||||
|
||||
// AppendArgs adds the arguments to the statement and doesn't return the placeholders.
|
||||
// If an argument is already added, it will not be added again.
|
||||
func (b *StatementBuilder) AppendArgs(args ...any) {
|
||||
for _, arg := range args {
|
||||
b.AppendArg(arg)
|
||||
}
|
||||
}
|
||||
|
||||
// Args returns the arguments added to the statement.
|
||||
func (b *StatementBuilder) Args() []any {
|
||||
return b.args
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user