feat(backend): state persisted objects (#9870)

This PR initiates the rework of Zitadel's backend to state-persisted
objects. This change is a step towards a more scalable and maintainable
architecture.

## Changes

* **New `/backend/v3` package**: A new package structure has been
introduced to house the reworked backend logic. This includes:
* `domain`: Contains the core business logic, commands, and repository
interfaces.
* `storage`: Implements the repository interfaces for database
interactions with new transactional tables.
  * `telemetry`: Provides logging and tracing capabilities.
* **Transactional Tables**: New database tables have been defined for
`instances`, `instance_domains`, `organizations`, and `org_domains`.
* **Projections**: New projections have been created to populate the new
relational tables from the existing event store, ensuring data
consistency during the migration.
* **Repositories**: New repositories provide an abstraction layer for
accessing and manipulating the data in the new tables.
* **Setup**: A new setup step for `TransactionalTables` has been added
to manage the database migrations for the new tables.

This PR lays the foundation for future work to fully transition to
state-persisted objects for these components, which will improve
performance and simplify data access patterns.

This PR initiates the rework of ZITADEL's backend to state-persisted
objects. This is a foundational step towards a new architecture that
will improve performance and maintainability.

The following objects are migrated from event-sourced aggregates to
state-persisted objects:

* Instances
  * incl. Domains
* Orgs
  * incl. Domains

The structure of the new backend implementation follows the software
architecture defined in this [wiki
page](https://github.com/zitadel/zitadel/wiki/Software-Architecturel).

This PR includes:

* The initial implementation of the new transactional repositories for
the objects listed above.
* Projections to populate the new relational tables from the existing
event store.
* Adjustments to the build and test process to accommodate the new
backend structure.

This is a work in progress and further changes will be made to complete
the migration.

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Iraq Jaber <iraq+github@zitadel.com>
Co-authored-by: Iraq <66622793+kkrime@users.noreply.github.com>
Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
Silvan
2025-09-05 10:54:34 +02:00
committed by GitHub
parent 8cc79c1376
commit 61cab8878e
119 changed files with 13940 additions and 11 deletions

View File

@@ -138,7 +138,7 @@ core_integration_server_start: core_integration_setup
.PHONY: core_integration_test_packages
core_integration_test_packages:
go test -race -count 1 -tags integration -timeout 60m -parallel 1 $$(go list -tags integration ./... | grep "integration_test")
go test -race -count 1 -tags integration -timeout 5m -parallel 1 $$(go list -tags integration ./... | grep -e "integration_test" -e "events_testing") -run ^TestServer_TestInstanceReduces$
.PHONY: core_integration_server_stop
core_integration_server_stop:
@@ -152,7 +152,7 @@ core_integration_server_stop:
.PHONY: core_integration_reports
core_integration_reports:
go tool covdata textfmt -i=tmp/coverage -pkg=github.com/zitadel/zitadel/internal/...,github.com/zitadel/zitadel/cmd/... -o profile.cov
go tool covdata textfmt -i=tmp/coverage -pkg=github.com/zitadel/zitadel/internal/...,github.com/zitadel/zitadel/cmd/...,github.com/zitadel/zitadel/backend/... -o profile.cov
.PHONY: core_integration_test
core_integration_test: core_integration_server_start core_integration_test_packages core_integration_server_stop core_integration_reports

10
backend/main.go Normal file
View File

@@ -0,0 +1,10 @@
// huhu thanks for looking at this.
// you can find more comments and doc.go files in the v3 package.
// to get an overview i would start in the api package.
package main
// import "github.com/zitadel/zitadel/backend/cmd"
// func main() {
// cmd.Execute()
// }

View File

@@ -0,0 +1,21 @@
package v2
// this file has been commented out to pass the linter
// import (
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
// )
// var (
// logger logging.Logger
// tracer tracing.Tracer
// )
// func SetLogger(l logging.Logger) {
// logger = l
// }
// func SetTracer(t tracing.Tracer) {
// tracer = t
// }

View File

@@ -0,0 +1,33 @@
package orgv2
// import (
// "context"
// "github.com/zitadel/zitadel/backend/v3/domain"
// "github.com/zitadel/zitadel/pkg/grpc/org/v2"
// )
// func CreateOrg(ctx context.Context, req *org.AddOrganizationRequest) (resp *org.AddOrganizationResponse, err error) {
// cmd := domain.NewAddOrgCommand(
// req.GetName(),
// addOrgAdminToCommand(req.GetAdmins()...)...,
// )
// err = domain.Invoke(ctx, cmd)
// if err != nil {
// return nil, err
// }
// return &org.AddOrganizationResponse{
// OrganizationId: cmd.ID,
// }, nil
// }
// func addOrgAdminToCommand(admins ...*org.AddOrganizationRequest_Admin) []*domain.AddMemberCommand {
// cmds := make([]*domain.AddMemberCommand, len(admins))
// for i, admin := range admins {
// cmds[i] = &domain.AddMemberCommand{
// UserID: admin.GetUserId(),
// Roles: admin.GetRoles(),
// }
// }
// return cmds
// }

View File

@@ -0,0 +1,21 @@
package orgv2
// this file has been commented out to pass the linter
// import (
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
// )
// var (
// logger logging.Logger
// tracer tracing.Tracer
// )
// func SetLogger(l logging.Logger) {
// logger = l
// }
// func SetTracer(t tracing.Tracer) {
// tracer = t
// }

View File

@@ -0,0 +1,93 @@
package userv2
// import (
// "context"
// "github.com/zitadel/zitadel/backend/v3/domain"
// "github.com/zitadel/zitadel/pkg/grpc/user/v2"
// )
// func SetEmail(ctx context.Context, req *user.SetEmailRequest) (resp *user.SetEmailResponse, err error) {
// var (
// verification domain.SetEmailOpt
// returnCode *domain.ReturnCodeCommand
// )
// switch req.GetVerification().(type) {
// case *user.SetEmailRequest_IsVerified:
// verification = domain.NewEmailVerifiedCommand(req.GetUserId(), req.GetIsVerified())
// case *user.SetEmailRequest_SendCode:
// verification = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
// case *user.SetEmailRequest_ReturnCode:
// returnCode = domain.NewReturnCodeCommand(req.GetUserId())
// verification = returnCode
// default:
// verification = domain.NewSendCodeCommand(req.GetUserId(), nil)
// }
// err = domain.Invoke(ctx, domain.NewSetEmailCommand(req.GetUserId(), req.GetEmail(), verification))
// if err != nil {
// return nil, err
// }
// var code *string
// if returnCode != nil && returnCode.Code != "" {
// code = &returnCode.Code
// }
// return &user.SetEmailResponse{
// VerificationCode: code,
// }, nil
// }
// func SendEmailCode(ctx context.Context, req *user.SendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
// var (
// returnCode *domain.ReturnCodeCommand
// cmd domain.Commander
// )
// switch req.GetVerification().(type) {
// case *user.SendEmailCodeRequest_SendCode:
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
// case *user.SendEmailCodeRequest_ReturnCode:
// returnCode = domain.NewReturnCodeCommand(req.GetUserId())
// cmd = returnCode
// default:
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
// }
// err = domain.Invoke(ctx, cmd)
// if err != nil {
// return nil, err
// }
// resp = new(user.SendEmailCodeResponse)
// if returnCode != nil {
// resp.VerificationCode = &returnCode.Code
// }
// return resp, nil
// }
// func ResendEmailCode(ctx context.Context, req *user.ResendEmailCodeRequest) (resp *user.SendEmailCodeResponse, err error) {
// var (
// returnCode *domain.ReturnCodeCommand
// cmd domain.Commander
// )
// switch req.GetVerification().(type) {
// case *user.ResendEmailCodeRequest_SendCode:
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
// case *user.ResendEmailCodeRequest_ReturnCode:
// returnCode = domain.NewReturnCodeCommand(req.GetUserId())
// cmd = returnCode
// default:
// cmd = domain.NewSendCodeCommand(req.GetUserId(), req.GetSendCode().UrlTemplate)
// }
// err = domain.Invoke(ctx, cmd)
// if err != nil {
// return nil, err
// }
// resp = new(user.SendEmailCodeResponse)
// if returnCode != nil {
// resp.VerificationCode = &returnCode.Code
// }
// return resp, nil
// }

View File

@@ -0,0 +1,19 @@
package userv2
// this file has been commented out to pass the linter
// import (
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
// )
// logger logging.Logger
// var tracer tracing.Tracer
// func SetLogger(l logging.Logger) {
// logger = l
// }
// func SetTracer(t tracing.Tracer) {
// tracer = t
// }

21
backend/v3/doc.go Normal file
View File

@@ -0,0 +1,21 @@
// the test used the manly relies on the following patterns:
// - api:
// - some example stubs for the grpc api, it maps the calls and responses to the domain objects
//
// - domain:
// - hexagonal architecture, it defines its dependencies as interfaces and the dependencies must use the objects defined by this package
// - command pattern which implements the changes
// - the invoker decorates the commands by checking for events, tracing, logging, potentially caching, etc.
// - the database connections are manged in this package
// - the database connections are passed to the repositories
//
// - storage:
// - repository pattern, the repositories are defined as interfaces and the implementations are in the storage package
// - the repositories are used by the domain package to access the database
// - the eventstore to store events. At the beginning it writes to the same events table as the /internal package, afterwards it writes to a different table
//
// - telemetry:
// - logging for standard output
// - tracing for distributed tracing
// - metrics for monitoring
package v3

View File

@@ -0,0 +1,131 @@
package domain
// import (
// "context"
// "fmt"
// "github.com/zitadel/zitadel/backend/v3/storage/database"
// )
// // Commander is the all it needs to implement the command pattern.
// // It is the interface all manipulations need to implement.
// // If possible it should also be used for queries. We will find out if this is possible in the future.
// type Commander interface {
// Execute(ctx context.Context, opts *CommandOpts) (err error)
// fmt.Stringer
// }
// // Invoker is part of the command pattern.
// // It is the interface that is used to execute commands.
// type Invoker interface {
// Invoke(ctx context.Context, command Commander, opts *CommandOpts) error
// }
// // CommandOpts are passed to each command
// // the provide common fields used by commands like the database client.
// type CommandOpts struct {
// DB database.QueryExecutor
// Invoker Invoker
// }
// type ensureTxOpts struct {
// *database.TransactionOptions
// }
// type EnsureTransactionOpt func(*ensureTxOpts)
// // EnsureTx ensures that the DB is a transaction. If it is not, it will start a new transaction.
// // The returned close function will end the transaction. If the DB is already a transaction, the close function
// // will do nothing because another [Commander] is already responsible for ending the transaction.
// func (o *CommandOpts) EnsureTx(ctx context.Context, opts ...EnsureTransactionOpt) (close func(context.Context, error) error, err error) {
// beginner, ok := o.DB.(database.Beginner)
// if !ok {
// // db is already a transaction
// return func(_ context.Context, err error) error {
// return err
// }, nil
// }
// txOpts := &ensureTxOpts{
// TransactionOptions: new(database.TransactionOptions),
// }
// for _, opt := range opts {
// opt(txOpts)
// }
// tx, err := beginner.Begin(ctx, txOpts.TransactionOptions)
// if err != nil {
// return nil, err
// }
// o.DB = tx
// return func(ctx context.Context, err error) error {
// return tx.End(ctx, err)
// }, nil
// }
// // EnsureClient ensures that the o.DB is a client. If it is not, it will get a new client from the [database.Pool].
// // The returned close function will release the client. If the o.DB is already a client or transaction, the close function
// // will do nothing because another [Commander] is already responsible for releasing the client.
// func (o *CommandOpts) EnsureClient(ctx context.Context) (close func(_ context.Context) error, err error) {
// pool, ok := o.DB.(database.Pool)
// if !ok {
// // o.DB is already a client
// return func(_ context.Context) error {
// return nil
// }, nil
// }
// client, err := pool.Acquire(ctx)
// if err != nil {
// return nil, err
// }
// o.DB = client
// return func(ctx context.Context) error {
// return client.Release(ctx)
// }, nil
// }
// func (o *CommandOpts) Invoke(ctx context.Context, command Commander) error {
// if o.Invoker == nil {
// return command.Execute(ctx, o)
// }
// return o.Invoker.Invoke(ctx, command, o)
// }
// func DefaultOpts(invoker Invoker) *CommandOpts {
// if invoker == nil {
// invoker = &noopInvoker{}
// }
// return &CommandOpts{
// DB: pool,
// Invoker: invoker,
// }
// }
// // commandBatch is a batch of commands.
// // It uses the [Invoker] provided by the opts to execute each command.
// type commandBatch struct {
// Commands []Commander
// }
// func BatchCommands(cmds ...Commander) *commandBatch {
// return &commandBatch{
// Commands: cmds,
// }
// }
// // String implements [Commander].
// func (cmd *commandBatch) String() string {
// return "commandBatch"
// }
// func (b *commandBatch) Execute(ctx context.Context, opts *CommandOpts) (err error) {
// for _, cmd := range b.Commands {
// if err = opts.Invoke(ctx, cmd); err != nil {
// return err
// }
// }
// return nil
// }
// var _ Commander = (*commandBatch)(nil)

View File

@@ -0,0 +1,90 @@
package domain
// import (
// "context"
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
// )
// // CreateUserCommand adds a new user including the email verification for humans.
// // In the future it might make sense to separate the command into two commands:
// // - CreateHumanCommand: creates a new human user
// // - CreateMachineCommand: creates a new machine user
// type CreateUserCommand struct {
// user *User
// email *SetEmailCommand
// }
// var (
// _ Commander = (*CreateUserCommand)(nil)
// _ eventer = (*CreateUserCommand)(nil)
// )
// // opts heavily reduces the complexity for email verification because each type of verification is a simple option which implements the [Commander] interface.
// func NewCreateHumanCommand(username string, opts ...CreateHumanOpt) *CreateUserCommand {
// cmd := &CreateUserCommand{
// user: &User{
// Username: username,
// Traits: &Human{},
// },
// }
// for _, opt := range opts {
// opt.applyOnCreateHuman(cmd)
// }
// return cmd
// }
// // String implements [Commander].
// func (cmd *CreateUserCommand) String() string {
// return "CreateUserCommand"
// }
// // Events implements [eventer].
// func (c *CreateUserCommand) Events() []*eventstore.Event {
// return []*eventstore.Event{
// {
// AggregateType: "user",
// AggregateID: c.user.ID,
// Type: "user.added",
// Payload: c.user,
// },
// }
// }
// // Execute implements [Commander].
// func (c *CreateUserCommand) Execute(ctx context.Context, opts *CommandOpts) error {
// if err := c.ensureUserID(); err != nil {
// return err
// }
// c.email.UserID = c.user.ID
// if err := opts.Invoke(ctx, c.email); err != nil {
// return err
// }
// return nil
// }
// type CreateHumanOpt interface {
// applyOnCreateHuman(*CreateUserCommand)
// }
// type createHumanIDOpt string
// // applyOnCreateHuman implements [CreateHumanOpt].
// func (c createHumanIDOpt) applyOnCreateHuman(cmd *CreateUserCommand) {
// cmd.user.ID = string(c)
// }
// var _ CreateHumanOpt = (*createHumanIDOpt)(nil)
// func CreateHumanWithID(id string) CreateHumanOpt {
// return createHumanIDOpt(id)
// }
// func (c *CreateUserCommand) ensureUserID() (err error) {
// if c.user.ID != "" {
// return nil
// }
// c.user.ID, err = generateID()
// return err
// }

View File

@@ -0,0 +1,37 @@
package domain
// import (
// "context"
// "github.com/zitadel/zitadel/internal/crypto"
// )
// type generateCodeCommand struct {
// code string
// value *crypto.CryptoValue
// }
// // I didn't update this repository to the solution proposed please view one of the following interfaces for correct usage:
// // - [UserRepository]
// // - [InstanceRepository]
// // - [OrgRepository]
// type CryptoRepository interface {
// GetEncryptionConfig(ctx context.Context) (*crypto.GeneratorConfig, error)
// }
// // String implements [Commander].
// func (cmd *generateCodeCommand) String() string {
// return "generateCodeCommand"
// }
// func (cmd *generateCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
// config, err := cryptoRepo(opts.DB).GetEncryptionConfig(ctx)
// if err != nil {
// return err
// }
// generator := crypto.NewEncryptionGenerator(*config, userCodeAlgorithm)
// cmd.value, cmd.code, err = crypto.NewCode(generator)
// return err
// }
// var _ Commander = (*generateCodeCommand)(nil)

127
backend/v3/domain/domain.go Normal file
View File

@@ -0,0 +1,127 @@
package domain
import (
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
//go:generate enumer -type DomainValidationType -transform lower -trimprefix DomainValidationType -sql
type DomainValidationType uint8
const (
DomainValidationTypeDNS DomainValidationType = iota
DomainValidationTypeHTTP
)
//go:generate enumer -type DomainType -transform lower -trimprefix DomainType -sql
type DomainType uint8
const (
DomainTypeCustom DomainType = iota
DomainTypeTrusted
)
type domainColumns interface {
// InstanceIDColumn returns the column for the instance id field.
InstanceIDColumn() database.Column
// DomainColumn returns the column for the domain field.
DomainColumn() database.Column
// IsPrimaryColumn returns the column for the is primary field.
IsPrimaryColumn() database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
}
type domainConditions interface {
// InstanceIDCondition returns a filter on the instance id field.
InstanceIDCondition(instanceID string) database.Condition
// DomainCondition returns a filter on the domain field.
DomainCondition(op database.TextOperation, domain string) database.Condition
// IsPrimaryCondition returns a filter on the is primary field.
IsPrimaryCondition(isPrimary bool) database.Condition
}
type domainChanges interface {
// SetPrimary sets a domain as primary based on the condition.
// All other domains will be set to non-primary.
//
// An error is returned if:
// - The condition identifies multiple domains.
// - The condition does not identify any domain.
//
// This is a no-op if:
// - The domain is already primary.
// - No domain matches the condition.
SetPrimary() database.Change
// SetUpdatedAt sets the updated at column.
// This is used for reducing events.
SetUpdatedAt(t time.Time) database.Change
}
// import (
// "math/rand/v2"
// "strconv"
// "github.com/zitadel/zitadel/backend/v3/storage/cache"
// "github.com/zitadel/zitadel/backend/v3/storage/database"
// // "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
// "github.com/zitadel/zitadel/internal/crypto"
// )
// // The variables could also be moved to a struct.
// // I just started with the singleton pattern and kept it like this.
// var (
// pool database.Pool
// userCodeAlgorithm crypto.EncryptionAlgorithm
// tracer tracing.Tracer
// // logger logging.Logger
// userRepo func(database.QueryExecutor) UserRepository
// // instanceRepo func(database.QueryExecutor) InstanceRepository
// cryptoRepo func(database.QueryExecutor) CryptoRepository
// orgRepo func(database.QueryExecutor) OrgRepository
// // instanceCache cache.Cache[instanceCacheIndex, string, *Instance]
// orgCache cache.Cache[orgCacheIndex, string, *Org]
// generateID func() (string, error) = func() (string, error) {
// return strconv.FormatUint(rand.Uint64(), 10), nil
// }
// )
// func SetPool(p database.Pool) {
// pool = p
// }
// func SetUserCodeAlgorithm(algorithm crypto.EncryptionAlgorithm) {
// userCodeAlgorithm = algorithm
// }
// func SetTracer(t tracing.Tracer) {
// tracer = t
// }
// // func SetLogger(l logging.Logger) {
// // logger = l
// // }
// func SetUserRepository(repo func(database.QueryExecutor) UserRepository) {
// userRepo = repo
// }
// func SetOrgRepository(repo func(database.QueryExecutor) OrgRepository) {
// orgRepo = repo
// }
// // func SetInstanceRepository(repo func(database.QueryExecutor) InstanceRepository) {
// // instanceRepo = repo
// // }
// func SetCryptoRepository(repo func(database.QueryExecutor) CryptoRepository) {
// cryptoRepo = repo
// }

View File

@@ -0,0 +1,67 @@
package domain_test
// import (
// "context"
// "log/slog"
// "testing"
// "github.com/stretchr/testify/assert"
// "github.com/stretchr/testify/require"
// "go.opentelemetry.io/otel"
// "go.opentelemetry.io/otel/exporters/stdout/stdouttrace"
// sdktrace "go.opentelemetry.io/otel/sdk/trace"
// "go.uber.org/mock/gomock"
// . "github.com/zitadel/zitadel/backend/v3/domain"
// "github.com/zitadel/zitadel/backend/v3/storage/database/dbmock"
// "github.com/zitadel/zitadel/backend/v3/storage/database/repository"
// "github.com/zitadel/zitadel/backend/v3/telemetry/logging"
// "github.com/zitadel/zitadel/backend/v3/telemetry/tracing"
// )
// These tests give an overview of how to use the domain package.
// func TestExample(t *testing.T) {
// t.Skip("skip example test because it is not a real test")
// ctx := context.Background()
// ctrl := gomock.NewController(t)
// pool := dbmock.NewMockPool(ctrl)
// tx := dbmock.NewMockTransaction(ctrl)
// pool.EXPECT().Begin(gomock.Any(), gomock.Any()).Return(tx, nil)
// tx.EXPECT().End(gomock.Any(), gomock.Any()).Return(nil)
// SetPool(pool)
// exporter, err := stdouttrace.New(stdouttrace.WithPrettyPrint())
// require.NoError(t, err)
// tracerProvider := sdktrace.NewTracerProvider(
// sdktrace.WithSyncer(exporter),
// )
// otel.SetTracerProvider(tracerProvider)
// SetTracer(tracing.Tracer{Tracer: tracerProvider.Tracer("test")})
// defer func() { assert.NoError(t, tracerProvider.Shutdown(ctx)) }()
// SetLogger(logging.Logger{Logger: slog.Default()})
// SetUserRepository(repository.UserRepository)
// SetOrgRepository(repository.OrgRepository)
// // SetInstanceRepository(repository.Instance)
// // SetCryptoRepository(repository.Crypto)
// t.Run("create org", func(t *testing.T) {
// org := NewAddOrgCommand("testorg", NewAddMemberCommand("testuser", "ORG_OWNER"))
// user := NewCreateHumanCommand("testuser")
// err := Invoke(ctx, BatchCommands(org, user))
// assert.NoError(t, err)
// })
// t.Run("verified email", func(t *testing.T) {
// err := Invoke(ctx, NewSetEmailCommand("u1", "test@example.com", NewEmailVerifiedCommand("u1", true)))
// assert.NoError(t, err)
// })
// t.Run("unverified email", func(t *testing.T) {
// err := Invoke(ctx, NewSetEmailCommand("u2", "test2@example.com", NewEmailVerifiedCommand("u2", false)))
// assert.NoError(t, err)
// })
// }

View File

@@ -0,0 +1,109 @@
// Code generated by "enumer -type DomainType -transform lower -trimprefix DomainType -sql"; DO NOT EDIT.
package domain
import (
"database/sql/driver"
"fmt"
"strings"
)
const _DomainTypeName = "customtrusted"
var _DomainTypeIndex = [...]uint8{0, 6, 13}
const _DomainTypeLowerName = "customtrusted"
func (i DomainType) String() string {
if i >= DomainType(len(_DomainTypeIndex)-1) {
return fmt.Sprintf("DomainType(%d)", i)
}
return _DomainTypeName[_DomainTypeIndex[i]:_DomainTypeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _DomainTypeNoOp() {
var x [1]struct{}
_ = x[DomainTypeCustom-(0)]
_ = x[DomainTypeTrusted-(1)]
}
var _DomainTypeValues = []DomainType{DomainTypeCustom, DomainTypeTrusted}
var _DomainTypeNameToValueMap = map[string]DomainType{
_DomainTypeName[0:6]: DomainTypeCustom,
_DomainTypeLowerName[0:6]: DomainTypeCustom,
_DomainTypeName[6:13]: DomainTypeTrusted,
_DomainTypeLowerName[6:13]: DomainTypeTrusted,
}
var _DomainTypeNames = []string{
_DomainTypeName[0:6],
_DomainTypeName[6:13],
}
// DomainTypeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func DomainTypeString(s string) (DomainType, error) {
if val, ok := _DomainTypeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _DomainTypeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to DomainType values", s)
}
// DomainTypeValues returns all values of the enum
func DomainTypeValues() []DomainType {
return _DomainTypeValues
}
// DomainTypeStrings returns a slice of all String values of the enum
func DomainTypeStrings() []string {
strs := make([]string, len(_DomainTypeNames))
copy(strs, _DomainTypeNames)
return strs
}
// IsADomainType returns "true" if the value is listed in the enum definition. "false" otherwise
func (i DomainType) IsADomainType() bool {
for _, v := range _DomainTypeValues {
if i == v {
return true
}
}
return false
}
func (i DomainType) Value() (driver.Value, error) {
return i.String(), nil
}
func (i *DomainType) Scan(value interface{}) error {
if value == nil {
return nil
}
var str string
switch v := value.(type) {
case []byte:
str = string(v)
case string:
str = v
case fmt.Stringer:
str = v.String()
default:
return fmt.Errorf("invalid value of DomainType: %[1]T(%[1]v)", value)
}
val, err := DomainTypeString(str)
if err != nil {
return err
}
*i = val
return nil
}

View File

@@ -0,0 +1,109 @@
// Code generated by "enumer -type DomainValidationType -transform lower -trimprefix DomainValidationType -sql"; DO NOT EDIT.
package domain
import (
"database/sql/driver"
"fmt"
"strings"
)
const _DomainValidationTypeName = "dnshttp"
var _DomainValidationTypeIndex = [...]uint8{0, 3, 7}
const _DomainValidationTypeLowerName = "dnshttp"
func (i DomainValidationType) String() string {
if i >= DomainValidationType(len(_DomainValidationTypeIndex)-1) {
return fmt.Sprintf("DomainValidationType(%d)", i)
}
return _DomainValidationTypeName[_DomainValidationTypeIndex[i]:_DomainValidationTypeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _DomainValidationTypeNoOp() {
var x [1]struct{}
_ = x[DomainValidationTypeDNS-(0)]
_ = x[DomainValidationTypeHTTP-(1)]
}
var _DomainValidationTypeValues = []DomainValidationType{DomainValidationTypeDNS, DomainValidationTypeHTTP}
var _DomainValidationTypeNameToValueMap = map[string]DomainValidationType{
_DomainValidationTypeName[0:3]: DomainValidationTypeDNS,
_DomainValidationTypeLowerName[0:3]: DomainValidationTypeDNS,
_DomainValidationTypeName[3:7]: DomainValidationTypeHTTP,
_DomainValidationTypeLowerName[3:7]: DomainValidationTypeHTTP,
}
var _DomainValidationTypeNames = []string{
_DomainValidationTypeName[0:3],
_DomainValidationTypeName[3:7],
}
// DomainValidationTypeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func DomainValidationTypeString(s string) (DomainValidationType, error) {
if val, ok := _DomainValidationTypeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _DomainValidationTypeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to DomainValidationType values", s)
}
// DomainValidationTypeValues returns all values of the enum
func DomainValidationTypeValues() []DomainValidationType {
return _DomainValidationTypeValues
}
// DomainValidationTypeStrings returns a slice of all String values of the enum
func DomainValidationTypeStrings() []string {
strs := make([]string, len(_DomainValidationTypeNames))
copy(strs, _DomainValidationTypeNames)
return strs
}
// IsADomainValidationType returns "true" if the value is listed in the enum definition. "false" otherwise
func (i DomainValidationType) IsADomainValidationType() bool {
for _, v := range _DomainValidationTypeValues {
if i == v {
return true
}
}
return false
}
func (i DomainValidationType) Value() (driver.Value, error) {
return i.String(), nil
}
func (i *DomainValidationType) Scan(value interface{}) error {
if value == nil {
return nil
}
var str string
switch v := value.(type) {
case []byte:
str = string(v)
case string:
str = v
case fmt.Stringer:
str = v.String()
default:
return fmt.Errorf("invalid value of DomainValidationType: %[1]T(%[1]v)", value)
}
val, err := DomainValidationTypeString(str)
if err != nil {
return err
}
*i = val
return nil
}

View File

@@ -0,0 +1,175 @@
package domain
// import (
// "context"
// "time"
// )
// // EmailVerifiedCommand verifies an email address for a user.
// type EmailVerifiedCommand struct {
// UserID string `json:"userId"`
// Email *Email `json:"email"`
// }
// func NewEmailVerifiedCommand(userID string, isVerified bool) *EmailVerifiedCommand {
// return &EmailVerifiedCommand{
// UserID: userID,
// Email: &Email{
// VerifiedAt: time.Time{},
// },
// }
// }
// // String implements [Commander].
// func (cmd *EmailVerifiedCommand) String() string {
// return "EmailVerifiedCommand"
// }
// var (
// _ Commander = (*EmailVerifiedCommand)(nil)
// _ SetEmailOpt = (*EmailVerifiedCommand)(nil)
// )
// // Execute implements [Commander]
// func (cmd *EmailVerifiedCommand) Execute(ctx context.Context, opts *CommandOpts) error {
// repo := userRepo(opts.DB).Human()
// return repo.Update(ctx, repo.IDCondition(cmd.UserID), repo.SetEmailVerifiedAt(time.Time{}))
// }
// // applyOnSetEmail implements [SetEmailOpt]
// func (cmd *EmailVerifiedCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
// cmd.UserID = setEmailCmd.UserID
// cmd.Email.Address = setEmailCmd.Email
// setEmailCmd.verification = cmd
// }
// // SendCodeCommand sends a verification code to the user's email address.
// // If the URLTemplate is not set it will use the default of the organization / instance.
// type SendCodeCommand struct {
// UserID string `json:"userId"`
// Email string `json:"email"`
// URLTemplate *string `json:"urlTemplate"`
// generator *generateCodeCommand
// }
// var (
// _ Commander = (*SendCodeCommand)(nil)
// _ SetEmailOpt = (*SendCodeCommand)(nil)
// )
// func NewSendCodeCommand(userID string, urlTemplate *string) *SendCodeCommand {
// return &SendCodeCommand{
// UserID: userID,
// generator: &generateCodeCommand{},
// URLTemplate: urlTemplate,
// }
// }
// // String implements [Commander].
// func (cmd *SendCodeCommand) String() string {
// return "SendCodeCommand"
// }
// // Execute implements [Commander]
// func (cmd *SendCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
// if err := cmd.ensureEmail(ctx, opts); err != nil {
// return err
// }
// if err := cmd.ensureURL(ctx, opts); err != nil {
// return err
// }
// if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
// return err
// }
// // TODO: queue notification
// return nil
// }
// func (cmd *SendCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
// if cmd.Email != "" {
// return nil
// }
// repo := userRepo(opts.DB).Human()
// email, err := repo.GetEmail(ctx, repo.IDCondition(cmd.UserID))
// if err != nil || !email.VerifiedAt.IsZero() {
// return err
// }
// cmd.Email = email.Address
// return nil
// }
// func (cmd *SendCodeCommand) ensureURL(ctx context.Context, opts *CommandOpts) error {
// if cmd.URLTemplate != nil && *cmd.URLTemplate != "" {
// return nil
// }
// _, _ = ctx, opts
// // TODO: load default template
// return nil
// }
// // applyOnSetEmail implements [SetEmailOpt]
// func (cmd *SendCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
// cmd.UserID = setEmailCmd.UserID
// cmd.Email = setEmailCmd.Email
// setEmailCmd.verification = cmd
// }
// // ReturnCodeCommand creates the code and returns it to the caller.
// // The caller gets the code by calling the Code field after the command got executed.
// type ReturnCodeCommand struct {
// UserID string `json:"userId"`
// Email string `json:"email"`
// Code string `json:"code"`
// generator *generateCodeCommand
// }
// var (
// _ Commander = (*ReturnCodeCommand)(nil)
// _ SetEmailOpt = (*ReturnCodeCommand)(nil)
// )
// func NewReturnCodeCommand(userID string) *ReturnCodeCommand {
// return &ReturnCodeCommand{
// UserID: userID,
// generator: &generateCodeCommand{},
// }
// }
// // String implements [Commander].
// func (cmd *ReturnCodeCommand) String() string {
// return "ReturnCodeCommand"
// }
// // Execute implements [Commander]
// func (cmd *ReturnCodeCommand) Execute(ctx context.Context, opts *CommandOpts) error {
// if err := cmd.ensureEmail(ctx, opts); err != nil {
// return err
// }
// if err := opts.Invoker.Invoke(ctx, cmd.generator, opts); err != nil {
// return err
// }
// cmd.Code = cmd.generator.code
// return nil
// }
// func (cmd *ReturnCodeCommand) ensureEmail(ctx context.Context, opts *CommandOpts) error {
// if cmd.Email != "" {
// return nil
// }
// repo := userRepo(opts.DB).Human()
// email, err := repo.GetEmail(ctx, repo.IDCondition(cmd.UserID))
// if err != nil || !email.VerifiedAt.IsZero() {
// return err
// }
// cmd.Email = email.Address
// return nil
// }
// // applyOnSetEmail implements [SetEmailOpt]
// func (cmd *ReturnCodeCommand) applyOnSetEmail(setEmailCmd *SetEmailCommand) {
// cmd.UserID = setEmailCmd.UserID
// cmd.Email = setEmailCmd.Email
// setEmailCmd.verification = cmd
// }

View File

@@ -0,0 +1,7 @@
package domain
import "errors"
var (
ErrNoAdminSpecified = errors.New("at least one admin must be specified")
)

View File

@@ -0,0 +1,117 @@
package domain
import (
"context"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type Instance struct {
ID string `json:"id,omitempty" db:"id"`
Name string `json:"name,omitempty" db:"name"`
DefaultOrgID string `json:"defaultOrgId,omitempty" db:"default_org_id"`
IAMProjectID string `json:"iamProjectId,omitempty" db:"iam_project_id"`
ConsoleClientID string `json:"consoleClientId,omitempty" db:"console_client_id"`
ConsoleAppID string `json:"consoleAppId,omitempty" db:"console_app_id"`
DefaultLanguage string `json:"defaultLanguage,omitempty" db:"default_language"`
CreatedAt time.Time `json:"createdAt" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt" db:"updated_at"`
Domains []*InstanceDomain `json:"domains,omitempty" db:"-"`
}
type instanceCacheIndex uint8
const (
instanceCacheIndexUndefined instanceCacheIndex = iota
instanceCacheIndexID
)
// Keys implements the [cache.Entry].
func (i *Instance) Keys(index instanceCacheIndex) (key []string) {
if index == instanceCacheIndexID {
return []string{i.ID}
}
return nil
}
var _ cache.Entry[instanceCacheIndex, string] = (*Instance)(nil)
// instanceColumns define all the columns of the instance table.
type instanceColumns interface {
// IDColumn returns the column for the id field.
IDColumn() database.Column
// NameColumn returns the column for the name field.
NameColumn() database.Column
// DefaultOrgIDColumn returns the column for the default org id field
DefaultOrgIDColumn() database.Column
// IAMProjectIDColumn returns the column for the default IAM org id field
IAMProjectIDColumn() database.Column
// ConsoleClientIDColumn returns the column for the default IAM org id field
ConsoleClientIDColumn() database.Column
// ConsoleAppIDColumn returns the column for the console client id field
ConsoleAppIDColumn() database.Column
// DefaultLanguageColumn returns the column for the default language field
DefaultLanguageColumn() database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
}
// instanceConditions define all the conditions for the instance table.
type instanceConditions interface {
// IDCondition returns an equal filter on the id field.
IDCondition(instanceID string) database.Condition
// NameCondition returns a filter on the name field.
NameCondition(op database.TextOperation, name string) database.Condition
}
// instanceChanges define all the changes for the instance table.
type instanceChanges interface {
// SetName sets the name column.
SetName(name string) database.Change
// SetUpdatedAt sets the updated at column.
SetUpdatedAt(time time.Time) database.Change
// SetIAMProject sets the iam project column.
SetIAMProject(id string) database.Change
// SetDefaultOrg sets the default org column.
SetDefaultOrg(id string) database.Change
// SetDefaultLanguage sets the default language column.
SetDefaultLanguage(language language.Tag) database.Change
// SetConsoleClientID sets the console client id column.
SetConsoleClientID(id string) database.Change
// SetConsoleAppID sets the console app id column.
SetConsoleAppID(id string) database.Change
}
// InstanceRepository is the interface for the instance repository.
type InstanceRepository interface {
instanceColumns
instanceConditions
instanceChanges
// TODO
// Member returns the member repository which is a sub repository of the instance repository.
// Member() MemberRepository
Get(ctx context.Context, opts ...database.QueryOption) (*Instance, error)
List(ctx context.Context, opts ...database.QueryOption) ([]*Instance, error)
Create(ctx context.Context, instance *Instance) error
Update(ctx context.Context, id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id string) (int64, error)
// Domains returns the domain sub repository for the instance.
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set even if shouldLoad is false in the future.
Domains(shouldLoad bool) InstanceDomainRepository
}
type CreateInstance struct {
Name string `json:"name"`
}

View File

@@ -0,0 +1,79 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type InstanceDomain struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
Domain string `json:"domain,omitempty" db:"domain"`
// IsPrimary indicates if the domain is the primary domain of the instance.
// It is only set for custom domains.
IsPrimary *bool `json:"isPrimary,omitempty" db:"is_primary"`
// IsGenerated indicates if the domain is a generated domain.
// It is only set for custom domains.
IsGenerated *bool `json:"isGenerated,omitempty" db:"is_generated"`
Type DomainType `json:"type,omitempty" db:"type"`
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
}
type AddInstanceDomain struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
Domain string `json:"domain,omitempty" db:"domain"`
IsPrimary *bool `json:"isPrimary,omitempty" db:"is_primary"`
IsGenerated *bool `json:"isGenerated,omitempty" db:"is_generated"`
Type DomainType `json:"type,omitempty" db:"type"`
// CreatedAt is the time when the domain was added.
// It is set by the repository and should not be set by the caller.
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
// UpdatedAt is the time when the domain was last updated.
// It is set by the repository and should not be set by the caller.
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
}
type instanceDomainColumns interface {
domainColumns
// IsGeneratedColumn returns the column for the is generated field.
IsGeneratedColumn() database.Column
// TypeColumn returns the column for the type field.
TypeColumn() database.Column
}
type instanceDomainConditions interface {
domainConditions
// TypeCondition returns a filter for the type field.
TypeCondition(typ DomainType) database.Condition
}
type instanceDomainChanges interface {
domainChanges
// SetType sets the type column.
SetType(typ DomainType) database.Change
}
type InstanceDomainRepository interface {
instanceDomainColumns
instanceDomainConditions
instanceDomainChanges
// Get returns a single domain based on the criteria.
// If no domain is found, it returns an error of type [database.ErrNotFound].
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
Get(ctx context.Context, opts ...database.QueryOption) (*InstanceDomain, error)
// List returns a list of domains based on the criteria.
// If no domains are found, it returns an empty slice.
List(ctx context.Context, opts ...database.QueryOption) ([]*InstanceDomain, error)
// Add adds a new domain to the instance.
Add(ctx context.Context, domain *AddInstanceDomain) error
// Update updates an existing domain in the instance.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
// Remove removes a domain from the instance.
Remove(ctx context.Context, condition database.Condition) (int64, error)
}

158
backend/v3/domain/invoke.go Normal file
View File

@@ -0,0 +1,158 @@
package domain
// import (
// "context"
// "fmt"
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
// )
// // Invoke provides a way to execute commands within the domain package.
// // It uses a chain of responsibility pattern to handle the command execution.
// // The default chain includes logging, tracing, and event publishing.
// // If you want to invoke multiple commands in a single transaction, you can use the [commandBatch].
// func Invoke(ctx context.Context, cmd Commander) error {
// invoker := newEventStoreInvoker(newLoggingInvoker(newTraceInvoker(nil)))
// opts := &CommandOpts{
// Invoker: invoker.collector,
// DB: pool,
// }
// return invoker.Invoke(ctx, cmd, opts)
// }
// // eventStoreInvoker checks if the command implements the [eventer] interface.
// // If it does, it collects the events and publishes them to the event store.
// type eventStoreInvoker struct {
// collector *eventCollector
// }
// func newEventStoreInvoker(next Invoker) *eventStoreInvoker {
// return &eventStoreInvoker{collector: &eventCollector{next: next}}
// }
// func (i *eventStoreInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
// err = i.collector.Invoke(ctx, command, opts)
// if err != nil {
// return err
// }
// if len(i.collector.events) > 0 {
// err = eventstore.Publish(ctx, i.collector.events, opts.DB)
// if err != nil {
// return err
// }
// }
// return nil
// }
// // eventCollector collects events from all commands. The [eventStoreInvoker] pushes the collected events after all commands are executed.
// type eventCollector struct {
// next Invoker
// events []*eventstore.Event
// }
// type eventer interface {
// Events() []*eventstore.Event
// }
// func (i *eventCollector) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
// if e, ok := command.(eventer); ok && len(e.Events()) > 0 {
// // we need to ensure all commands are executed in the same transaction
// close, err := opts.EnsureTx(ctx)
// if err != nil {
// return err
// }
// defer func() { err = close(ctx, err) }()
// i.events = append(i.events, e.Events()...)
// }
// if i.next != nil {
// return i.next.Invoke(ctx, command, opts)
// }
// return command.Execute(ctx, opts)
// }
// // traceInvoker decorates each command with tracing.
// type traceInvoker struct {
// next Invoker
// }
// func newTraceInvoker(next Invoker) *traceInvoker {
// return &traceInvoker{next: next}
// }
// func (i *traceInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
// ctx, span := tracer.Start(ctx, fmt.Sprintf("%T", command))
// defer func() {
// if err != nil {
// span.RecordError(err)
// }
// span.End()
// }()
// if i.next != nil {
// return i.next.Invoke(ctx, command, opts)
// }
// return command.Execute(ctx, opts)
// }
// // loggingInvoker decorates each command with logging.
// // It is an example implementation and logs the command name at the beginning and success or failure after the command got executed.
// type loggingInvoker struct {
// next Invoker
// }
// func newLoggingInvoker(next Invoker) *loggingInvoker {
// return &loggingInvoker{next: next}
// }
// func (i *loggingInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
// logger.InfoContext(ctx, "Invoking command", "command", command.String())
// if i.next != nil {
// err = i.next.Invoke(ctx, command, opts)
// } else {
// err = command.Execute(ctx, opts)
// }
// if err != nil {
// logger.ErrorContext(ctx, "Command invocation failed", "command", command.String(), "error", err)
// return err
// }
// logger.InfoContext(ctx, "Command invocation succeeded", "command", command.String())
// return nil
// }
// type noopInvoker struct {
// next Invoker
// }
// func (i *noopInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) error {
// if i.next != nil {
// return i.next.Invoke(ctx, command, opts)
// }
// return command.Execute(ctx, opts)
// }
// // cacheInvoker could be used in the future to do the caching.
// // My goal would be to have two interfaces:
// // - cacheSetter: which caches an object
// // - cacheGetter: which gets an object from the cache, this should also skip the command execution
// type cacheInvoker struct {
// next Invoker
// }
// type cacher interface {
// Cache(opts *CommandOpts)
// }
// func (i *cacheInvoker) Invoke(ctx context.Context, command Commander, opts *CommandOpts) (err error) {
// if c, ok := command.(cacher); ok {
// c.Cache(opts)
// }
// if i.next != nil {
// err = i.next.Invoke(ctx, command, opts)
// } else {
// err = command.Execute(ctx, opts)
// }
// return err
// }

View File

@@ -0,0 +1,137 @@
package domain
// import (
// "context"
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
// )
// // AddOrgCommand adds a new organization.
// // I'm unsure if we should add the Admins here or if this should be a separate command.
// type AddOrgCommand struct {
// ID string `json:"id"`
// Name string `json:"name"`
// Admins []*AddMemberCommand `json:"admins"`
// }
// func NewAddOrgCommand(name string, admins ...*AddMemberCommand) *AddOrgCommand {
// return &AddOrgCommand{
// Name: name,
// Admins: admins,
// }
// }
// // String implements [Commander].
// func (cmd *AddOrgCommand) String() string {
// return "AddOrgCommand"
// }
// // Execute implements Commander.
// func (cmd *AddOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
// if len(cmd.Admins) == 0 {
// return ErrNoAdminSpecified
// }
// if err = cmd.ensureID(); err != nil {
// return err
// }
// close, err := opts.EnsureTx(ctx)
// if err != nil {
// return err
// }
// defer func() { err = close(ctx, err) }()
// err = orgRepo(opts.DB).Create(ctx, &Org{
// ID: cmd.ID,
// Name: cmd.Name,
// })
// if err != nil {
// return err
// }
// for _, admin := range cmd.Admins {
// admin.orgID = cmd.ID
// if err = opts.Invoke(ctx, admin); err != nil {
// return err
// }
// }
// orgCache.Set(ctx, &Org{
// ID: cmd.ID,
// Name: cmd.Name,
// })
// return nil
// }
// // Events implements [eventer].
// func (cmd *AddOrgCommand) Events() []*eventstore.Event {
// return []*eventstore.Event{
// {
// AggregateType: "org",
// AggregateID: cmd.ID,
// Type: "org.added",
// Payload: cmd,
// },
// }
// }
// var (
// _ Commander = (*AddOrgCommand)(nil)
// _ eventer = (*AddOrgCommand)(nil)
// )
// func (cmd *AddOrgCommand) ensureID() (err error) {
// if cmd.ID != "" {
// return nil
// }
// cmd.ID, err = generateID()
// return err
// }
// // AddMemberCommand adds a new member to an organization.
// // I'm not sure if we should make it more generic to also use it for instances.
// type AddMemberCommand struct {
// orgID string
// UserID string `json:"userId"`
// Roles []string `json:"roles"`
// }
// func NewAddMemberCommand(userID string, roles ...string) *AddMemberCommand {
// return &AddMemberCommand{
// UserID: userID,
// Roles: roles,
// }
// }
// // String implements [Commander].
// func (cmd *AddMemberCommand) String() string {
// return "AddMemberCommand"
// }
// // Execute implements Commander.
// func (a *AddMemberCommand) Execute(ctx context.Context, opts *CommandOpts) (err error) {
// close, err := opts.EnsureTx(ctx)
// if err != nil {
// return err
// }
// defer func() { err = close(ctx, err) }()
// return orgRepo(opts.DB).Member().AddMember(ctx, a.orgID, a.UserID, a.Roles)
// }
// // Events implements [eventer].
// func (a *AddMemberCommand) Events() []*eventstore.Event {
// return []*eventstore.Event{
// {
// AggregateType: "org",
// AggregateID: a.UserID,
// Type: "member.added",
// Payload: a,
// },
// }
// }
// var (
// _ Commander = (*AddMemberCommand)(nil)
// _ eventer = (*AddMemberCommand)(nil)
// )

View File

@@ -0,0 +1,100 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
//go:generate enumer -type OrgState -transform lower -trimprefix OrgState -sql
type OrgState uint8
const (
OrgStateActive OrgState = iota
OrgStateInactive
)
type Organization struct {
ID string `json:"id,omitempty" db:"id"`
Name string `json:"name,omitempty" db:"name"`
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
State OrgState `json:"state,omitempty" db:"state"`
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
Domains []*OrganizationDomain `json:"domains,omitempty" db:"-"` // domains need to be handled separately
}
// OrgIdentifierCondition is used to help specify a single Organization,
// it will either be used as the organization ID or organization name,
// as organizations can be identified either using (instanceID + ID) OR (instanceID + name)
type OrgIdentifierCondition interface {
database.Condition
}
// organizationColumns define all the columns of the instance table.
type organizationColumns interface {
// IDColumn returns the column for the id field.
IDColumn() database.Column
// NameColumn returns the column for the name field.
NameColumn() database.Column
// InstanceIDColumn returns the column for the default org id field
InstanceIDColumn() database.Column
// StateColumn returns the column for the name field.
StateColumn() database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
}
// organizationConditions define all the conditions for the instance table.
type organizationConditions interface {
// IDCondition returns an equal filter on the id field.
IDCondition(instanceID string) OrgIdentifierCondition
// NameCondition returns a filter on the name field.
NameCondition(name string) OrgIdentifierCondition
// InstanceIDCondition returns a filter on the instance id field.
InstanceIDCondition(instanceID string) database.Condition
// StateCondition returns a filter on the name field.
StateCondition(state OrgState) database.Condition
}
// organizationChanges define all the changes for the instance table.
type organizationChanges interface {
// SetName sets the name column.
SetName(name string) database.Change
// SetState sets the name column.
SetState(state OrgState) database.Change
}
// OrganizationRepository is the interface for the instance repository.
type OrganizationRepository interface {
organizationColumns
organizationConditions
organizationChanges
Get(ctx context.Context, opts ...database.QueryOption) (*Organization, error)
List(ctx context.Context, opts ...database.QueryOption) ([]*Organization, error)
Create(ctx context.Context, instance *Organization) error
Update(ctx context.Context, id OrgIdentifierCondition, instance_id string, changes ...database.Change) (int64, error)
Delete(ctx context.Context, id OrgIdentifierCondition, instance_id string) (int64, error)
// Domains returns the domain sub repository for the organization.
// If shouldLoad is true, the domains will be loaded from the database and written to the [Instance].Domains field.
// If shouldLoad is set to true once, the Domains field will be set event if shouldLoad is false in the future.
Domains(shouldLoad bool) OrganizationDomainRepository
}
type CreateOrganization struct {
Name string `json:"name"`
}
// MemberRepository is a sub repository of the org repository and maybe the instance repository.
type MemberRepository interface {
AddMember(ctx context.Context, orgID, userID string, roles []string) error
SetMemberRoles(ctx context.Context, orgID, userID string, roles []string) error
RemoveMember(ctx context.Context, orgID, userID string) error
}

View File

@@ -0,0 +1,84 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type OrganizationDomain struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
OrgID string `json:"orgId,omitempty" db:"org_id"`
Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
ValidationType *DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
}
type AddOrganizationDomain struct {
InstanceID string `json:"instanceId,omitempty" db:"instance_id"`
OrgID string `json:"orgId,omitempty" db:"org_id"`
Domain string `json:"domain,omitempty" db:"domain"`
IsVerified bool `json:"isVerified,omitempty" db:"is_verified"`
IsPrimary bool `json:"isPrimary,omitempty" db:"is_primary"`
ValidationType *DomainValidationType `json:"validationType,omitempty" db:"validation_type"`
// CreatedAt is the time when the domain was added.
// It is set by the repository and should not be set by the caller.
CreatedAt time.Time `json:"createdAt,omitzero" db:"created_at"`
// UpdatedAt is the time when the domain was added.
// It is set by the repository and should not be set by the caller.
UpdatedAt time.Time `json:"updatedAt,omitzero" db:"updated_at"`
}
type organizationDomainColumns interface {
domainColumns
// OrgIDColumn returns the column for the org id field.
OrgIDColumn() database.Column
// IsVerifiedColumn returns the column for the is verified field.
IsVerifiedColumn() database.Column
// ValidationTypeColumn returns the column for the verification type field.
ValidationTypeColumn() database.Column
}
type organizationDomainConditions interface {
domainConditions
// OrgIDCondition returns a filter on the org id field.
OrgIDCondition(orgID string) database.Condition
// IsVerifiedCondition returns a filter on the is verified field.
IsVerifiedCondition(isVerified bool) database.Condition
}
type organizationDomainChanges interface {
domainChanges
// SetVerified sets the is verified column to true.
SetVerified() database.Change
// SetValidationType sets the verification type column.
// If the domain is already verified, this is a no-op.
SetValidationType(verificationType DomainValidationType) database.Change
}
type OrganizationDomainRepository interface {
organizationDomainColumns
organizationDomainConditions
organizationDomainChanges
// Get returns a single domain based on the criteria.
// If no domain is found, it returns an error of type [database.ErrNotFound].
// If multiple domains are found, it returns an error of type [database.ErrMultipleRows].
Get(ctx context.Context, opts ...database.QueryOption) (*OrganizationDomain, error)
// List returns a list of domains based on the criteria.
// If no domains are found, it returns an empty slice.
List(ctx context.Context, opts ...database.QueryOption) ([]*OrganizationDomain, error)
// Add adds a new domain to the organization.
Add(ctx context.Context, domain *AddOrganizationDomain) error
// Update updates an existing domain in the organization.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error)
// Remove removes a domain from the organization.
Remove(ctx context.Context, condition database.Condition) (int64, error)
}

View File

@@ -0,0 +1,109 @@
// Code generated by "enumer -type OrgState -transform lower -trimprefix OrgState -sql"; DO NOT EDIT.
package domain
import (
"database/sql/driver"
"fmt"
"strings"
)
const _OrgStateName = "activeinactive"
var _OrgStateIndex = [...]uint8{0, 6, 14}
const _OrgStateLowerName = "activeinactive"
func (i OrgState) String() string {
if i >= OrgState(len(_OrgStateIndex)-1) {
return fmt.Sprintf("OrgState(%d)", i)
}
return _OrgStateName[_OrgStateIndex[i]:_OrgStateIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _OrgStateNoOp() {
var x [1]struct{}
_ = x[OrgStateActive-(0)]
_ = x[OrgStateInactive-(1)]
}
var _OrgStateValues = []OrgState{OrgStateActive, OrgStateInactive}
var _OrgStateNameToValueMap = map[string]OrgState{
_OrgStateName[0:6]: OrgStateActive,
_OrgStateLowerName[0:6]: OrgStateActive,
_OrgStateName[6:14]: OrgStateInactive,
_OrgStateLowerName[6:14]: OrgStateInactive,
}
var _OrgStateNames = []string{
_OrgStateName[0:6],
_OrgStateName[6:14],
}
// OrgStateString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func OrgStateString(s string) (OrgState, error) {
if val, ok := _OrgStateNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _OrgStateNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to OrgState values", s)
}
// OrgStateValues returns all values of the enum
func OrgStateValues() []OrgState {
return _OrgStateValues
}
// OrgStateStrings returns a slice of all String values of the enum
func OrgStateStrings() []string {
strs := make([]string, len(_OrgStateNames))
copy(strs, _OrgStateNames)
return strs
}
// IsAOrgState returns "true" if the value is listed in the enum definition. "false" otherwise
func (i OrgState) IsAOrgState() bool {
for _, v := range _OrgStateValues {
if i == v {
return true
}
}
return false
}
func (i OrgState) Value() (driver.Value, error) {
return i.String(), nil
}
func (i *OrgState) Scan(value interface{}) error {
if value == nil {
return nil
}
var str string
switch v := value.(type) {
case []byte:
str = string(v)
case string:
str = v
case fmt.Stringer:
str = v.String()
default:
return fmt.Errorf("invalid value of OrgState: %[1]T(%[1]v)", value)
}
val, err := OrgStateString(str)
if err != nil {
return err
}
*i = val
return nil
}

View File

@@ -0,0 +1,74 @@
package domain
// import (
// "context"
// "github.com/zitadel/zitadel/backend/v3/storage/eventstore"
// )
// // SetEmailCommand sets the email address of a user.
// // If allows verification as a sub command.
// // The verification command is executed after the email address is set.
// // The verification command is executed in the same transaction as the email address update.
// type SetEmailCommand struct {
// UserID string `json:"userId"`
// Email string `json:"email"`
// verification Commander
// }
// var (
// _ Commander = (*SetEmailCommand)(nil)
// _ eventer = (*SetEmailCommand)(nil)
// _ CreateHumanOpt = (*SetEmailCommand)(nil)
// )
// type SetEmailOpt interface {
// applyOnSetEmail(*SetEmailCommand)
// }
// func NewSetEmailCommand(userID, email string, verificationType SetEmailOpt) *SetEmailCommand {
// cmd := &SetEmailCommand{
// UserID: userID,
// Email: email,
// }
// verificationType.applyOnSetEmail(cmd)
// return cmd
// }
// // String implements [Commander].
// func (cmd *SetEmailCommand) String() string {
// return "SetEmailCommand"
// }
// func (cmd *SetEmailCommand) Execute(ctx context.Context, opts *CommandOpts) error {
// close, err := opts.EnsureTx(ctx)
// if err != nil {
// return err
// }
// defer func() { err = close(ctx, err) }()
// // userStatement(opts.DB).Human().ByID(cmd.UserID).SetEmail(ctx, cmd.Email)
// repo := userRepo(opts.DB).Human()
// err = repo.Update(ctx, repo.IDCondition(cmd.UserID), repo.SetEmailAddress(cmd.Email))
// if err != nil {
// return err
// }
// return opts.Invoke(ctx, cmd.verification)
// }
// // Events implements [eventer].
// func (cmd *SetEmailCommand) Events() []*eventstore.Event {
// return []*eventstore.Event{
// {
// AggregateType: "user",
// AggregateID: cmd.UserID,
// Type: "user.email.set",
// Payload: cmd,
// },
// }
// }
// // applyOnCreateHuman implements [CreateHumanOpt].
// func (cmd *SetEmailCommand) applyOnCreateHuman(createUserCmd *CreateUserCommand) {
// createUserCmd.email = cmd
// }

241
backend/v3/domain/user.go Normal file
View File

@@ -0,0 +1,241 @@
package domain
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
// userColumns define all the columns of the user table.
type userColumns interface {
// InstanceIDColumn returns the column for the instance id field.
InstanceIDColumn() database.Column
// OrgIDColumn returns the column for the org id field.
OrgIDColumn() database.Column
// IDColumn returns the column for the id field.
IDColumn() database.Column
// UsernameColumn returns the column for the username field.
UsernameColumn() database.Column
// CreatedAtColumn returns the column for the created at field.
CreatedAtColumn() database.Column
// UpdatedAtColumn returns the column for the updated at field.
UpdatedAtColumn() database.Column
// DeletedAtColumn returns the column for the deleted at field.
DeletedAtColumn() database.Column
}
// userConditions define all the conditions for the user table.
type userConditions interface {
// InstanceIDCondition returns an equal filter on the instance id field.
InstanceIDCondition(instanceID string) database.Condition
// OrgIDCondition returns an equal filter on the org id field.
OrgIDCondition(orgID string) database.Condition
// IDCondition returns an equal filter on the id field.
IDCondition(userID string) database.Condition
// UsernameCondition returns a filter on the username field.
UsernameCondition(op database.TextOperation, username string) database.Condition
// CreatedAtCondition returns a filter on the created at field.
CreatedAtCondition(op database.NumberOperation, createdAt time.Time) database.Condition
// UpdatedAtCondition returns a filter on the updated at field.
UpdatedAtCondition(op database.NumberOperation, updatedAt time.Time) database.Condition
// DeletedAtCondition filters for deleted users is isDeleted is set to true otherwise only not deleted users must be filtered.
DeletedCondition(isDeleted bool) database.Condition
// DeletedAtCondition filters for deleted users based on the given parameters.
DeletedAtCondition(op database.NumberOperation, deletedAt time.Time) database.Condition
}
// userChanges define all the changes for the user table.
type userChanges interface {
// SetUsername sets the username column.
SetUsername(username string) database.Change
}
// UserRepository is the interface for the user repository.
type UserRepository interface {
userColumns
userConditions
userChanges
// Get returns a user based on the given condition.
Get(ctx context.Context, opts ...database.QueryOption) (*User, error)
// List returns a list of users based on the given condition.
List(ctx context.Context, opts ...database.QueryOption) ([]*User, error)
// Create creates a new user.
Create(ctx context.Context, user *User) error
// Delete removes users based on the given condition.
Delete(ctx context.Context, condition database.Condition) error
// Human returns the [HumanRepository].
Human() HumanRepository
// Machine returns the [MachineRepository].
Machine() MachineRepository
}
// humanColumns define all the columns of the human table which inherits the user table.
type humanColumns interface {
userColumns
// FirstNameColumn returns the column for the first name field.
FirstNameColumn() database.Column
// LastNameColumn returns the column for the last name field.
LastNameColumn() database.Column
// EmailAddressColumn returns the column for the email address field.
EmailAddressColumn() database.Column
// EmailVerifiedAtColumn returns the column for the email verified at field.
EmailVerifiedAtColumn() database.Column
// PhoneNumberColumn returns the column for the phone number field.
PhoneNumberColumn() database.Column
// PhoneVerifiedAtColumn returns the column for the phone verified at field.
PhoneVerifiedAtColumn() database.Column
}
// humanConditions define all the conditions for the human table which inherits the user table.
type humanConditions interface {
userConditions
// FirstNameCondition returns a filter on the first name field.
FirstNameCondition(op database.TextOperation, firstName string) database.Condition
// LastNameCondition returns a filter on the last name field.
LastNameCondition(op database.TextOperation, lastName string) database.Condition
// EmailAddressCondition returns a filter on the email address field.
EmailAddressCondition(op database.TextOperation, email string) database.Condition
// EmailVerifiedCondition returns a filter that checks if the email is verified or not.
EmailVerifiedCondition(isVerified bool) database.Condition
// EmailVerifiedAtCondition returns a filter on the email verified at field.
EmailVerifiedAtCondition(op database.NumberOperation, emailVerifiedAt time.Time) database.Condition
// PhoneNumberCondition returns a filter on the phone number field.
PhoneNumberCondition(op database.TextOperation, phoneNumber string) database.Condition
// PhoneVerifiedCondition returns a filter that checks if the phone is verified or not.
PhoneVerifiedCondition(isVerified bool) database.Condition
// PhoneVerifiedAtCondition returns a filter on the phone verified at field.
PhoneVerifiedAtCondition(op database.NumberOperation, phoneVerifiedAt time.Time) database.Condition
}
// humanChanges define all the changes for the human table which inherits the user table.
type humanChanges interface {
userChanges
// SetFirstName sets the first name field of the human.
SetFirstName(firstName string) database.Change
// SetLastName sets the last name field of the human.
SetLastName(lastName string) database.Change
// SetEmail sets the email address and verified field of the email
// if verifiedAt is nil the email is not verified
SetEmail(address string, verifiedAt *time.Time) database.Change
// SetEmailAddress sets the email address field of the email
SetEmailAddress(email string) database.Change
// SetEmailVerifiedAt sets the verified column of the email
// if at is zero the statement uses the database timestamp
SetEmailVerifiedAt(at time.Time) database.Change
// SetPhone sets the phone number and verified field
// if verifiedAt is nil the phone is not verified
SetPhone(number string, verifiedAt *time.Time) database.Change
// SetPhoneNumber sets the phone number field
SetPhoneNumber(phoneNumber string) database.Change
// SetPhoneVerifiedAt sets the verified field of the phone
// if at is zero the statement uses the database timestamp
SetPhoneVerifiedAt(at time.Time) database.Change
}
// HumanRepository is the interface for the human repository it inherits the user repository.
type HumanRepository interface {
humanColumns
humanConditions
humanChanges
// Get returns an email based on the given condition.
GetEmail(ctx context.Context, condition database.Condition) (*Email, error)
// Update updates human users based on the given condition and changes.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
}
// machineColumns define all the columns of the machine table which inherits the user table.
type machineColumns interface {
userColumns
// DescriptionColumn returns the column for the description field.
DescriptionColumn() database.Column
}
// machineConditions define all the conditions for the machine table which inherits the user table.
type machineConditions interface {
userConditions
// DescriptionCondition returns a filter on the description field.
DescriptionCondition(op database.TextOperation, description string) database.Condition
}
// machineChanges define all the changes for the machine table which inherits the user table.
type machineChanges interface {
userChanges
// SetDescription sets the description field of the machine.
SetDescription(description string) database.Change
}
// MachineRepository is the interface for the machine repository it inherits the user repository.
type MachineRepository interface {
// Update updates machine users based on the given condition and changes.
Update(ctx context.Context, condition database.Condition, changes ...database.Change) error
machineColumns
machineConditions
machineChanges
}
// UserTraits is implemented by [Human] and [Machine].
type UserTraits interface {
Type() UserType
}
type UserType string
const (
UserTypeHuman UserType = "human"
UserTypeMachine UserType = "machine"
)
type User struct {
InstanceID string
OrgID string
ID string
CreatedAt time.Time
UpdatedAt time.Time
DeletedAt time.Time
Username string
Traits UserTraits
}
type Human struct {
FirstName string `json:"firstName"`
LastName string `json:"lastName"`
Email *Email `json:"email,omitempty"`
Phone *Phone `json:"phone,omitempty"`
}
// Type implements [UserTraits].
func (h *Human) Type() UserType {
return UserTypeHuman
}
var _ UserTraits = (*Human)(nil)
type Email struct {
Address string `json:"address"`
VerifiedAt time.Time `json:"verifiedAt"`
}
type Phone struct {
Number string `json:"number"`
VerifiedAt time.Time `json:"verifiedAt"`
}
type Machine struct {
Description string `json:"description"`
}
// Type implements [UserTraits].
func (m *Machine) Type() UserType {
return UserTypeMachine
}
var _ UserTraits = (*Machine)(nil)

112
backend/v3/storage/cache/cache.go vendored Normal file
View File

@@ -0,0 +1,112 @@
// Package cache provides abstraction of cache implementations that can be used by zitadel.
package cache
import (
"context"
"time"
"github.com/zitadel/logging"
)
// Purpose describes which object types are stored by a cache.
type Purpose int
//go:generate enumer -type Purpose -transform snake -trimprefix Purpose
const (
PurposeUnspecified Purpose = iota
PurposeAuthzInstance
PurposeMilestones
PurposeOrganization
PurposeIdPFormCallback
)
// Cache stores objects with a value of type `V`.
// Objects may be referred to by one or more indices.
// Implementations may encode the value for storage.
// This means non-exported fields may be lost and objects
// with function values may fail to encode.
// See https://pkg.go.dev/encoding/json#Marshal for example.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
// Indices are defined when calling the constructor of an implementation of this interface.
// It is illegal to refer to an idex not defined during construction.
//
// `K` is the type used as key in each index.
// Due to the limitations in type constraints, all indices use the same key type.
//
// Implementations are free to use stricter type constraints or fixed typing.
type Cache[I, K comparable, V Entry[I, K]] interface {
// Get an object through specified index.
// An [IndexUnknownError] may be returned if the index is unknown.
// [ErrCacheMiss] is returned if the key was not found in the index,
// or the object is not valid.
Get(ctx context.Context, index I, key K) (V, bool)
// Set an object.
// Keys are created on each index based in the [Entry.Keys] method.
// If any key maps to an existing object, the object is invalidated,
// regardless if the object has other keys defined in the new entry.
// This to prevent ghost objects when an entry reduces the amount of keys
// for a given index.
Set(ctx context.Context, value V)
// Invalidate an object through specified index.
// Implementations may choose to instantly delete the object,
// defer until prune or a separate cleanup routine.
// Invalidated object are no longer returned from Get.
// It is safe to call Invalidate multiple times or on non-existing entries.
Invalidate(ctx context.Context, index I, key ...K) error
// Delete one or more keys from a specific index.
// An [IndexUnknownError] may be returned if the index is unknown.
// The referred object is not invalidated and may still be accessible though
// other indices and keys.
// It is safe to call Delete multiple times or on non-existing entries
Delete(ctx context.Context, index I, key ...K) error
// Truncate deletes all cached objects.
Truncate(ctx context.Context) error
}
// Entry contains a value of type `V` to be cached.
//
// `I` is the type by which indices are identified,
// typically an enum for type-safe access.
//
// `K` is the type used as key in an index.
// Due to the limitations in type constraints, all indices use the same key type.
type Entry[I, K comparable] interface {
// Keys returns which keys map to the object in a specified index.
// May return nil if the index in unknown or when there are no keys.
Keys(index I) (key []K)
}
type Connector int
//go:generate enumer -type Connector -transform snake -trimprefix Connector -linecomment -text
const (
// Empty line comment ensures empty string for unspecified value
ConnectorUnspecified Connector = iota //
ConnectorMemory
ConnectorPostgres
ConnectorRedis
)
type Config struct {
Connector Connector
// Age since an object was added to the cache,
// after which the object is considered invalid.
// 0 disables max age checks.
MaxAge time.Duration
// Age since last use (Get) of an object,
// after which the object is considered invalid.
// 0 disables last use age checks.
LastUseAge time.Duration
// Log allows logging of the specific cache.
// By default only errors are logged to stdout.
Log *logging.Config
}

View File

@@ -0,0 +1,49 @@
// Package connector provides glue between the [cache.Cache] interface and implementations from the connector sub-packages.
package connector
import (
"context"
"fmt"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/gomap"
"github.com/zitadel/zitadel/backend/v3/storage/cache/connector/noop"
)
type CachesConfig struct {
Connectors struct {
Memory gomap.Config
}
Instance *cache.Config
Milestones *cache.Config
Organization *cache.Config
IdPFormCallbacks *cache.Config
}
type Connectors struct {
Config CachesConfig
Memory *gomap.Connector
}
func StartConnectors(conf *CachesConfig) (Connectors, error) {
if conf == nil {
return Connectors{}, nil
}
return Connectors{
Config: *conf,
Memory: gomap.NewConnector(conf.Connectors.Memory),
}, nil
}
func StartCache[I ~int, K ~string, V cache.Entry[I, K]](background context.Context, indices []I, purpose cache.Purpose, conf *cache.Config, connectors Connectors) (cache.Cache[I, K, V], error) {
if conf == nil || conf.Connector == cache.ConnectorUnspecified {
return noop.NewCache[I, K, V](), nil
}
if conf.Connector == cache.ConnectorMemory && connectors.Memory != nil {
c := gomap.NewCache[I, K, V](background, indices, *conf)
connectors.Memory.Config.StartAutoPrune(background, c, purpose)
return c, nil
}
return nil, fmt.Errorf("cache connector %q not enabled", conf.Connector)
}

View File

@@ -0,0 +1,23 @@
package gomap
import (
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type Config struct {
Enabled bool
AutoPrune cache.AutoPruneConfig
}
type Connector struct {
Config cache.AutoPruneConfig
}
func NewConnector(config Config) *Connector {
if !config.Enabled {
return nil
}
return &Connector{
Config: config.AutoPrune,
}
}

View File

@@ -0,0 +1,200 @@
package gomap
import (
"context"
"errors"
"log/slog"
"maps"
"os"
"sync"
"sync/atomic"
"time"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type mapCache[I, K comparable, V cache.Entry[I, K]] struct {
config *cache.Config
indexMap map[I]*index[K, V]
logger *slog.Logger
}
// NewCache returns an in-memory Cache implementation based on the builtin go map type.
// Object values are stored as-is and there is no encoding or decoding involved.
func NewCache[I, K comparable, V cache.Entry[I, K]](background context.Context, indices []I, config cache.Config) cache.PrunerCache[I, K, V] {
m := &mapCache[I, K, V]{
config: &config,
indexMap: make(map[I]*index[K, V], len(indices)),
logger: slog.New(slog.NewTextHandler(os.Stderr, &slog.HandlerOptions{
AddSource: true,
Level: slog.LevelError,
})),
}
if config.Log != nil {
m.logger = config.Log.Slog()
}
m.logger.InfoContext(background, "map cache logging enabled")
for _, name := range indices {
m.indexMap[name] = &index[K, V]{
config: m.config,
entries: make(map[K]*entry[V]),
}
}
return m
}
func (c *mapCache[I, K, V]) Get(ctx context.Context, index I, key K) (value V, ok bool) {
i, ok := c.indexMap[index]
if !ok {
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
entry, err := i.Get(key)
if err == nil {
c.logger.DebugContext(ctx, "map cache get", "index", index, "key", key)
return entry.value, true
}
if errors.Is(err, cache.ErrCacheMiss) {
c.logger.InfoContext(ctx, "map cache get", "err", err, "index", index, "key", key)
return value, false
}
c.logger.ErrorContext(ctx, "map cache get", "err", cache.NewIndexUnknownErr(index), "index", index, "key", key)
return value, false
}
func (c *mapCache[I, K, V]) Set(ctx context.Context, value V) {
now := time.Now()
entry := &entry[V]{
value: value,
created: now,
}
entry.lastUse.Store(now.UnixMicro())
for name, i := range c.indexMap {
keys := value.Keys(name)
i.Set(keys, entry)
c.logger.DebugContext(ctx, "map cache set", "index", name, "keys", keys)
}
}
func (c *mapCache[I, K, V]) Invalidate(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Invalidate(keys)
c.logger.DebugContext(ctx, "map cache invalidate", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Delete(ctx context.Context, index I, keys ...K) error {
i, ok := c.indexMap[index]
if !ok {
return cache.NewIndexUnknownErr(index)
}
i.Delete(keys)
c.logger.DebugContext(ctx, "map cache delete", "index", index, "keys", keys)
return nil
}
func (c *mapCache[I, K, V]) Prune(ctx context.Context) error {
for name, index := range c.indexMap {
index.Prune()
c.logger.DebugContext(ctx, "map cache prune", "index", name)
}
return nil
}
func (c *mapCache[I, K, V]) Truncate(ctx context.Context) error {
for name, index := range c.indexMap {
index.Truncate()
c.logger.DebugContext(ctx, "map cache truncate", "index", name)
}
return nil
}
type index[K comparable, V any] struct {
mutex sync.RWMutex
config *cache.Config
entries map[K]*entry[V]
}
func (i *index[K, V]) Get(key K) (*entry[V], error) {
i.mutex.RLock()
entry, ok := i.entries[key]
i.mutex.RUnlock()
if ok && entry.isValid(i.config) {
return entry, nil
}
return nil, cache.ErrCacheMiss
}
func (c *index[K, V]) Set(keys []K, entry *entry[V]) {
c.mutex.Lock()
for _, key := range keys {
c.entries[key] = entry
}
c.mutex.Unlock()
}
func (i *index[K, V]) Invalidate(keys []K) {
i.mutex.RLock()
for _, key := range keys {
if entry, ok := i.entries[key]; ok {
entry.invalid.Store(true)
}
}
i.mutex.RUnlock()
}
func (c *index[K, V]) Delete(keys []K) {
c.mutex.Lock()
for _, key := range keys {
delete(c.entries, key)
}
c.mutex.Unlock()
}
func (c *index[K, V]) Prune() {
c.mutex.Lock()
maps.DeleteFunc(c.entries, func(_ K, entry *entry[V]) bool {
return !entry.isValid(c.config)
})
c.mutex.Unlock()
}
func (c *index[K, V]) Truncate() {
c.mutex.Lock()
c.entries = make(map[K]*entry[V])
c.mutex.Unlock()
}
type entry[V any] struct {
value V
created time.Time
invalid atomic.Bool
lastUse atomic.Int64 // UnixMicro time
}
func (e *entry[V]) isValid(c *cache.Config) bool {
if e.invalid.Load() {
return false
}
now := time.Now()
if c.MaxAge > 0 {
if e.created.Add(c.MaxAge).Before(now) {
e.invalid.Store(true)
return false
}
}
if c.LastUseAge > 0 {
lastUse := e.lastUse.Load()
if time.UnixMicro(lastUse).Add(c.LastUseAge).Before(now) {
e.invalid.Store(true)
return false
}
e.lastUse.CompareAndSwap(lastUse, now.UnixMicro())
}
return true
}

View File

@@ -0,0 +1,329 @@
package gomap
import (
"context"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type testIndex int
const (
testIndexID testIndex = iota
testIndexName
)
var testIndices = []testIndex{
testIndexID,
testIndexName,
}
type testObject struct {
id string
names []string
}
func (o *testObject) Keys(index testIndex) []string {
switch index {
case testIndexID:
return []string{o.id}
case testIndexName:
return o.names
default:
return nil
}
}
func Test_mapCache_Get(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
type args struct {
index testIndex
key string
}
tests := []struct {
name string
args args
want *testObject
wantOk bool
}{
{
name: "ok",
args: args{
index: testIndexID,
key: "id",
},
want: obj,
wantOk: true,
},
{
name: "miss",
args: args{
index: testIndexID,
key: "spanac",
},
want: nil,
wantOk: false,
},
{
name: "unknown index",
args: args{
index: 99,
key: "id",
},
want: nil,
wantOk: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, ok := c.Get(context.Background(), tt.args.index, tt.args.key)
assert.Equal(t, tt.want, got)
assert.Equal(t, tt.wantOk, ok)
})
}
}
func Test_mapCache_Invalidate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
got, ok := c.Get(context.Background(), testIndexID, "id")
assert.Nil(t, got)
assert.False(t, ok)
}
func Test_mapCache_Delete(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
obj := &testObject{
id: "id",
names: []string{"foo", "bar"},
}
c.Set(context.Background(), obj)
err := c.Delete(context.Background(), testIndexName, "bar")
require.NoError(t, err)
// Shouldn't find object by deleted name
got, ok := c.Get(context.Background(), testIndexName, "bar")
assert.Nil(t, got)
assert.False(t, ok)
// Should find object by other name
got, ok = c.Get(context.Background(), testIndexName, "foo")
assert.Equal(t, obj, got)
assert.True(t, ok)
// Should find object by id
got, ok = c.Get(context.Background(), testIndexID, "id")
assert.Equal(t, obj, got)
assert.True(t, ok)
}
func Test_mapCache_Prune(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
// invalidate one entry
err := c.Invalidate(context.Background(), testIndexName, "bar")
require.NoError(t, err)
err = c.(cache.Pruner).Prune(context.Background())
require.NoError(t, err)
// Other object should still be found
got, ok := c.Get(context.Background(), testIndexID, "id2")
assert.Equal(t, objects[1], got)
assert.True(t, ok)
}
func Test_mapCache_Truncate(t *testing.T) {
c := NewCache[testIndex, string, *testObject](context.Background(), testIndices, cache.Config{
MaxAge: time.Second,
LastUseAge: time.Second / 4,
Log: &logging.Config{
Level: "debug",
AddSource: true,
},
})
objects := []*testObject{
{
id: "id1",
names: []string{"foo", "bar"},
},
{
id: "id2",
names: []string{"hello"},
},
}
for _, obj := range objects {
c.Set(context.Background(), obj)
}
err := c.Truncate(context.Background())
require.NoError(t, err)
mc := c.(*mapCache[testIndex, string, *testObject])
for _, index := range mc.indexMap {
index.mutex.RLock()
assert.Len(t, index.entries, 0)
index.mutex.RUnlock()
}
}
func Test_entry_isValid(t *testing.T) {
type fields struct {
created time.Time
invalid bool
lastUse time.Time
}
tests := []struct {
name string
fields fields
config *cache.Config
want bool
}{
{
name: "invalid",
fields: fields{
created: time.Now(),
invalid: true,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "max age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute + time.Second)),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
LastUseAge: time.Second,
},
want: true,
},
{
name: "last use age exceeded",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: false,
},
{
name: "last use age disabled",
fields: fields{
created: time.Now().Add(-(time.Minute / 2)),
invalid: false,
lastUse: time.Now().Add(-(time.Second * 2)),
},
config: &cache.Config{
MaxAge: time.Minute,
},
want: true,
},
{
name: "valid",
fields: fields{
created: time.Now(),
invalid: false,
lastUse: time.Now(),
},
config: &cache.Config{
MaxAge: time.Minute,
LastUseAge: time.Second,
},
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
e := &entry[any]{
created: tt.fields.created,
}
e.invalid.Store(tt.fields.invalid)
e.lastUse.Store(tt.fields.lastUse.UnixMicro())
got := e.isValid(tt.config)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,21 @@
package noop
import (
"context"
"github.com/zitadel/zitadel/backend/v3/storage/cache"
)
type noop[I, K comparable, V cache.Entry[I, K]] struct{}
// NewCache returns a cache that does nothing
func NewCache[I, K comparable, V cache.Entry[I, K]]() cache.Cache[I, K, V] {
return noop[I, K, V]{}
}
func (noop[I, K, V]) Set(context.Context, V) {}
func (noop[I, K, V]) Get(context.Context, I, K) (value V, ok bool) { return }
func (noop[I, K, V]) Invalidate(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Delete(context.Context, I, ...K) (err error) { return }
func (noop[I, K, V]) Prune(context.Context) (err error) { return }
func (noop[I, K, V]) Truncate(context.Context) (err error) { return }

View File

@@ -0,0 +1,98 @@
// Code generated by "enumer -type Connector -transform snake -trimprefix Connector -linecomment -text"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _ConnectorName = "memorypostgresredis"
var _ConnectorIndex = [...]uint8{0, 0, 6, 14, 19}
const _ConnectorLowerName = "memorypostgresredis"
func (i Connector) String() string {
if i < 0 || i >= Connector(len(_ConnectorIndex)-1) {
return fmt.Sprintf("Connector(%d)", i)
}
return _ConnectorName[_ConnectorIndex[i]:_ConnectorIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _ConnectorNoOp() {
var x [1]struct{}
_ = x[ConnectorUnspecified-(0)]
_ = x[ConnectorMemory-(1)]
_ = x[ConnectorPostgres-(2)]
_ = x[ConnectorRedis-(3)]
}
var _ConnectorValues = []Connector{ConnectorUnspecified, ConnectorMemory, ConnectorPostgres, ConnectorRedis}
var _ConnectorNameToValueMap = map[string]Connector{
_ConnectorName[0:0]: ConnectorUnspecified,
_ConnectorLowerName[0:0]: ConnectorUnspecified,
_ConnectorName[0:6]: ConnectorMemory,
_ConnectorLowerName[0:6]: ConnectorMemory,
_ConnectorName[6:14]: ConnectorPostgres,
_ConnectorLowerName[6:14]: ConnectorPostgres,
_ConnectorName[14:19]: ConnectorRedis,
_ConnectorLowerName[14:19]: ConnectorRedis,
}
var _ConnectorNames = []string{
_ConnectorName[0:0],
_ConnectorName[0:6],
_ConnectorName[6:14],
_ConnectorName[14:19],
}
// ConnectorString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func ConnectorString(s string) (Connector, error) {
if val, ok := _ConnectorNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _ConnectorNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Connector values", s)
}
// ConnectorValues returns all values of the enum
func ConnectorValues() []Connector {
return _ConnectorValues
}
// ConnectorStrings returns a slice of all String values of the enum
func ConnectorStrings() []string {
strs := make([]string, len(_ConnectorNames))
copy(strs, _ConnectorNames)
return strs
}
// IsAConnector returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Connector) IsAConnector() bool {
for _, v := range _ConnectorValues {
if i == v {
return true
}
}
return false
}
// MarshalText implements the encoding.TextMarshaler interface for Connector
func (i Connector) MarshalText() ([]byte, error) {
return []byte(i.String()), nil
}
// UnmarshalText implements the encoding.TextUnmarshaler interface for Connector
func (i *Connector) UnmarshalText(text []byte) error {
var err error
*i, err = ConnectorString(string(text))
return err
}

2
backend/v3/storage/cache/doc.go vendored Normal file
View File

@@ -0,0 +1,2 @@
// this package is copy pasted from the internal/cache package
package cache

29
backend/v3/storage/cache/error.go vendored Normal file
View File

@@ -0,0 +1,29 @@
package cache
import (
"errors"
"fmt"
)
type IndexUnknownError[I comparable] struct {
index I
}
func NewIndexUnknownErr[I comparable](index I) error {
return IndexUnknownError[I]{index}
}
func (i IndexUnknownError[I]) Error() string {
return fmt.Sprintf("index %v unknown", i.index)
}
func (a IndexUnknownError[I]) Is(err error) bool {
if b, ok := err.(IndexUnknownError[I]); ok {
return a.index == b.index
}
return false
}
var (
ErrCacheMiss = errors.New("cache miss")
)

76
backend/v3/storage/cache/pruner.go vendored Normal file
View File

@@ -0,0 +1,76 @@
package cache
import (
"context"
"math/rand"
"time"
"github.com/jonboulle/clockwork"
"github.com/zitadel/logging"
)
// Pruner is an optional [Cache] interface.
type Pruner interface {
// Prune deletes all invalidated or expired objects.
Prune(ctx context.Context) error
}
type PrunerCache[I, K comparable, V Entry[I, K]] interface {
Cache[I, K, V]
Pruner
}
type AutoPruneConfig struct {
// Interval at which the cache is automatically pruned.
// 0 or lower disables automatic pruning.
Interval time.Duration
// Timeout for an automatic prune.
// It is recommended to keep the value shorter than AutoPruneInterval
// 0 or lower disables automatic pruning.
Timeout time.Duration
}
func (c AutoPruneConfig) StartAutoPrune(background context.Context, pruner Pruner, purpose Purpose) (close func()) {
return c.startAutoPrune(background, pruner, purpose, clockwork.NewRealClock())
}
func (c *AutoPruneConfig) startAutoPrune(background context.Context, pruner Pruner, purpose Purpose, clock clockwork.Clock) (close func()) {
if c.Interval <= 0 {
return func() {}
}
background, cancel := context.WithCancel(background)
// randomize the first interval
timer := clock.NewTimer(time.Duration(rand.Int63n(int64(c.Interval))))
go c.pruneTimer(background, pruner, purpose, timer)
return cancel
}
func (c *AutoPruneConfig) pruneTimer(background context.Context, pruner Pruner, purpose Purpose, timer clockwork.Timer) {
defer func() {
if !timer.Stop() {
<-timer.Chan()
}
}()
for {
select {
case <-background.Done():
return
case <-timer.Chan():
err := c.doPrune(background, pruner)
logging.OnError(err).WithField("purpose", purpose).Error("cache auto prune")
timer.Reset(c.Interval)
}
}
}
func (c *AutoPruneConfig) doPrune(background context.Context, pruner Pruner) error {
ctx, cancel := context.WithCancel(background)
defer cancel()
if c.Timeout > 0 {
ctx, cancel = context.WithTimeout(background, c.Timeout)
defer cancel()
}
return pruner.Prune(ctx)
}

43
backend/v3/storage/cache/pruner_test.go vendored Normal file
View File

@@ -0,0 +1,43 @@
package cache
import (
"context"
"testing"
"time"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
)
type testPruner struct {
called chan struct{}
}
func (p *testPruner) Prune(context.Context) error {
p.called <- struct{}{}
return nil
}
func TestAutoPruneConfig_startAutoPrune(t *testing.T) {
c := AutoPruneConfig{
Interval: time.Second,
Timeout: time.Millisecond,
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
pruner := testPruner{
called: make(chan struct{}),
}
clock := clockwork.NewFakeClock()
close := c.startAutoPrune(ctx, &pruner, PurposeAuthzInstance, clock)
defer close()
clock.Advance(time.Second)
select {
case _, ok := <-pruner.called:
assert.True(t, ok)
case <-ctx.Done():
t.Fatal(ctx.Err())
}
}

View File

@@ -0,0 +1,90 @@
// Code generated by "enumer -type Purpose -transform snake -trimprefix Purpose"; DO NOT EDIT.
package cache
import (
"fmt"
"strings"
)
const _PurposeName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
var _PurposeIndex = [...]uint8{0, 11, 25, 35, 47, 65}
const _PurposeLowerName = "unspecifiedauthz_instancemilestonesorganizationid_p_form_callback"
func (i Purpose) String() string {
if i < 0 || i >= Purpose(len(_PurposeIndex)-1) {
return fmt.Sprintf("Purpose(%d)", i)
}
return _PurposeName[_PurposeIndex[i]:_PurposeIndex[i+1]]
}
// An "invalid array index" compiler error signifies that the constant values have changed.
// Re-run the stringer command to generate them again.
func _PurposeNoOp() {
var x [1]struct{}
_ = x[PurposeUnspecified-(0)]
_ = x[PurposeAuthzInstance-(1)]
_ = x[PurposeMilestones-(2)]
_ = x[PurposeOrganization-(3)]
_ = x[PurposeIdPFormCallback-(4)]
}
var _PurposeValues = []Purpose{PurposeUnspecified, PurposeAuthzInstance, PurposeMilestones, PurposeOrganization, PurposeIdPFormCallback}
var _PurposeNameToValueMap = map[string]Purpose{
_PurposeName[0:11]: PurposeUnspecified,
_PurposeLowerName[0:11]: PurposeUnspecified,
_PurposeName[11:25]: PurposeAuthzInstance,
_PurposeLowerName[11:25]: PurposeAuthzInstance,
_PurposeName[25:35]: PurposeMilestones,
_PurposeLowerName[25:35]: PurposeMilestones,
_PurposeName[35:47]: PurposeOrganization,
_PurposeLowerName[35:47]: PurposeOrganization,
_PurposeName[47:65]: PurposeIdPFormCallback,
_PurposeLowerName[47:65]: PurposeIdPFormCallback,
}
var _PurposeNames = []string{
_PurposeName[0:11],
_PurposeName[11:25],
_PurposeName[25:35],
_PurposeName[35:47],
_PurposeName[47:65],
}
// PurposeString retrieves an enum value from the enum constants string name.
// Throws an error if the param is not part of the enum.
func PurposeString(s string) (Purpose, error) {
if val, ok := _PurposeNameToValueMap[s]; ok {
return val, nil
}
if val, ok := _PurposeNameToValueMap[strings.ToLower(s)]; ok {
return val, nil
}
return 0, fmt.Errorf("%s does not belong to Purpose values", s)
}
// PurposeValues returns all values of the enum
func PurposeValues() []Purpose {
return _PurposeValues
}
// PurposeStrings returns a slice of all String values of the enum
func PurposeStrings() []string {
strs := make([]string, len(_PurposeNames))
copy(strs, _PurposeNames)
return strs
}
// IsAPurpose returns "true" if the value is listed in the enum definition. "false" otherwise
func (i Purpose) IsAPurpose() bool {
for _, v := range _PurposeValues {
if i == v {
return true
}
}
return false
}

View File

@@ -0,0 +1,53 @@
package database
// Change represents a change to a column in a database table.
// Its written in the SET clause of an UPDATE statement.
type Change interface {
Write(builder *StatementBuilder)
}
type change[V Value] struct {
column Column
value V
}
var _ Change = (*change[string])(nil)
func NewChange[V Value](col Column, value V) Change {
return &change[V]{
column: col,
value: value,
}
}
func NewChangePtr[V Value](col Column, value *V) Change {
if value == nil {
return NewChange(col, NullInstruction)
}
return NewChange(col, *value)
}
// Write implements [Change].
func (c change[V]) Write(builder *StatementBuilder) {
c.column.WriteUnqualified(builder)
builder.WriteString(" = ")
builder.WriteArg(c.value)
}
type Changes []Change
func NewChanges(cols ...Change) Change {
return Changes(cols)
}
// Write implements [Change].
func (m Changes) Write(builder *StatementBuilder) {
for i, col := range m {
if i > 0 {
builder.WriteString(", ")
}
col.Write(builder)
}
}
var _ Change = Changes(nil)

View File

@@ -0,0 +1,85 @@
package database
type Columns []Column
// WriteQualified implements [Column].
func (m Columns) WriteQualified(builder *StatementBuilder) {
for i, col := range m {
if i > 0 {
builder.WriteString(", ")
}
col.WriteQualified(builder)
}
}
// WriteUnqualified implements [Column].
func (m Columns) WriteUnqualified(builder *StatementBuilder) {
for i, col := range m {
if i > 0 {
builder.WriteString(", ")
}
col.WriteUnqualified(builder)
}
}
// Column represents a column in a database table.
type Column interface {
// Write(builder *StatementBuilder)
WriteQualified(builder *StatementBuilder)
WriteUnqualified(builder *StatementBuilder)
}
type column struct {
table string
name string
}
func NewColumn(table, name string) Column {
return column{table: table, name: name}
}
// WriteQualified implements [Column].
func (c column) WriteQualified(builder *StatementBuilder) {
builder.Grow(len(c.table) + len(c.name) + 1)
builder.WriteString(c.table)
builder.WriteRune('.')
builder.WriteString(c.name)
}
// WriteUnqualified implements [Column].
func (c column) WriteUnqualified(builder *StatementBuilder) {
builder.WriteString(c.name)
}
var _ Column = (*column)(nil)
// // ignoreCaseColumn represents two database columns, one for the
// // original value and one for the lower case value.
// type ignoreCaseColumn interface {
// Column
// WriteIgnoreCase(builder *StatementBuilder)
// }
// func NewIgnoreCaseColumn(col Column, suffix string) ignoreCaseColumn {
// return ignoreCaseCol{
// column: col,
// suffix: suffix,
// }
// }
// type ignoreCaseCol struct {
// column Column
// suffix string
// }
// // WriteIgnoreCase implements [ignoreCaseColumn].
// func (c ignoreCaseCol) WriteIgnoreCase(builder *StatementBuilder) {
// c.column.WriteQualified(builder)
// builder.WriteString(c.suffix)
// }
// // WriteQualified implements [ignoreCaseColumn].
// func (c ignoreCaseCol) WriteQualified(builder *StatementBuilder) {
// c.column.WriteQualified(builder)
// builder.WriteString(c.suffix)
// }

View File

@@ -0,0 +1,130 @@
package database
// Condition represents a SQL condition.
// Its written after the WHERE keyword in a SQL statement.
type Condition interface {
Write(builder *StatementBuilder)
}
type and struct {
conditions []Condition
}
// Write implements [Condition].
func (a *and) Write(builder *StatementBuilder) {
if len(a.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
}
for i, condition := range a.conditions {
if i > 0 {
builder.WriteString(" AND ")
}
condition.Write(builder)
}
}
// And combines multiple conditions with AND.
func And(conditions ...Condition) *and {
return &and{conditions: conditions}
}
var _ Condition = (*and)(nil)
type or struct {
conditions []Condition
}
// Write implements [Condition].
func (o *or) Write(builder *StatementBuilder) {
if len(o.conditions) > 1 {
builder.WriteString("(")
defer builder.WriteString(")")
}
for i, condition := range o.conditions {
if i > 0 {
builder.WriteString(" OR ")
}
condition.Write(builder)
}
}
// Or combines multiple conditions with OR.
func Or(conditions ...Condition) *or {
return &or{conditions: conditions}
}
var _ Condition = (*or)(nil)
type isNull struct {
column Column
}
// Write implements [Condition].
func (i *isNull) Write(builder *StatementBuilder) {
i.column.WriteQualified(builder)
builder.WriteString(" IS NULL")
}
// IsNull creates a condition that checks if a column is NULL.
func IsNull(column Column) *isNull {
return &isNull{column: column}
}
var _ Condition = (*isNull)(nil)
type isNotNull struct {
column Column
}
// Write implements [Condition].
func (i *isNotNull) Write(builder *StatementBuilder) {
i.column.WriteQualified(builder)
builder.WriteString(" IS NOT NULL")
}
// IsNotNull creates a condition that checks if a column is NOT NULL.
func IsNotNull(column Column) *isNotNull {
return &isNotNull{column: column}
}
var _ Condition = (*isNotNull)(nil)
type valueCondition func(builder *StatementBuilder)
// NewTextCondition creates a condition that compares a text column with a value.
func NewTextCondition[V Text](col Column, op TextOperation, value V) Condition {
return valueCondition(func(builder *StatementBuilder) {
writeTextOperation(builder, col, op, value)
})
}
// NewDateCondition creates a condition that compares a numeric column with a value.
func NewNumberCondition[V Number](col Column, op NumberOperation, value V) Condition {
return valueCondition(func(builder *StatementBuilder) {
writeNumberOperation(builder, col, op, value)
})
}
// NewDateCondition creates a condition that compares a boolean column with a value.
func NewBooleanCondition[V Boolean](col Column, value V) Condition {
return valueCondition(func(builder *StatementBuilder) {
writeBooleanOperation(builder, col, value)
})
}
// NewColumnCondition creates a condition that compares two columns on equality.
func NewColumnCondition(col1, col2 Column) Condition {
return valueCondition(func(builder *StatementBuilder) {
col1.WriteQualified(builder)
builder.WriteString(" = ")
col2.WriteQualified(builder)
})
}
// Write implements [Condition].
func (c valueCondition) Write(builder *StatementBuilder) {
c(builder)
}
var _ Condition = (*valueCondition)(nil)

View File

@@ -0,0 +1,10 @@
package database
import (
"context"
)
// Connector abstracts the database driver.
type Connector interface {
Connect(ctx context.Context) (Pool, error)
}

View File

@@ -0,0 +1,83 @@
package database
import (
"context"
)
// Pool is a connection pool. e.g. pgxpool
type Pool interface {
Beginner
QueryExecutor
Migrator
Acquire(ctx context.Context) (Client, error)
Close(ctx context.Context) error
}
type PoolTest interface {
Pool
// MigrateTest is the same as [Migrator] but executes the migrations multiple times instead of only once.
MigrateTest(ctx context.Context) error
}
// Client is a single database connection which can be released back to the pool.
type Client interface {
Beginner
QueryExecutor
Migrator
Release(ctx context.Context) error
}
// Querier is a database client that can execute queries and return rows.
type Querier interface {
Query(ctx context.Context, stmt string, args ...any) (Rows, error)
QueryRow(ctx context.Context, stmt string, args ...any) Row
}
// Executor is a database client that can execute statements.
// It returns the number of rows affected or an error
type Executor interface {
Exec(ctx context.Context, stmt string, args ...any) (int64, error)
}
// QueryExecutor is a database client that can execute queries and statements.
type QueryExecutor interface {
Querier
Executor
}
// Scanner scans a single row of data into the destination.
type Scanner interface {
Scan(dest ...any) error
}
// Row is an abstraction of sql.Row.
type Row interface {
Scanner
}
// Rows is an abstraction of sql.Rows.
type Rows interface {
Scanner
Next() bool
Close() error
Err() error
}
type CollectableRows interface {
// Collect collects all rows and scans them into dest.
// dest must be a pointer to a slice of pointer to structs
// e.g. *[]*MyStruct
// Rows are closed after this call.
Collect(dest any) error
// CollectFirst collects the first row and scans it into dest.
// dest must be a pointer to a struct
// e.g. *MyStruct{}
// Rows are closed after this call.
CollectFirst(dest any) error
// CollectExactlyOneRow collects exactly one row and scans it into dest.
// e.g. *MyStruct{}
// Rows are closed after this call.
CollectExactlyOneRow(dest any) error
}

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,92 @@
package dialect
import (
"context"
"errors"
"reflect"
"github.com/mitchellh/mapstructure"
"github.com/spf13/viper"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
)
type Hook struct {
Match func(string) bool
Decode func(config any) (database.Connector, error)
Name string
Constructor func() database.Connector
}
var hooks = []Hook{
{
Match: postgres.NameMatcher,
Decode: postgres.DecodeConfig,
Name: postgres.Name,
Constructor: func() database.Connector { return new(postgres.Config) },
},
// {
// Match: gosql.NameMatcher,
// Decode: gosql.DecodeConfig,
// Name: gosql.Name,
// Constructor: func() database.Connector { return new(gosql.Config) },
// },
}
type Config struct {
Dialects map[string]any `mapstructure:",remain" yaml:",inline"`
connector database.Connector
}
func (c Config) Connect(ctx context.Context) (database.Pool, error) {
if len(c.Dialects) != 1 {
return nil, errors.New("exactly one dialect must be configured")
}
return c.connector.Connect(ctx)
}
// Hooks implements [configure.Unmarshaller].
func (c Config) Hooks() []viper.DecoderConfigOption {
return []viper.DecoderConfigOption{
viper.DecodeHook(decodeHook),
}
}
func decodeHook(from, to reflect.Value) (_ any, err error) {
if to.Type() != reflect.TypeOf(Config{}) {
return from.Interface(), nil
}
config := new(Config)
if err = mapstructure.Decode(from.Interface(), config); err != nil {
return nil, err
}
if err = config.decodeDialect(); err != nil {
return nil, err
}
return config, nil
}
func (c *Config) decodeDialect() error {
for _, hook := range hooks {
for name, config := range c.Dialects {
if !hook.Match(name) {
continue
}
connector, err := hook.Decode(config)
if err != nil {
return err
}
c.connector = connector
return nil
}
}
return errors.New("no dialect found")
}

View File

@@ -0,0 +1,89 @@
package postgres
import (
"context"
"errors"
"slices"
"strings"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/mitchellh/mapstructure"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var (
_ database.Connector = (*Config)(nil)
Name = "postgres"
isMigrated bool
)
type Config struct {
*pgxpool.Config
*pgxpool.Pool
// Host string
// Port int32
// Database string
// MaxOpenConns uint32
// MaxIdleConns uint32
// MaxConnLifetime time.Duration
// MaxConnIdleTime time.Duration
// User User
// // Additional options to be appended as options=<Options>
// // The value will be taken as is. Multiple options are space separated.
// Options string
// configuredFields []string
}
// Connect implements [database.Connector].
func (c *Config) Connect(ctx context.Context) (database.Pool, error) {
pool, err := c.getPool(ctx)
if err != nil {
return nil, wrapError(err)
}
if err = pool.Ping(ctx); err != nil {
return nil, wrapError(err)
}
return &pgxPool{Pool: pool}, nil
}
func (c *Config) getPool(ctx context.Context) (*pgxpool.Pool, error) {
if c.Pool != nil {
return c.Pool, nil
}
return pgxpool.NewWithConfig(ctx, c.Config)
}
func NameMatcher(name string) bool {
return slices.Contains([]string{"postgres", "pg"}, strings.ToLower(name))
}
func DecodeConfig(input any) (database.Connector, error) {
switch c := input.(type) {
case string:
config, err := pgxpool.ParseConfig(c)
if err != nil {
return nil, err
}
return &Config{Config: config}, nil
case map[string]any:
connector := new(Config)
decoder, err := mapstructure.NewDecoder(&mapstructure.DecoderConfig{
DecodeHook: mapstructure.StringToTimeDurationHookFunc(),
WeaklyTypedInput: true,
Result: connector,
})
if err != nil {
return nil, err
}
if err = decoder.Decode(c); err != nil {
return nil, err
}
return &Config{
Config: &pgxpool.Config{},
}, nil
}
return nil, errors.New("invalid configuration")
}

View File

@@ -0,0 +1,67 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration"
)
type pgxConn struct {
*pgxpool.Conn
}
var _ database.Client = (*pgxConn)(nil)
// Release implements [database.Client].
func (c *pgxConn) Release(_ context.Context) error {
c.Conn.Release()
return nil
}
// Begin implements [database.Client].
func (c *pgxConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, wrapError(err)
}
return &pgxTx{tx}, nil
}
// Query implements sql.Client.
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
func (c *pgxConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Conn.Query(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
}
// QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.Conn.QueryRow(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Conn.Exec(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected(), nil
}
// Migrate implements [database.Migrator].
func (c *pgxConn) Migrate(ctx context.Context) error {
if isMigrated {
return nil
}
err := migration.Migrate(ctx, c.Conn.Conn())
isMigrated = err == nil
return wrapError(err)
}

View File

@@ -0,0 +1,2 @@
// pgxpool v5 implementation of the interfaces defined in the database package.
package postgres

View File

@@ -0,0 +1,50 @@
// embedded is used for testing purposes
package embedded
import (
"net"
"os"
embeddedpostgres "github.com/fergusstrange/embedded-postgres"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
)
// StartEmbedded starts an embedded postgres v16 instance and returns a database connector and a stop function
// the database is started on a random port and data are stored in a temporary directory
// its used for testing purposes only
func StartEmbedded() (connector database.Connector, stop func(), err error) {
path, err := os.MkdirTemp("", "zitadel-embedded-postgres-*")
logging.OnError(err).Fatal("unable to create temp dir")
port, close := getPort()
config := embeddedpostgres.DefaultConfig().Version(embeddedpostgres.V16).Port(uint32(port)).RuntimePath(path)
embedded := embeddedpostgres.NewDatabase(config)
close()
err = embedded.Start()
logging.OnError(err).Fatal("unable to start db")
connector, err = postgres.DecodeConfig(config.GetConnectionURL())
if err != nil {
return nil, nil, err
}
return connector, func() {
logging.OnError(embedded.Stop()).Error("unable to stop db")
}, nil
}
// getPort returns a free port and locks it until close is called
func getPort() (port uint16, close func()) {
l, err := net.Listen("tcp", ":0")
logging.OnError(err).Fatal("unable to get port")
port = uint16(l.Addr().(*net.TCPAddr).Port)
logging.WithFields("port", port).Info("Port is available")
return port, func() {
logging.OnError(l.Close()).Error("unable to close port listener")
}
}

View File

@@ -0,0 +1,38 @@
package postgres
import (
"errors"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
func wrapError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, pgx.ErrNoRows) {
return database.NewNoRowFoundError(err)
}
var pgxErr *pgconn.PgError
if !errors.As(err, &pgxErr) {
return database.NewUnknownError(err)
}
switch pgxErr.Code {
// 23514: check_violation - A value violates a CHECK constraint.
case "23514":
return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23505: unique_violation - A value violates a UNIQUE constraint.
case "23505":
return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23503: foreign_key_violation - A value violates a foreign key constraint.
case "23503":
return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23502: not_null_violation - A value violates a NOT NULL constraint.
case "23502":
return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
}
return database.NewUnknownError(err)
}

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 001_instance_table/up.sql
up001InstanceTable string
//go:embed 001_instance_table/down.sql
down001InstanceTable string
)
func init() {
registerSQLMigration(1, up001InstanceTable, down001InstanceTable)
}

View File

@@ -0,0 +1 @@
DROP TABLE zitadel.instances;

View File

@@ -0,0 +1,25 @@
CREATE TABLE IF NOT EXISTS zitadel.instances(
id TEXT NOT NULL CHECK (id <> '') PRIMARY KEY,
name TEXT NOT NULL CHECK (name <> ''),
default_org_id TEXT, -- NOT NULL,
iam_project_id TEXT, -- NOT NULL,
console_client_id TEXT, -- NOT NULL,
console_app_id TEXT, -- NOT NULL,
default_language TEXT, -- NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
);
CREATE OR REPLACE FUNCTION zitadel.set_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at := NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.instances
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 002_organization_table/up.sql
up002OrganizationTable string
//go:embed 002_organization_table/down.sql
down002OrganizationTable string
)
func init() {
registerSQLMigration(2, up002OrganizationTable, down002OrganizationTable)
}

View File

@@ -0,0 +1,2 @@
DROP TABLE zitadel.organizations;
DROP Type zitadel.organization_state;

View File

@@ -0,0 +1,24 @@
CREATE TYPE zitadel.organization_state AS ENUM (
'active',
'inactive'
);
CREATE TABLE zitadel.organizations(
id TEXT NOT NULL CHECK (id <> ''),
name TEXT NOT NULL CHECK (name <> ''),
instance_id TEXT NOT NULL REFERENCES zitadel.instances (id) ON DELETE CASCADE,
state zitadel.organization_state NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL,
PRIMARY KEY (instance_id, id)
);
CREATE UNIQUE INDEX org_unique_instance_id_name_idx
ON zitadel.organizations (instance_id, name);
CREATE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.organizations
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -0,0 +1,16 @@
package migration
import (
_ "embed"
)
var (
//go:embed 003_domains_table/up.sql
up003DomainsTable string
//go:embed 003_domains_table/down.sql
down003DomainsTable string
)
func init() {
registerSQLMigration(3, up003DomainsTable, down003DomainsTable)
}

View File

@@ -0,0 +1,6 @@
DROP TABLE IF EXISTS zitadel.instance_domains;
DROP TABLE IF EXISTS zitadel.org_domains;
DROP TYPE IF EXISTS zitadel.domain_type;
DROP TYPE IF EXISTS zitadel.domain_validation_type;
DROP FUNCTION IF EXISTS zitadel.check_verified_org_domain();
DROP FUNCTION IF EXISTS zitadel.ensure_single_primary_instance_domain();

View File

@@ -0,0 +1,137 @@
CREATE TYPE zitadel.domain_validation_type AS ENUM (
'dns'
, 'http'
);
CREATE TYPE zitadel.domain_type AS ENUM (
'custom'
, 'trusted'
);
CREATE TABLE zitadel.instance_domains(
instance_id TEXT NOT NULL
, domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255)
, is_primary BOOLEAN
, is_generated BOOLEAN
, type zitadel.domain_type NOT NULL
, created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
, updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
, PRIMARY KEY (domain)
, FOREIGN KEY (instance_id) REFERENCES zitadel.instances(id) ON DELETE CASCADE
, CONSTRAINT primary_cannot_be_trusted CHECK (is_primary IS NULL OR type != 'trusted')
, CONSTRAINT generated_cannot_be_trusted CHECK (is_generated IS NULL OR type != 'trusted')
, CONSTRAINT custom_values_set CHECK ((is_primary IS NOT NULL AND is_generated IS NOT NULL) OR type != 'custom')
);
CREATE INDEX idx_instance_domain_instance ON zitadel.instance_domains(instance_id);
CREATE TABLE zitadel.org_domains(
instance_id TEXT NOT NULL
, org_id TEXT NOT NULL
, domain TEXT NOT NULL CHECK (LENGTH(domain) BETWEEN 1 AND 255)
, is_verified BOOLEAN NOT NULL DEFAULT FALSE
, is_primary BOOLEAN NOT NULL DEFAULT FALSE
, validation_type zitadel.domain_validation_type
, created_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
, updated_at TIMESTAMPTZ DEFAULT NOW() NOT NULL
, PRIMARY KEY (instance_id, org_id, domain)
, FOREIGN KEY (instance_id, org_id) REFERENCES zitadel.organizations(instance_id, id) ON DELETE CASCADE
, UNIQUE (instance_id, org_id, domain)
);
CREATE INDEX idx_org_domain ON zitadel.org_domains(instance_id, domain);
-- Trigger to update the updated_at timestamp on instance_domains
CREATE TRIGGER trg_set_updated_at_instance_domains
BEFORE UPDATE ON zitadel.instance_domains
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();
-- Trigger to update the updated_at timestamp on org_domains
CREATE TRIGGER trg_set_updated_at_org_domains
BEFORE UPDATE ON zitadel.org_domains
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();
-- Function to check for already verified org domains
CREATE OR REPLACE FUNCTION zitadel.check_verified_org_domain()
RETURNS TRIGGER AS $$
BEGIN
-- Check if there's already a verified domain within this instance (excluding the current record being updated)
IF EXISTS (
SELECT 1
FROM zitadel.org_domains
WHERE instance_id = NEW.instance_id
AND domain = NEW.domain
AND is_verified = TRUE
AND (TG_OP = 'INSERT' OR (org_id != NEW.org_id))
) THEN
RAISE EXCEPTION 'org domain is already taken';
END IF;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- Trigger to enforce verified domain constraint on org_domains
CREATE TRIGGER trg_check_verified_org_domain
BEFORE INSERT OR UPDATE ON zitadel.org_domains
FOR EACH ROW
WHEN (NEW.is_verified IS TRUE)
EXECUTE FUNCTION zitadel.check_verified_org_domain();
-- Function to ensure only one primary domain per instance in instance_domains
CREATE OR REPLACE FUNCTION zitadel.ensure_single_primary_instance_domain()
RETURNS TRIGGER AS $$
BEGIN
-- If setting this domain as primary, update all other domains in the same instance to non-primary
UPDATE zitadel.instance_domains
SET is_primary = FALSE, updated_at = NOW()
WHERE instance_id = NEW.instance_id
AND domain != NEW.domain
AND is_primary = TRUE
AND type = 'custom';
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- Trigger to enforce single primary domain constraint on instance_domains
CREATE TRIGGER trg_ensure_single_primary_instance_domain
BEFORE INSERT OR UPDATE ON zitadel.instance_domains
FOR EACH ROW
WHEN (NEW.is_primary IS TRUE)
EXECUTE FUNCTION zitadel.ensure_single_primary_instance_domain();
-- Function to ensure only one primary domain per organization in org_domains
CREATE OR REPLACE FUNCTION zitadel.ensure_single_primary_org_domain()
RETURNS TRIGGER AS $$
BEGIN
-- If setting this domain as primary, update all other domains in the same organization to non-primary
UPDATE zitadel.org_domains
SET is_primary = FALSE, updated_at = NOW()
WHERE instance_id = NEW.instance_id
AND org_id = NEW.org_id
AND domain != NEW.domain
AND is_primary = TRUE;
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
-- Trigger to enforce single primary domain constraint on org_domains
CREATE TRIGGER trg_ensure_single_primary_org_domain
BEFORE INSERT OR UPDATE ON zitadel.org_domains
FOR EACH ROW
WHEN (NEW.is_primary IS TRUE)
EXECUTE FUNCTION zitadel.ensure_single_primary_org_domain();

View File

@@ -0,0 +1,13 @@
// This package contains the migration logic for the PostgreSQL dialect.
// It uses the [github.com/jackc/tern/v2/migrate] package to handle the migration process.
//
// **Developer Note**:
//
// Each migration MUST be registered in an init function.
// Create a go file for each migration with the sequence of the migration as prefix and some descriptive name.
// The file name MUST be in the format <sequence>_<name>.go.
// Each migration SHOULD provide an up and down migration.
// Prefer to write SQL statements instead of funcs if it is reasonable.
// To keep the folder clean create a folder to store the sql files with the following format: <sequence>_<name>/{up/down}.sql.
// And use the go embed directive to embed the sql files.
package migration

View File

@@ -0,0 +1,33 @@
package migration
import (
"context"
"github.com/jackc/pgx/v5"
"github.com/jackc/tern/v2/migrate"
)
var migrations []*migrate.Migration
func Migrate(ctx context.Context, conn *pgx.Conn) error {
// we need to ensure that the schema exists before we can run the migration
// because creating the migrations table already required the schema
_, err := conn.Exec(ctx, "CREATE SCHEMA IF NOT EXISTS zitadel")
if err != nil {
return err
}
migrator, err := migrate.NewMigrator(ctx, conn, "zitadel.migrations")
if err != nil {
return err
}
migrator.Migrations = migrations
return migrator.Migrate(ctx)
}
func registerSQLMigration(sequence int32, up, down string) {
migrations = append(migrations, &migrate.Migration{
Sequence: sequence,
UpSQL: up,
DownSQL: down,
})
}

View File

@@ -0,0 +1,60 @@
package migration_test
import (
"context"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded"
)
func TestMigrate(t *testing.T) {
tests := []struct {
name string
stmt string
args []any
res []any
}{
{
name: "schema",
stmt: "SELECT EXISTS(SELECT 1 FROM information_schema.schemata where schema_name = 'zitadel') ;",
res: []any{true},
},
{
name: "001",
stmt: "SELECT EXISTS(SELECT 1 FROM pg_catalog.pg_tables WHERE schemaname = 'zitadel' and tablename=$1)",
args: []any{"instances"},
res: []any{true},
},
}
ctx := context.Background()
connector, stop, err := embedded.StartEmbedded()
require.NoError(t, err, "failed to start embedded postgres")
defer stop()
client, err := connector.Connect(ctx)
require.NoError(t, err, "failed to connect to embedded postgres")
err = client.(database.Migrator).Migrate(ctx)
require.NoError(t, err, "failed to execute migration steps")
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := make([]any, len(tt.res))
for i := range got {
got[i] = new(any)
tt.res[i] = gu.Ptr(tt.res[i])
}
require.NoError(t, client.QueryRow(ctx, tt.stmt, tt.args...).Scan(got...), "failed to execute check query")
assert.Equal(t, tt.res, got, "query result does not match")
})
}
}

View File

@@ -0,0 +1,100 @@
package postgres
import (
"context"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/migration"
)
type pgxPool struct {
*pgxpool.Pool
}
var _ database.Pool = (*pgxPool)(nil)
func PGxPool(pool *pgxpool.Pool) *pgxPool {
return &pgxPool{
Pool: pool,
}
}
// Acquire implements [database.Pool].
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Pool.Acquire(ctx)
if err != nil {
return nil, wrapError(err)
}
return &pgxConn{Conn: conn}, nil
}
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *pgxPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := c.Pool.Query(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
}
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.Pool.QueryRow(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Pool.Exec(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected(), nil
}
// Begin implements [database.Pool].
func (c *pgxPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToPgx(opts))
if err != nil {
return nil, wrapError(err)
}
return &pgxTx{tx}, nil
}
// Close implements [database.Pool].
func (c *pgxPool) Close(_ context.Context) error {
c.Pool.Close()
return nil
}
// Migrate implements [database.Migrator].
func (c *pgxPool) Migrate(ctx context.Context) error {
if isMigrated {
return nil
}
client, err := c.Pool.Acquire(ctx)
if err != nil {
return err
}
err = migration.Migrate(ctx, client.Conn())
isMigrated = err == nil
return wrapError(err)
}
// Migrate implements [database.PoolTest].
func (c *pgxPool) MigrateTest(ctx context.Context) error {
client, err := c.Pool.Acquire(ctx)
if err != nil {
return err
}
err = migration.Migrate(ctx, client.Conn())
isMigrated = err == nil
return err
}

View File

@@ -0,0 +1,77 @@
package postgres
import (
"github.com/georgysavva/scany/v2/pgxscan"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var (
_ database.Rows = (*Rows)(nil)
_ database.CollectableRows = (*Rows)(nil)
_ database.Row = (*Row)(nil)
)
type Row struct{ pgx.Row }
// Scan implements [database.Row].
// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row.
func (r *Row) Scan(dest ...any) error {
return wrapError(r.Row.Scan(dest...))
}
type Rows struct{ pgx.Rows }
// Err implements [database.Rows].
// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows.
func (r *Rows) Err() error {
return wrapError(r.Rows.Err())
}
func (r *Rows) Scan(dest ...any) error {
return wrapError(r.Rows.Scan(dest...))
}
// Collect implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) Collect(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return wrapError(pgxscan.ScanAll(dest, r.Rows))
}
// CollectFirst implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) CollectFirst(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return wrapError(pgxscan.ScanRow(dest, r.Rows))
}
// CollectExactlyOneRow implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return wrapError(pgxscan.ScanOne(dest, r.Rows))
}
// Close implements [database.Rows].
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
func (r *Rows) Close() error {
r.Rows.Close()
return nil
}

View File

@@ -0,0 +1,107 @@
package postgres
import (
"context"
"errors"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type pgxTx struct{ pgx.Tx }
var _ database.Transaction = (*pgxTx)(nil)
// Commit implements [database.Transaction].
func (tx *pgxTx) Commit(ctx context.Context) error {
err := tx.Tx.Commit(ctx)
return wrapError(err)
}
// Rollback implements [database.Transaction].
func (tx *pgxTx) Rollback(ctx context.Context) error {
err := tx.Tx.Rollback(ctx)
return wrapError(err)
}
// End implements [database.Transaction].
func (tx *pgxTx) End(ctx context.Context, err error) error {
if err != nil {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil {
err = errors.Join(err, rollbackErr)
}
return err
}
return tx.Commit(ctx)
}
// Query implements [database.Transaction].
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (tx *pgxTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
rows, err := tx.Tx.Query(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
}
// QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{tx.Tx.QueryRow(ctx, sql, args...)}
}
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := tx.Tx.Exec(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected(), nil
}
// Begin implements [database.Transaction].
// As postgres does not support nested transactions we use savepoints to emulate them.
func (tx *pgxTx) Begin(ctx context.Context) (database.Transaction, error) {
savepoint, err := tx.Tx.Begin(ctx)
if err != nil {
return nil, wrapError(err)
}
return &pgxTx{savepoint}, nil
}
func transactionOptionsToPgx(opts *database.TransactionOptions) pgx.TxOptions {
if opts == nil {
return pgx.TxOptions{}
}
return pgx.TxOptions{
IsoLevel: isolationToPgx(opts.IsolationLevel),
AccessMode: accessModeToPgx(opts.AccessMode),
}
}
func isolationToPgx(isolation database.IsolationLevel) pgx.TxIsoLevel {
switch isolation {
case database.IsolationLevelSerializable:
return pgx.Serializable
case database.IsolationLevelReadCommitted:
return pgx.ReadCommitted
default:
return pgx.Serializable
}
}
func accessModeToPgx(accessMode database.AccessMode) pgx.TxAccessMode {
switch accessMode {
case database.AccessModeReadWrite:
return pgx.ReadWrite
case database.AccessModeReadOnly:
return pgx.ReadOnly
default:
return pgx.ReadWrite
}
}

View File

@@ -0,0 +1,60 @@
package sql
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type sqlConn struct {
*sql.Conn
}
var _ database.Client = (*sqlConn)(nil)
// Release implements [database.Client].
func (c *sqlConn) Release(_ context.Context) error {
return c.Close()
}
// Begin implements [database.Client].
func (c *sqlConn) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
if err != nil {
return nil, wrapError(err)
}
return &sqlTx{tx}, nil
}
// Query implements sql.Client.
// Subtle: this method shadows the method (*Conn).Query of pgxConn.Conn.
func (c *sqlConn) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
//nolint:rowserrcheck // Rows.Close is called by the caller
rows, err := c.QueryContext(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
}
// QueryRow implements sql.Client.
// Subtle: this method shadows the method (*Conn).QueryRow of pgxConn.Conn.
func (c *sqlConn) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.QueryRowContext(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *sqlConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.ExecContext(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected()
}
// Migrate implements [database.Migrator].
func (c *sqlConn) Migrate(ctx context.Context) error {
return ErrMigrate
}

View File

@@ -0,0 +1,3 @@
// [database/sql] implementation of the interfaces defined in the database package.
// This package is used to migrate from event driven to relational Zitadel.
package sql

View File

@@ -0,0 +1,41 @@
package sql
import (
"database/sql"
"errors"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var ErrMigrate = errors.New("sql does not support migrations, use a different dialect")
func wrapError(err error) error {
if err == nil {
return nil
}
if errors.Is(err, pgx.ErrNoRows) || errors.Is(err, sql.ErrNoRows) {
return database.NewNoRowFoundError(err)
}
var pgxErr *pgconn.PgError
if !errors.As(err, &pgxErr) {
return database.NewUnknownError(err)
}
switch pgxErr.Code {
// 23514: check_violation - A value violates a CHECK constraint.
case "23514":
return database.NewCheckError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23505: unique_violation - A value violates a UNIQUE constraint.
case "23505":
return database.NewUniqueError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23503: foreign_key_violation - A value violates a foreign key constraint.
case "23503":
return database.NewForeignKeyError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
// 23502: not_null_violation - A value violates a NOT NULL constraint.
case "23502":
return database.NewNotNullError(pgxErr.TableName, pgxErr.ConstraintName, pgxErr)
}
return database.NewUnknownError(err)
}

View File

@@ -0,0 +1,75 @@
package sql
import (
"context"
"database/sql"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type sqlPool struct {
*sql.DB
}
var _ database.Pool = (*sqlPool)(nil)
func SQLPool(db *sql.DB) *sqlPool {
return &sqlPool{
DB: db,
}
}
// Acquire implements [database.Pool].
func (c *sqlPool) Acquire(ctx context.Context) (database.Client, error) {
conn, err := c.Conn(ctx)
if err != nil {
return nil, wrapError(err)
}
return &sqlConn{Conn: conn}, nil
}
// Query implements [database.Pool].
// Subtle: this method shadows the method (Pool).Query of pgxPool.Pool.
func (c *sqlPool) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
//nolint:rowserrcheck // Rows.Close is called by the caller
rows, err := c.QueryContext(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
}
// QueryRow implements [database.Pool].
// Subtle: this method shadows the method (Pool).QueryRow of pgxPool.Pool.
func (c *sqlPool) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{c.QueryRowContext(ctx, sql, args...)}
}
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *sqlPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.ExecContext(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected()
}
// Begin implements [database.Pool].
func (c *sqlPool) Begin(ctx context.Context, opts *database.TransactionOptions) (database.Transaction, error) {
tx, err := c.BeginTx(ctx, transactionOptionsToSQL(opts))
if err != nil {
return nil, wrapError(err)
}
return &sqlTx{tx}, nil
}
// Close implements [database.Pool].
func (c *sqlPool) Close(_ context.Context) error {
return c.DB.Close()
}
// Migrate implements [database.Migrator].
func (c *sqlPool) Migrate(ctx context.Context) error {
return ErrMigrate
}

View File

@@ -0,0 +1,78 @@
package sql
import (
"database/sql"
pgxscan "github.com/georgysavva/scany/v2/dbscan"
"github.com/jackc/pgx/v5"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var (
_ database.Rows = (*Rows)(nil)
_ database.CollectableRows = (*Rows)(nil)
_ database.Row = (*Row)(nil)
)
type Row struct{ pgx.Row }
// Scan implements [database.Row].
// Subtle: this method shadows the method ([pgx.Row]).Scan of Row.Row.
func (r *Row) Scan(dest ...any) error {
return wrapError(r.Row.Scan(dest...))
}
type Rows struct{ *sql.Rows }
// Err implements [database.Rows].
// Subtle: this method shadows the method ([pgx.Rows]).Err of Rows.Rows.
func (r *Rows) Err() error {
return wrapError(r.Rows.Err())
}
func (r *Rows) Scan(dest ...any) error {
return wrapError(r.Rows.Scan(dest...))
}
// Collect implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) Collect(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return wrapError(pgxscan.ScanAll(dest, r.Rows))
}
// CollectFirst implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) CollectFirst(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return wrapError(pgxscan.ScanRow(dest, r.Rows))
}
// CollectExactlyOneRow implements [database.CollectableRows].
// See [this page](https://github.com/georgysavva/scany/blob/master/dbscan/doc.go#L8) for additional details.
func (r *Rows) CollectExactlyOneRow(dest any) (err error) {
defer func() {
closeErr := r.Close()
if err == nil {
err = closeErr
}
}()
return wrapError(pgxscan.ScanOne(dest, r.Rows))
}
// Close implements [database.Rows].
// Subtle: this method shadows the method (Rows).Close of Rows.Rows.
func (r *Rows) Close() error {
return r.Rows.Close()
}

View File

@@ -0,0 +1,69 @@
package sql
import (
"context"
"errors"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ database.Transaction = (*sqlSavepoint)(nil)
const (
savepointName = "zitadel_savepoint"
createSavepoint = "SAVEPOINT " + savepointName
rollbackToSavepoint = "ROLLBACK TO SAVEPOINT " + savepointName
commitSavepoint = "RELEASE SAVEPOINT " + savepointName
)
type sqlSavepoint struct {
parent database.Transaction
}
// Commit implements [database.Transaction].
func (s *sqlSavepoint) Commit(ctx context.Context) error {
_, err := s.parent.Exec(ctx, commitSavepoint)
return wrapError(err)
}
// Rollback implements [database.Transaction].
func (s *sqlSavepoint) Rollback(ctx context.Context) error {
_, err := s.parent.Exec(ctx, rollbackToSavepoint)
return wrapError(err)
}
// End implements [database.Transaction].
func (s *sqlSavepoint) End(ctx context.Context, err error) error {
if err != nil {
rollbackErr := s.Rollback(ctx)
if rollbackErr != nil {
err = errors.Join(err, rollbackErr)
}
return err
}
return s.Commit(ctx)
}
// Query implements [database.Transaction].
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (s *sqlSavepoint) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
return s.parent.Query(ctx, sql, args...)
}
// QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (s *sqlSavepoint) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return s.parent.QueryRow(ctx, sql, args...)
}
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (s *sqlSavepoint) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
return s.parent.Exec(ctx, sql, args...)
}
// Begin implements [database.Transaction].
// As postgres does not support nested transactions we use savepoints to emulate them.
func (s *sqlSavepoint) Begin(ctx context.Context) (database.Transaction, error) {
return s.parent.Begin(ctx)
}

View File

@@ -0,0 +1,104 @@
package sql
import (
"context"
"database/sql"
"errors"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type sqlTx struct{ *sql.Tx }
var _ database.Transaction = (*sqlTx)(nil)
func SQLTx(tx *sql.Tx) *sqlTx {
return &sqlTx{
Tx: tx,
}
}
// Commit implements [database.Transaction].
func (tx *sqlTx) Commit(ctx context.Context) error {
return wrapError(tx.Tx.Commit())
}
// Rollback implements [database.Transaction].
func (tx *sqlTx) Rollback(ctx context.Context) error {
return wrapError(tx.Tx.Rollback())
}
// End implements [database.Transaction].
func (tx *sqlTx) End(ctx context.Context, err error) error {
if err != nil {
rollbackErr := tx.Rollback(ctx)
if rollbackErr != nil {
err = errors.Join(err, rollbackErr)
}
return err
}
return tx.Commit(ctx)
}
// Query implements [database.Transaction].
// Subtle: this method shadows the method (Tx).Query of pgxTx.Tx.
func (tx *sqlTx) Query(ctx context.Context, sql string, args ...any) (database.Rows, error) {
//nolint:rowserrcheck // Rows.Close is called by the caller
rows, err := tx.QueryContext(ctx, sql, args...)
if err != nil {
return nil, wrapError(err)
}
return &Rows{rows}, nil
}
// QueryRow implements [database.Transaction].
// Subtle: this method shadows the method (Tx).QueryRow of pgxTx.Tx.
func (tx *sqlTx) QueryRow(ctx context.Context, sql string, args ...any) database.Row {
return &Row{tx.QueryRowContext(ctx, sql, args...)}
}
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *sqlTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := tx.ExecContext(ctx, sql, args...)
if err != nil {
return 0, wrapError(err)
}
return res.RowsAffected()
}
// Begin implements [database.Transaction].
// As postgres does not support nested transactions we use savepoints to emulate them.
func (tx *sqlTx) Begin(ctx context.Context) (database.Transaction, error) {
_, err := tx.ExecContext(ctx, createSavepoint)
if err != nil {
return nil, wrapError(err)
}
return &sqlSavepoint{tx}, nil
}
func transactionOptionsToSQL(opts *database.TransactionOptions) *sql.TxOptions {
if opts == nil {
return nil
}
return &sql.TxOptions{
Isolation: isolationToSQL(opts.IsolationLevel),
ReadOnly: accessModeToSQL(opts.AccessMode),
}
}
func isolationToSQL(isolation database.IsolationLevel) sql.IsolationLevel {
switch isolation {
case database.IsolationLevelSerializable:
return sql.LevelSerializable
case database.IsolationLevelReadCommitted:
return sql.LevelReadCommitted
default:
return sql.LevelSerializable
}
}
func accessModeToSQL(accessMode database.AccessMode) bool {
return accessMode == database.AccessModeReadOnly
}

View File

@@ -0,0 +1,230 @@
package database
import (
"errors"
"fmt"
)
var ErrNoChanges = errors.New("update must contain a change")
// NoRowFoundError is returned when QueryRow does not find any row.
// It wraps the dialect specific original error to provide more context.
type NoRowFoundError struct {
original error
}
func NewNoRowFoundError(original error) error {
return &NoRowFoundError{
original: original,
}
}
func (e *NoRowFoundError) Error() string {
return "no row found"
}
func (e *NoRowFoundError) Is(target error) bool {
_, ok := target.(*NoRowFoundError)
return ok
}
func (e *NoRowFoundError) Unwrap() error {
return e.original
}
// MultipleRowsFoundError is returned when QueryRow finds multiple rows.
// It wraps the dialect specific original error to provide more context.
type MultipleRowsFoundError struct {
original error
count int
}
func NewMultipleRowsFoundError(original error, count int) error {
return &MultipleRowsFoundError{
original: original,
count: count,
}
}
func (e *MultipleRowsFoundError) Error() string {
return fmt.Sprintf("multiple rows found: %d", e.count)
}
func (e *MultipleRowsFoundError) Is(target error) bool {
_, ok := target.(*MultipleRowsFoundError)
return ok
}
func (e *MultipleRowsFoundError) Unwrap() error {
return e.original
}
type IntegrityType string
const (
IntegrityTypeCheck IntegrityType = "check"
IntegrityTypeUnique IntegrityType = "unique"
IntegrityTypeForeign IntegrityType = "foreign"
IntegrityTypeNotNull IntegrityType = "not null"
)
// IntegrityViolationError represents a generic integrity violation error.
// It wraps the dialect specific original error to provide more context.
type IntegrityViolationError struct {
integrityType IntegrityType
table string
constraint string
original error
}
func NewIntegrityViolationError(typ IntegrityType, table, constraint string, original error) error {
return &IntegrityViolationError{
integrityType: typ,
table: table,
constraint: constraint,
original: original,
}
}
func (e *IntegrityViolationError) Error() string {
return fmt.Sprintf("integrity violation of type %q on %q (constraint: %q): %v", e.integrityType, e.table, e.constraint, e.original)
}
func (e *IntegrityViolationError) Is(target error) bool {
_, ok := target.(*IntegrityViolationError)
return ok
}
// CheckError is returned when a check constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a check constraint was violated during an insert or update operation.
type CheckError struct {
IntegrityViolationError
}
func NewCheckError(table, constraint string, original error) error {
return &CheckError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeCheck,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *CheckError) Is(target error) bool {
_, ok := target.(*CheckError)
return ok
}
func (e *CheckError) Unwrap() error {
return &e.IntegrityViolationError
}
// UniqueError is returned when a unique constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a unique constraint was violated during an insert or update operation.
type UniqueError struct {
IntegrityViolationError
}
func NewUniqueError(table, constraint string, original error) error {
return &UniqueError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeUnique,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *UniqueError) Is(target error) bool {
_, ok := target.(*UniqueError)
return ok
}
func (e *UniqueError) Unwrap() error {
return &e.IntegrityViolationError
}
// ForeignKeyError is returned when a foreign key constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a foreign key constraint was violated during an insert or update operation
type ForeignKeyError struct {
IntegrityViolationError
}
func NewForeignKeyError(table, constraint string, original error) error {
return &ForeignKeyError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeForeign,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *ForeignKeyError) Is(target error) bool {
_, ok := target.(*ForeignKeyError)
return ok
}
func (e *ForeignKeyError) Unwrap() error {
return &e.IntegrityViolationError
}
// NotNullError is returned when a not null constraint fails.
// It wraps the [IntegrityViolationError] to provide more context.
// It is used to indicate that a not null constraint was violated during an insert or update operation.
type NotNullError struct {
IntegrityViolationError
}
func NewNotNullError(table, constraint string, original error) error {
return &NotNullError{
IntegrityViolationError: IntegrityViolationError{
integrityType: IntegrityTypeNotNull,
table: table,
constraint: constraint,
original: original,
},
}
}
func (e *NotNullError) Is(target error) bool {
_, ok := target.(*NotNullError)
return ok
}
func (e *NotNullError) Unwrap() error {
return &e.IntegrityViolationError
}
// UnknownError is returned when an unknown error occurs.
// It wraps the dialect specific original error to provide more context.
// It is used to indicate that an error occurred that does not fit into any of the other categories.
type UnknownError struct {
original error
}
func NewUnknownError(original error) error {
return &UnknownError{
original: original,
}
}
func (e *UnknownError) Error() string {
return fmt.Sprintf("unknown database error: %v", e.original)
}
func (e *UnknownError) Is(target error) bool {
_, ok := target.(*UnknownError)
return ok
}
func (e *UnknownError) Unwrap() error {
return e.original
}

View File

@@ -0,0 +1,78 @@
//go:build integration
package events_test
import (
"context"
"log"
"os"
"testing"
"time"
"github.com/jackc/pgx/v5"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
"github.com/zitadel/zitadel/internal/integration"
v2beta "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta"
v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
const ConnString = "host=localhost port=5432 user=zitadel password=zitadel dbname=zitadel sslmode=disable"
var (
dbPool *pgxpool.Pool
CTX context.Context
Instance *integration.Instance
SystemClient system.SystemServiceClient
OrgClient v2beta_org.OrganizationServiceClient
)
var pool database.Pool
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
CTX = integration.WithSystemAuthorization(ctx)
Instance = integration.NewInstance(CTX)
SystemClient = integration.SystemClient()
OrgClient = Instance.Client.OrgV2beta
defer func() {
_, err := Instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
InstanceId: Instance.Instance.Id,
})
if err != nil {
log.Printf("Failed to delete instance on cleanup: %v\n", err)
}
}()
var err error
dbConfig, err := pgxpool.ParseConfig(ConnString)
if err != nil {
panic(err)
}
dbConfig.AfterConnect = func(ctx context.Context, conn *pgx.Conn) error {
orgState, err := conn.LoadType(ctx, "zitadel.organization_state")
if err != nil {
return err
}
conn.TypeMap().RegisterType(orgState)
return nil
}
dbPool, err = pgxpool.NewWithConfig(context.Background(), dbConfig)
if err != nil {
panic(err)
}
pool = postgres.PGxPool(dbPool)
return m.Run()
}())
}

View File

@@ -0,0 +1,310 @@
//go:build integration
package events_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
v2beta "github.com/zitadel/zitadel/pkg/grpc/instance/v2beta"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
func TestServer_TestInstanceDomainReduces(t *testing.T) {
instance := integration.NewInstance(CTX)
instanceRepo := repository.InstanceRepository(pool)
instanceDomainRepo := instanceRepo.Domains(true)
t.Cleanup(func() {
_, err := instance.Client.InstanceV2Beta.DeleteInstance(CTX, &v2beta.DeleteInstanceRequest{
InstanceId: instance.Instance.Id,
})
if err != nil {
t.Logf("Failed to delete instance on cleanup: %v", err)
}
})
// Wait for instance to be created
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
_, err := instanceRepo.Get(CTX,
database.WithCondition(instanceRepo.IDCondition(instance.Instance.Id)),
)
assert.NoError(ttt, err)
}, retryDuration, tick)
t.Run("test instance custom domain add reduces", func(t *testing.T) {
// Add a domain to the instance
domainName := gofakeit.DomainName()
beforeAdd := time.Now()
_, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
afterAdd := time.Now()
t.Cleanup(func() {
_, err := instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
if err != nil {
t.Logf("Failed to delete instance domain on cleanup: %v", err)
}
})
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
),
),
)
require.NoError(ttt, err)
// event instance.domain.added
assert.Equal(ttt, domainName, domain.Domain)
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
assert.False(ttt, *domain.IsPrimary)
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
}, retryDuration, tick)
})
t.Run("test instance custom domain set primary reduces", func(t *testing.T) {
// Add a domain to the instance
domainName := gofakeit.DomainName()
_, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
t.Cleanup(func() {
// first we change the primary domain to something else
domain, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
instanceDomainRepo.IsPrimaryCondition(false),
),
),
database.WithLimit(1),
)
require.NoError(t, err)
_, err = SystemClient.SetPrimaryDomain(CTX, &system.SetPrimaryDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domain.Domain,
})
require.NoError(t, err)
_, err = instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
if err != nil {
t.Logf("Failed to delete instance domain on cleanup: %v", err)
}
})
// Wait for domain to be created
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
),
),
)
require.NoError(ttt, err)
require.False(ttt, *domain.IsPrimary)
assert.Equal(ttt, domainName, domain.Domain)
}, retryDuration, tick)
// Set domain as primary
beforeSetPrimary := time.Now()
_, err = SystemClient.SetPrimaryDomain(CTX, &system.SetPrimaryDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
afterSetPrimary := time.Now()
// Test that set primary reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.IsPrimaryCondition(true),
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
),
),
)
require.NoError(ttt, err)
// event instance.domain.primary.set
assert.Equal(ttt, domainName, domain.Domain)
assert.True(ttt, *domain.IsPrimary)
assert.WithinRange(ttt, domain.UpdatedAt, beforeSetPrimary, afterSetPrimary)
}, retryDuration, tick)
})
t.Run("test instance custom domain remove reduces", func(t *testing.T) {
// Add a domain to the instance
domainName := gofakeit.DomainName()
_, err := instance.Client.InstanceV2Beta.AddCustomDomain(CTX, &v2beta.AddCustomDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
// Wait for domain to be created and verify it exists
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
_, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
),
),
)
require.NoError(ttt, err)
}, retryDuration, tick)
// Remove the domain
_, err = instance.Client.InstanceV2Beta.RemoveCustomDomain(CTX, &v2beta.RemoveCustomDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
instanceDomainRepo.TypeCondition(domain.DomainTypeCustom),
),
),
)
// event instance.domain.removed
assert.Nil(ttt, domain)
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
}, retryDuration, tick)
})
t.Run("test instance trusted domain add reduces", func(t *testing.T) {
// Add a domain to the instance
domainName := gofakeit.DomainName()
beforeAdd := time.Now()
_, err := instance.Client.InstanceV2Beta.AddTrustedDomain(CTX, &v2beta.AddTrustedDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
afterAdd := time.Now()
t.Cleanup(func() {
_, err := instance.Client.InstanceV2Beta.RemoveTrustedDomain(CTX, &v2beta.RemoveTrustedDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
if err != nil {
t.Logf("Failed to delete instance domain on cleanup: %v", err)
}
})
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted),
),
),
)
require.NoError(ttt, err)
// event instance.domain.added
assert.Equal(ttt, domainName, domain.Domain)
assert.Equal(ttt, instance.Instance.Id, domain.InstanceID)
assert.WithinRange(ttt, domain.CreatedAt, beforeAdd, afterAdd)
assert.WithinRange(ttt, domain.UpdatedAt, beforeAdd, afterAdd)
}, retryDuration, tick)
})
t.Run("test instance trusted domain remove reduces", func(t *testing.T) {
// Add a domain to the instance
domainName := gofakeit.DomainName()
_, err := instance.Client.InstanceV2Beta.AddTrustedDomain(CTX, &v2beta.AddTrustedDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
// Wait for domain to be created and verify it exists
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
_, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted),
),
),
)
require.NoError(ttt, err)
}, retryDuration, tick)
// Remove the domain
_, err = instance.Client.InstanceV2Beta.RemoveTrustedDomain(CTX, &v2beta.RemoveTrustedDomainRequest{
InstanceId: instance.Instance.Id,
Domain: domainName,
})
require.NoError(t, err)
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
domain, err := instanceDomainRepo.Get(CTX,
database.WithCondition(
database.And(
instanceDomainRepo.InstanceIDCondition(instance.Instance.Id),
instanceDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
instanceDomainRepo.TypeCondition(domain.DomainTypeTrusted),
),
),
)
// event instance.domain.removed
assert.Nil(ttt, domain)
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
}, retryDuration, tick)
})
}

View File

@@ -0,0 +1,162 @@
//go:build integration
package events_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
func TestServer_TestInstanceReduces(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
t.Run("test instance add reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
beforeCreate := time.Now()
instance, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
})
afterCreate := time.Now()
require.NoError(t, err)
t.Cleanup(func() {
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
InstanceId: instance.GetInstanceId(),
})
if err != nil {
t.Logf("Failed to delete instance on cleanup: %v", err)
}
})
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
database.WithCondition(instanceRepo.IDCondition(instance.GetInstanceId())),
)
require.NoError(ttt, err)
// event instance.added
assert.Equal(ttt, instanceName, instance.Name)
// event instance.default.org.set
assert.NotNil(t, instance.DefaultOrgID)
// event instance.iam.project.set
assert.NotNil(t, instance.IAMProjectID)
// event instance.iam.console.set
assert.NotNil(t, instance.ConsoleAppID)
// event instance.default.language.set
assert.NotNil(t, instance.DefaultLanguage)
// event instance.added
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
// event instance.added
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test instance update reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
})
require.NoError(t, err)
t.Cleanup(func() {
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
InstanceId: res.GetInstanceId(),
})
if err != nil {
t.Logf("Failed to delete instance on cleanup: %v", err)
}
})
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
assert.Equal(t, instanceName, instance.Name)
}, retryDuration, tick)
instanceName += "new"
beforeUpdate := time.Now()
_, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{
InstanceId: res.InstanceId,
InstanceName: instanceName,
})
afterUpdate := time.Now()
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
// event instance.changed
assert.Equal(t, instanceName, instance.Name)
assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
}, retryDuration, tick)
})
t.Run("test instance delete reduces", func(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
})
require.NoError(t, err)
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
require.NoError(t, err)
assert.Equal(t, instanceName, instance.Name)
}, retryDuration, tick)
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
InstanceId: res.InstanceId,
})
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
database.WithCondition(instanceRepo.IDCondition(res.GetInstanceId())),
)
// event instance.removed
assert.Nil(t, instance)
require.ErrorIs(t, err, new(database.NoRowFoundError))
}, retryDuration, tick)
})
}

View File

@@ -0,0 +1,133 @@
//go:build integration
package events_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
v2beta "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
)
func TestServer_TestOrgDomainReduces(t *testing.T) {
org, err := OrgClient.CreateOrganization(CTX, &v2beta.CreateOrganizationRequest{
Name: gofakeit.Name(),
})
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(pool)
orgDomainRepo := orgRepo.Domains(false)
t.Cleanup(func() {
_, err := OrgClient.DeleteOrganization(CTX, &v2beta.DeleteOrganizationRequest{
Id: org.GetId(),
})
if err != nil {
t.Logf("Failed to delete organization on cleanup: %v", err)
}
})
// Wait for org to be created
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
_, err := orgRepo.Get(CTX,
database.WithCondition(orgRepo.IDCondition(org.GetId())),
)
assert.NoError(ttt, err)
}, retryDuration, tick)
// The API call also sets the domain as primary, so we don't do a separate test for that.
t.Run("test organization domain add reduces", func(t *testing.T) {
// Add a domain to the organization
domainName := gofakeit.DomainName()
beforeAdd := time.Now()
_, err := OrgClient.AddOrganizationDomain(CTX, &v2beta.AddOrganizationDomainRequest{
OrganizationId: org.GetId(),
Domain: domainName,
})
require.NoError(t, err)
afterAdd := time.Now()
t.Cleanup(func() {
_, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
OrganizationId: org.GetId(),
Domain: domainName,
})
if err != nil {
t.Logf("Failed to delete domain on cleanup: %v", err)
}
})
// Test that domain add reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
gottenDomain, err := orgDomainRepo.Get(CTX,
database.WithCondition(
database.And(
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
orgDomainRepo.OrgIDCondition(org.Id),
orgDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
),
),
)
require.NoError(ttt, err)
// event org.domain.added
assert.Equal(t, domainName, gottenDomain.Domain)
assert.Equal(t, Instance.Instance.Id, gottenDomain.InstanceID)
assert.Equal(t, org.Id, gottenDomain.OrgID)
assert.WithinRange(t, gottenDomain.CreatedAt, beforeAdd, afterAdd)
assert.WithinRange(t, gottenDomain.UpdatedAt, beforeAdd, afterAdd)
}, retryDuration, tick)
})
t.Run("test org domain remove reduces", func(t *testing.T) {
// Add a domain to the organization
domainName := gofakeit.DomainName()
_, err := OrgClient.AddOrganizationDomain(CTX, &v2beta.AddOrganizationDomainRequest{
OrganizationId: org.GetId(),
Domain: domainName,
})
require.NoError(t, err)
t.Cleanup(func() {
_, err := OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
OrganizationId: org.GetId(),
Domain: domainName,
})
if err != nil {
t.Logf("Failed to delete domain on cleanup: %v", err)
}
})
// Remove the domain
_, err = OrgClient.DeleteOrganizationDomain(CTX, &v2beta.DeleteOrganizationDomainRequest{
OrganizationId: org.GetId(),
Domain: domainName,
})
require.NoError(t, err)
// Test that domain remove reduces
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
domain, err := orgDomainRepo.Get(CTX,
database.WithCondition(
database.And(
orgDomainRepo.InstanceIDCondition(Instance.Instance.Id),
orgDomainRepo.DomainCondition(database.TextOperationEqual, domainName),
),
),
)
// event instance.domain.removed
assert.Nil(ttt, domain)
require.ErrorIs(ttt, err, new(database.NoRowFoundError))
}, retryDuration, tick)
})
}

View File

@@ -0,0 +1,269 @@
//go:build integration
package events_test
import (
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
v2beta_org "github.com/zitadel/zitadel/pkg/grpc/org/v2beta"
)
func TestServer_TestOrganizationReduces(t *testing.T) {
instanceID := Instance.ID()
orgRepo := repository.OrganizationRepository(pool)
t.Run("test org add reduces", func(t *testing.T) {
beforeCreate := time.Now()
orgName := gofakeit.Name()
org, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
afterCreate := time.Now()
t.Cleanup(func() {
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
Id: org.GetId(),
})
if err != nil {
t.Logf("Failed to delete organization on cleanup: %v", err)
}
})
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(tt *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
database.WithCondition(
database.And(
orgRepo.IDCondition(org.GetId()),
orgRepo.InstanceIDCondition(instanceID),
),
),
)
require.NoError(tt, err)
// event org.added
assert.NotNil(t, organization.ID)
assert.Equal(t, orgName, organization.Name)
assert.NotNil(t, organization.InstanceID)
assert.Equal(t, domain.OrgStateActive, organization.State)
assert.WithinRange(t, organization.CreatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, organization.UpdatedAt, beforeCreate, afterCreate)
}, retryDuration, tick)
})
t.Run("test org change reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
t.Cleanup(func() {
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
Id: organization.Id,
})
if err != nil {
t.Logf("Failed to delete organization on cleanup: %v", err)
}
})
// 2. update org name
beforeUpdate := time.Now()
orgName = orgName + "_new"
_, err = OrgClient.UpdateOrganization(CTX, &v2beta_org.UpdateOrganizationRequest{
Id: organization.Id,
Name: orgName,
})
require.NoError(t, err)
afterUpdate := time.Now()
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
orgRepo.InstanceIDCondition(instanceID),
),
),
)
require.NoError(t, err)
// event org.changed
assert.Equal(t, orgName, organization.Name)
assert.WithinRange(t, organization.UpdatedAt, beforeUpdate, afterUpdate)
}, retryDuration, tick)
})
t.Run("test org deactivate reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
t.Cleanup(func() {
// Cleanup: delete the organization
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
Id: organization.Id,
})
if err != nil {
t.Logf("Failed to delete organization on cleanup: %v", err)
}
})
// 2. deactivate org name
beforeDeactivate := time.Now()
_, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
afterDeactivate := time.Now()
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
orgRepo.InstanceIDCondition(instanceID),
),
),
)
require.NoError(t, err)
// event org.deactivate
assert.Equal(t, domain.OrgStateInactive, organization.State)
assert.WithinRange(t, organization.UpdatedAt, beforeDeactivate, afterDeactivate)
}, retryDuration, tick)
})
t.Run("test org activate reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
t.Cleanup(func() {
// Cleanup: delete the organization
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
Id: organization.Id,
})
if err != nil {
t.Logf("Failed to delete organization on cleanup: %v", err)
}
})
// 2. deactivate org name
_, err = OrgClient.DeactivateOrganization(CTX, &v2beta_org.DeactivateOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(pool)
// 3. check org deactivated
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
orgRepo.InstanceIDCondition(instanceID),
),
),
)
require.NoError(t, err)
assert.Equal(t, domain.OrgStateInactive, organization.State)
}, retryDuration, tick)
// 4. activate org name
beforeActivate := time.Now()
_, err = OrgClient.ActivateOrganization(CTX, &v2beta_org.ActivateOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
afterActivate := time.Now()
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
orgRepo.InstanceIDCondition(instanceID),
),
),
)
require.NoError(t, err)
// event org.reactivate
assert.Equal(t, orgName, organization.Name)
assert.Equal(t, domain.OrgStateActive, organization.State)
assert.WithinRange(t, organization.UpdatedAt, beforeActivate, afterActivate)
}, retryDuration, tick)
})
t.Run("test org remove reduces", func(t *testing.T) {
orgName := gofakeit.Name()
// 1. create org
organization, err := OrgClient.CreateOrganization(CTX, &v2beta_org.CreateOrganizationRequest{
Name: orgName,
})
require.NoError(t, err)
// 2. check org retrievable
orgRepo := repository.OrganizationRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
_, err := orgRepo.Get(CTX,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
orgRepo.InstanceIDCondition(instanceID),
),
),
)
require.NoError(t, err)
}, retryDuration, tick)
// 3. delete org
_, err = OrgClient.DeleteOrganization(CTX, &v2beta_org.DeleteOrganizationRequest{
Id: organization.Id,
})
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(t *assert.CollectT) {
organization, err := orgRepo.Get(CTX,
database.WithCondition(
database.And(
orgRepo.IDCondition(organization.Id),
orgRepo.InstanceIDCondition(instanceID),
),
),
)
require.ErrorIs(t, err, new(database.NoRowFoundError))
// event org.remove
assert.Nil(t, organization)
}, retryDuration, tick)
})
}

View File

@@ -0,0 +1,3 @@
package database
//go:generate mockgen -typed -package dbmock -destination ./dbmock/database.mock.go github.com/zitadel/zitadel/backend/v3/storage/database Pool,Client,Row,Rows,Transaction

View File

@@ -0,0 +1,9 @@
package database
import "context"
type Migrator interface {
// Migrate executes migrations to setup the database.
// The method can be called once per running Zitadel.
Migrate(ctx context.Context) error
}

View File

@@ -0,0 +1,136 @@
package database
import (
"time"
"golang.org/x/exp/constraints"
)
type Value interface {
Boolean | Number | Text | Instruction
}
type Operation interface {
BooleanOperation | NumberOperation | TextOperation
}
type Text interface {
~string | ~[]byte
}
// TextOperation are operations that can be performed on text values.
type TextOperation uint8
const (
// TextOperationEqual compares two strings for equality.
TextOperationEqual TextOperation = iota + 1
// TextOperationEqualIgnoreCase compares two strings for equality, ignoring case.
TextOperationEqualIgnoreCase
// TextOperationNotEqual compares two strings for inequality.
TextOperationNotEqual
// TextOperationNotEqualIgnoreCase compares two strings for inequality, ignoring case.
TextOperationNotEqualIgnoreCase
// TextOperationStartsWith checks if the first string starts with the second.
TextOperationStartsWith
// TextOperationStartsWithIgnoreCase checks if the first string starts with the second, ignoring case.
TextOperationStartsWithIgnoreCase
)
var textOperations = map[TextOperation]string{
TextOperationEqual: " = ",
TextOperationEqualIgnoreCase: " LIKE ",
TextOperationNotEqual: " <> ",
TextOperationNotEqualIgnoreCase: " NOT LIKE ",
TextOperationStartsWith: " LIKE ",
TextOperationStartsWithIgnoreCase: " LIKE ",
}
func writeTextOperation[T Text](builder *StatementBuilder, col Column, op TextOperation, value T) {
switch op {
case TextOperationEqual, TextOperationNotEqual:
col.WriteQualified(builder)
builder.WriteString(textOperations[op])
builder.WriteArg(value)
case TextOperationEqualIgnoreCase, TextOperationNotEqualIgnoreCase:
builder.WriteString("LOWER(")
col.WriteQualified(builder)
builder.WriteString(")")
builder.WriteString(textOperations[op])
builder.WriteString("LOWER(")
builder.WriteArg(value)
builder.WriteString(")")
case TextOperationStartsWith:
col.WriteQualified(builder)
builder.WriteString(textOperations[op])
builder.WriteArg(value)
builder.WriteString(" || '%'")
case TextOperationStartsWithIgnoreCase:
builder.WriteString("LOWER(")
col.WriteQualified(builder)
builder.WriteString(")")
builder.WriteString(textOperations[op])
builder.WriteString("LOWER(")
builder.WriteArg(value)
builder.WriteString(")")
builder.WriteString(" || '%'")
default:
panic("unsupported text operation")
}
}
type Number interface {
constraints.Integer | constraints.Float | constraints.Complex | time.Time | time.Duration
}
// NumberOperation are operations that can be performed on number values.
type NumberOperation uint8
const (
// NumberOperationEqual compares two numbers for equality.
NumberOperationEqual NumberOperation = iota + 1
// NumberOperationNotEqual compares two numbers for inequality.
NumberOperationNotEqual
// NumberOperationLessThan compares two numbers to check if the first is less than the second.
NumberOperationLessThan
// NumberOperationLessThanOrEqual compares two numbers to check if the first is less than or equal to the second.
NumberOperationAtLeast
// NumberOperationGreaterThan compares two numbers to check if the first is greater than the second.
NumberOperationGreaterThan
// NumberOperationGreaterThanOrEqual compares two numbers to check if the first is greater than or equal to the second.
NumberOperationAtMost
)
var numberOperations = map[NumberOperation]string{
NumberOperationEqual: " = ",
NumberOperationNotEqual: " <> ",
NumberOperationLessThan: " < ",
NumberOperationAtLeast: " <= ",
NumberOperationGreaterThan: " > ",
NumberOperationAtMost: " >= ",
}
func writeNumberOperation[T Number](builder *StatementBuilder, col Column, op NumberOperation, value T) {
col.WriteQualified(builder)
builder.WriteString(numberOperations[op])
builder.WriteArg(value)
}
type Boolean interface {
~bool
}
// BooleanOperation are operations that can be performed on boolean values.
type BooleanOperation uint8
const (
BooleanOperationIsTrue BooleanOperation = iota + 1
BooleanOperationIsFalse
)
func writeBooleanOperation[T Boolean](builder *StatementBuilder, col Column, value T) {
col.WriteQualified(builder)
builder.WriteString(" = ")
builder.WriteArg(value)
}

View File

@@ -0,0 +1,21 @@
package database
// Order represents a SQL condition.
// Its written after the ORDER BY keyword in a SQL statement.
type Order interface {
Write(builder *StatementBuilder)
}
type orderBy struct {
column Column
}
func OrderBy(column Column) Order {
return &orderBy{column: column}
}
// Write implements [Order].
func (o *orderBy) Write(builder *StatementBuilder) {
builder.WriteString(" ORDER BY ")
o.column.WriteQualified(builder)
}

View File

@@ -0,0 +1,174 @@
package database
type QueryOption func(opts *QueryOpts)
// WithCondition sets the condition for the query.
func WithCondition(condition Condition) QueryOption {
return func(opts *QueryOpts) {
opts.Condition = condition
}
}
// WithOrderBy sets the columns to order the results by.
func WithOrderBy(ordering OrderDirection, orderBy ...Column) QueryOption {
return func(opts *QueryOpts) {
opts.OrderBy = orderBy
opts.Ordering = ordering
}
}
func WithOrderByAscending(columns ...Column) QueryOption {
return WithOrderBy(OrderDirectionAsc, columns...)
}
func WithOrderByDescending(columns ...Column) QueryOption {
return WithOrderBy(OrderDirectionDesc, columns...)
}
// WithLimit sets the maximum number of results to return.
func WithLimit(limit uint32) QueryOption {
return func(opts *QueryOpts) {
opts.Limit = limit
}
}
// WithOffset sets the number of results to skip before returning the results.
func WithOffset(offset uint32) QueryOption {
return func(opts *QueryOpts) {
opts.Offset = offset
}
}
// WithGroupBy sets the columns to group the results by.
func WithGroupBy(groupBy ...Column) QueryOption {
return func(opts *QueryOpts) {
opts.GroupBy = groupBy
}
}
// WithLeftJoin adds a LEFT JOIN to the query.
func WithLeftJoin(table string, columns Condition) QueryOption {
return func(opts *QueryOpts) {
opts.Joins = append(opts.Joins, join{
table: table,
typ: JoinTypeLeft,
columns: columns,
})
}
}
type joinType string
const (
JoinTypeLeft joinType = "LEFT"
)
type join struct {
table string
typ joinType
columns Condition
}
type OrderDirection uint8
const (
OrderDirectionAsc OrderDirection = iota
OrderDirectionDesc
)
// QueryOpts holds the options for a query.
// It is used to build the SQL SELECT statement.
type QueryOpts struct {
// Condition is the condition to filter the results.
// It is used to build the WHERE clause of the SQL statement.
Condition Condition
// OrderBy is the columns to order the results by.
// It is used to build the ORDER BY clause of the SQL statement.
OrderBy Columns
// Ordering defines if the columns should be ordered ascending or descending.
// Default is ascending.
Ordering OrderDirection
// Limit is the maximum number of results to return.
// It is used to build the LIMIT clause of the SQL statement.
Limit uint32
// Offset is the number of results to skip before returning the results.
// It is used to build the OFFSET clause of the SQL statement.
Offset uint32
// GroupBy is the columns to group the results by.
// It is used to build the GROUP BY clause of the SQL statement.
GroupBy Columns
// Joins is a list of joins to be applied to the query.
// It is used to build the JOIN clauses of the SQL statement.
Joins []join
}
func (opts *QueryOpts) Write(builder *StatementBuilder) {
opts.WriteLeftJoins(builder)
opts.WriteCondition(builder)
opts.WriteGroupBy(builder)
opts.WriteOrderBy(builder)
opts.WriteLimit(builder)
opts.WriteOffset(builder)
}
func (opts *QueryOpts) WriteCondition(builder *StatementBuilder) {
if opts.Condition == nil {
return
}
builder.WriteString(" WHERE ")
opts.Condition.Write(builder)
}
func (opts *QueryOpts) WriteOrderBy(builder *StatementBuilder) {
if len(opts.OrderBy) == 0 {
return
}
builder.WriteString(" ORDER BY ")
for i, col := range opts.OrderBy {
if i > 0 {
builder.WriteString(", ")
}
col.WriteQualified(builder)
if opts.Ordering == OrderDirectionDesc {
builder.WriteString(" DESC")
}
}
}
func (opts *QueryOpts) WriteLimit(builder *StatementBuilder) {
if opts.Limit == 0 {
return
}
builder.WriteString(" LIMIT ")
builder.WriteArg(opts.Limit)
}
func (opts *QueryOpts) WriteOffset(builder *StatementBuilder) {
if opts.Offset == 0 {
return
}
builder.WriteString(" OFFSET ")
builder.WriteArg(opts.Offset)
}
func (opts *QueryOpts) WriteGroupBy(builder *StatementBuilder) {
if len(opts.GroupBy) == 0 {
return
}
builder.WriteString(" GROUP BY ")
opts.GroupBy.WriteQualified(builder)
}
func (opts *QueryOpts) WriteLeftJoins(builder *StatementBuilder) {
if len(opts.Joins) == 0 {
return
}
for _, join := range opts.Joins {
builder.WriteString(" ")
builder.WriteString(string(join.typ))
builder.WriteString(" JOIN ")
builder.WriteString(join.table)
builder.WriteString(" ON ")
join.columns.Write(builder)
}
}

View File

@@ -0,0 +1,5 @@
// Package implements the repositories defined in the domain package.
// The repositories are used by the domain package to access the database.
// the inheritance.sql file is me over-engineering table inheritance.
// I would create a user table which is inherited by human_user and machine_user and the same for objects like idps.
package repository

View File

@@ -0,0 +1,135 @@
CREATE TABLE objects (
id SERIAL PRIMARY KEY,
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL,
deleted_at TIMESTAMP
);
CREATE OR REPLACE FUNCTION update_updated_at_column()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at = NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TABLE instances(
name VARCHAR(50) NOT NULL
, PRIMARY KEY (id)
) INHERITS (objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON instances
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE instance_objects(
instance_id INT NOT NULL
, PRIMARY KEY (instance_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
--, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (objects);
CREATE TABLE orgs(
name VARCHAR(50) NOT NULL
, PRIMARY KEY (instance_id, id)
, CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (instance_objects);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON orgs
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE org_objects(
org_id INT NOT NULL
, PRIMARY KEY (instance_id, org_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
-- CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id),
-- CONSTRAINT fk_instance FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (instance_objects);
CREATE TABLE users (
username VARCHAR(50) NOT NULL
, PRIMARY KEY (instance_id, org_id, id)
-- as foreign keys are not inherited we need to define them on the child tables
-- , CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
-- , CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (org_objects);
CREATE INDEX idx_users_username ON users(username);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE human_users(
first_name VARCHAR(50)
, last_name VARCHAR(50)
, PRIMARY KEY (instance_id, org_id, id)
-- CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id),
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (users);
CREATE INDEX idx_human_users_username ON human_users(username);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON human_users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE TABLE machine_users(
description VARCHAR(50)
, PRIMARY KEY (instance_id, org_id, id)
-- , CONSTRAINT fk_user FOREIGN KEY (instance_id, org_id, id) REFERENCES users(instance_id, org_id, id)
, CONSTRAINT fk_org FOREIGN KEY (instance_id, org_id) REFERENCES orgs(instance_id, id)
, CONSTRAINT fk_instances FOREIGN KEY (instance_id) REFERENCES instances(id)
) INHERITS (users);
CREATE INDEX idx_machine_users_username ON machine_users(username);
CREATE TRIGGER set_updated_at
BEFORE UPDATE
ON machine_users
FOR EACH ROW
EXECUTE FUNCTION update_updated_at_column();
CREATE VIEW users_view AS (
SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, tableoid::regclass::TEXT AS type
, first_name
, last_name
, NULL AS description
FROM
human_users
UNION
SELECT
id
, created_at
, updated_at
, deleted_at
, instance_id
, org_id
, username
, tableoid::regclass::TEXT AS type
, NULL AS first_name
, NULL AS last_name
, description
FROM
machine_users
);

View File

@@ -0,0 +1,306 @@
package repository
import (
"context"
"database/sql"
"encoding/json"
"time"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ domain.InstanceRepository = (*instance)(nil)
type instance struct {
repository
shouldLoadDomains bool
domainRepo *instanceDomain
}
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
return &instance{
repository: repository{
client: client,
},
}
}
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
const (
queryInstanceStmt = `SELECT instances.id, instances.name, instances.default_org_id, instances.iam_project_id, instances.console_client_id, instances.console_app_id, instances.default_language, instances.created_at, instances.updated_at` +
` , CASE WHEN count(instance_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', instance_domains.domain, 'isPrimary', instance_domains.is_primary, 'isGenerated', instance_domains.is_generated, 'createdAt', instance_domains.created_at, 'updatedAt', instance_domains.updated_at)) ELSE NULL::JSONB END domains` +
` FROM zitadel.instances`
)
// Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn()),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
return scanInstance(ctx, i.client, &builder)
}
// List implements [domain.InstanceRepository].
func (i *instance) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Instance, error) {
opts = append(opts,
i.joinDomains(),
database.WithGroupBy(i.IDColumn()),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
options.Write(&builder)
return scanInstances(ctx, i.client, &builder)
}
func (i *instance) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 2)
columns = append(columns, database.NewColumnCondition(i.IDColumn(), i.Domains(false).InstanceIDColumn()))
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !i.shouldLoadDomains {
columns = append(columns, database.IsNull(i.Domains(false).InstanceIDColumn()))
}
return database.WithLeftJoin(
"zitadel.instance_domains",
database.And(columns...),
)
}
// Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
)
if !instance.CreatedAt.IsZero() {
createdAt = instance.CreatedAt
}
if !instance.UpdatedAt.IsZero() {
updatedAt = instance.UpdatedAt
}
builder.WriteString(`INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at) VALUES (`)
builder.WriteArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
}
// Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, id string, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET `)
database.Changes(changes).Write(&builder)
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
stmt := builder.String()
return i.client.Exec(ctx, stmt, builder.Args()...)
}
// Delete implements [domain.InstanceRepository].
func (i instance) Delete(ctx context.Context, id string) (int64, error) {
var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.instances`)
idCondition := i.IDCondition(id)
writeCondition(&builder, idCondition)
return i.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetName implements [domain.instanceChanges].
func (i instance) SetName(name string) database.Change {
return database.NewChange(i.NameColumn(), name)
}
// SetUpdatedAt implements [domain.instanceChanges].
func (i instance) SetUpdatedAt(time time.Time) database.Change {
return database.NewChange(i.UpdatedAtColumn(), time)
}
func (i instance) SetIAMProject(id string) database.Change {
return database.NewChange(i.IAMProjectIDColumn(), id)
}
func (i instance) SetDefaultOrg(id string) database.Change {
return database.NewChange(i.DefaultOrgIDColumn(), id)
}
func (i instance) SetDefaultLanguage(lang language.Tag) database.Change {
return database.NewChange(i.DefaultLanguageColumn(), lang.String())
}
func (i instance) SetConsoleClientID(id string) database.Change {
return database.NewChange(i.ConsoleClientIDColumn(), id)
}
func (i instance) SetConsoleAppID(id string) database.Change {
return database.NewChange(i.ConsoleAppIDColumn(), id)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// IDCondition implements [domain.instanceConditions].
func (i instance) IDCondition(id string) database.Condition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
}
// NameCondition implements [domain.instanceConditions].
func (i instance) NameCondition(op database.TextOperation, name string) database.Condition {
return database.NewTextCondition(i.NameColumn(), op, name)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.instanceColumns].
func (instance) IDColumn() database.Column {
return database.NewColumn("instances", "id")
}
// NameColumn implements [domain.instanceColumns].
func (instance) NameColumn() database.Column {
return database.NewColumn("instances", "name")
}
// CreatedAtColumn implements [domain.instanceColumns].
func (instance) CreatedAtColumn() database.Column {
return database.NewColumn("instances", "created_at")
}
// DefaultOrgIdColumn implements [domain.instanceColumns].
func (instance) DefaultOrgIDColumn() database.Column {
return database.NewColumn("instances", "default_org_id")
}
// IAMProjectIDColumn implements [domain.instanceColumns].
func (instance) IAMProjectIDColumn() database.Column {
return database.NewColumn("instances", "iam_project_id")
}
// ConsoleClientIDColumn implements [domain.instanceColumns].
func (instance) ConsoleClientIDColumn() database.Column {
return database.NewColumn("instances", "console_client_id")
}
// ConsoleAppIDColumn implements [domain.instanceColumns].
func (instance) ConsoleAppIDColumn() database.Column {
return database.NewColumn("instances", "console_app_id")
}
// DefaultLanguageColumn implements [domain.instanceColumns].
func (instance) DefaultLanguageColumn() database.Column {
return database.NewColumn("instances", "default_language")
}
// UpdatedAtColumn implements [domain.instanceColumns].
func (instance) UpdatedAtColumn() database.Column {
return database.NewColumn("instances", "updated_at")
}
type rawInstance struct {
*domain.Instance
RawDomains sql.Null[json.RawMessage] `json:"domains,omitzero" db:"domains"`
}
func scanInstance(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
var instance rawInstance
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&instance); err != nil {
return nil, err
}
if instance.RawDomains.Valid {
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
return nil, err
}
}
return instance.Instance, nil
}
func scanInstances(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Instance, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
var rawInstances []*rawInstance
if err := rows.(database.CollectableRows).Collect(&rawInstances); err != nil {
return nil, err
}
instances := make([]*domain.Instance, len(rawInstances))
for i, instance := range rawInstances {
if instance.RawDomains.Valid {
if err := json.Unmarshal(instance.RawDomains.V, &instance.Domains); err != nil {
return nil, err
}
}
instances[i] = instance.Instance
}
return instances, nil
}
// -------------------------------------------------------------
// sub repositories
// -------------------------------------------------------------
// Domains implements [domain.InstanceRepository].
func (i *instance) Domains(shouldLoad bool) domain.InstanceDomainRepository {
if !i.shouldLoadDomains {
i.shouldLoadDomains = shouldLoad
}
if i.domainRepo != nil {
return i.domainRepo
}
i.domainRepo = &instanceDomain{
repository: i.repository,
instance: i,
}
return i.domainRepo
}

View File

@@ -0,0 +1,214 @@
package repository
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ domain.InstanceDomainRepository = (*instanceDomain)(nil)
type instanceDomain struct {
repository
*instance
}
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
const queryInstanceDomainStmt = `SELECT instance_domains.instance_id, instance_domains.domain, instance_domains.is_primary, instance_domains.created_at, instance_domains.updated_at ` +
`FROM zitadel.instance_domains`
// Get implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).Get of instanceDomain.instance.
func (i *instanceDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.InstanceDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryInstanceDomainStmt)
options.Write(&builder)
return scanInstanceDomain(ctx, i.client, &builder)
}
// List implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).List of instanceDomain.instance.
func (i *instanceDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.InstanceDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryInstanceDomainStmt)
options.Write(&builder)
return scanInstanceDomains(ctx, i.client, &builder)
}
// Add implements [domain.InstanceDomainRepository].
func (i *instanceDomain) Add(ctx context.Context, domain *domain.AddInstanceDomain) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
)
if !domain.CreatedAt.IsZero() {
createdAt = domain.CreatedAt
}
if !domain.UpdatedAt.IsZero() {
updatedAt = domain.UpdatedAt
}
builder.WriteString(`INSERT INTO zitadel.instance_domains (instance_id, domain, is_primary, is_generated, type, created_at, updated_at) VALUES (`)
builder.WriteArgs(domain.InstanceID, domain.Domain, domain.IsPrimary, domain.IsGenerated, domain.Type, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
return i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
}
// Update implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).Update of instanceDomain.instance.
func (i *instanceDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instance_domains SET `)
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
return i.client.Exec(ctx, builder.String(), builder.Args()...)
}
// Remove implements [domain.InstanceDomainRepository].
func (i *instanceDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.instance_domains WHERE `)
condition.Write(&builder)
return i.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetPrimary implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetPrimary() database.Change {
return database.NewChange(i.IsPrimaryColumn(), true)
}
// SetUpdatedAt implements [domain.OrganizationDomainRepository].
func (i instanceDomain) SetUpdatedAt(updatedAt time.Time) database.Change {
return database.NewChange(i.UpdatedAtColumn(), updatedAt)
}
// SetType implements [domain.InstanceDomainRepository].
func (i instanceDomain) SetType(typ domain.DomainType) database.Change {
return database.NewChange(i.TypeColumn(), typ)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// DomainCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
return database.NewTextCondition(i.DomainColumn(), op, domain)
}
// InstanceIDCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(i.InstanceIDColumn(), database.TextOperationEqual, instanceID)
}
// IsPrimaryCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
return database.NewBooleanCondition(i.IsPrimaryColumn(), isPrimary)
}
// TypeCondition implements [domain.InstanceDomainRepository].
func (i instanceDomain) TypeCondition(typ domain.DomainType) database.Condition {
return database.NewTextCondition(i.TypeColumn(), database.TextOperationEqual, typ.String())
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// CreatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).CreatedAtColumn of instanceDomain.instance.
func (instanceDomain) CreatedAtColumn() database.Column {
return database.NewColumn("instance_domains", "created_at")
}
// DomainColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) DomainColumn() database.Column {
return database.NewColumn("instance_domains", "domain")
}
// InstanceIDColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) InstanceIDColumn() database.Column {
return database.NewColumn("instance_domains", "instance_id")
}
// IsPrimaryColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsPrimaryColumn() database.Column {
return database.NewColumn("instance_domains", "is_primary")
}
// UpdatedAtColumn implements [domain.InstanceDomainRepository].
// Subtle: this method shadows the method ([domain.InstanceRepository]).UpdatedAtColumn of instanceDomain.instance.
func (instanceDomain) UpdatedAtColumn() database.Column {
return database.NewColumn("instance_domains", "updated_at")
}
// IsGeneratedColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) IsGeneratedColumn() database.Column {
return database.NewColumn("instance_domains", "is_generated")
}
// TypeColumn implements [domain.InstanceDomainRepository].
func (instanceDomain) TypeColumn() database.Column {
return database.NewColumn("instance_domains", "type")
}
// -------------------------------------------------------------
// scanners
// -------------------------------------------------------------
func scanInstanceDomains(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.InstanceDomain, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
var domains []*domain.InstanceDomain
if err := rows.(database.CollectableRows).Collect(&domains); err != nil {
return nil, err
}
return domains, nil
}
func scanInstanceDomain(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.InstanceDomain, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
domain := new(domain.InstanceDomain)
if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil {
return nil, err
}
return domain, nil
}

View File

@@ -0,0 +1,701 @@
package repository_test
import (
"context"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
func TestAddInstanceDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain
instanceDomain domain.AddInstanceDomain
err error
}{
{
name: "happy path custom domain",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
},
},
{
name: "happy path trusted domain",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
},
},
{
name: "add primary domain",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
},
},
{
name: "add custom domain without domain name",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: "",
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
},
err: new(database.CheckError),
},
{
name: "add trusted domain without domain name",
instanceDomain: domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: "",
Type: domain.DomainTypeTrusted,
},
err: new(database.CheckError),
},
{
name: "add custom domain with same domain twice",
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
domainName := gofakeit.DomainName()
instanceDomain := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
}
err := domainRepo.Add(ctx, instanceDomain)
require.NoError(t, err)
// return same domain again
return &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
}
},
err: new(database.UniqueError),
},
{
name: "add trusted domain with same domain twice",
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.InstanceDomainRepository) *domain.AddInstanceDomain {
domainName := gofakeit.DomainName()
instanceDomain := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
Type: domain.DomainTypeTrusted,
}
err := domainRepo.Add(ctx, instanceDomain)
require.NoError(t, err)
// return same domain again
return &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
Type: domain.DomainTypeTrusted,
}
},
err: new(database.UniqueError),
},
{
name: "add domain with non-existent instance",
instanceDomain: domain.AddInstanceDomain{
InstanceID: "non-existent-instance",
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
},
err: new(database.ForeignKeyError),
},
{
name: "add domain without instance id",
instanceDomain: domain.AddInstanceDomain{
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
},
err: new(database.ForeignKeyError),
},
{
name: "add custom domain without primary",
instanceDomain: domain.AddInstanceDomain{
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsGenerated: gu.Ptr(false),
},
err: new(database.CheckError),
},
{
name: "add custom domain without generated",
instanceDomain: domain.AddInstanceDomain{
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeCustom,
IsPrimary: gu.Ptr(false),
},
err: new(database.CheckError),
},
{
name: "add trusted domain with primary",
instanceDomain: domain.AddInstanceDomain{
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
IsPrimary: gu.Ptr(false),
},
err: new(database.CheckError),
},
{
name: "add trusted domain with generated",
instanceDomain: domain.AddInstanceDomain{
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
IsGenerated: gu.Ptr(false),
},
err: new(database.CheckError),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
// we take now here because the timestamp of the transaction is used to set the createdAt and updatedAt fields
beforeAdd := time.Now()
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
domainRepo := instanceRepo.Domains(false)
var instanceDomain *domain.AddInstanceDomain
if test.testFunc != nil {
instanceDomain = test.testFunc(ctx, t, domainRepo)
} else {
instanceDomain = &test.instanceDomain
}
err = domainRepo.Add(ctx, instanceDomain)
afterAdd := time.Now()
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
}
require.NoError(t, err)
assert.NotZero(t, instanceDomain.CreatedAt)
assert.NotZero(t, instanceDomain.UpdatedAt)
assert.WithinRange(t, instanceDomain.CreatedAt, beforeAdd, afterAdd)
assert.WithinRange(t, instanceDomain.UpdatedAt, beforeAdd, afterAdd)
})
}
}
func TestGetInstanceDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
// add domains
domainRepo := instanceRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domainName2 := gofakeit.DomainName()
domain1 := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName1,
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
domain2 := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName2,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(t.Context(), domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
require.NoError(t, err)
tests := []struct {
name string
opts []database.QueryOption
expected *domain.InstanceDomain
err error
}{
{
name: "get primary domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
},
expected: &domain.InstanceDomain{
InstanceID: instanceID,
Domain: domainName1,
IsPrimary: gu.Ptr(true),
},
},
{
name: "get by domain name",
opts: []database.QueryOption{
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
},
expected: &domain.InstanceDomain{
InstanceID: instanceID,
Domain: domainName2,
IsPrimary: gu.Ptr(false),
},
},
{
name: "get non-existent domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
},
err: new(database.NoRowFoundError),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
result, err := domainRepo.Get(ctx, test.opts...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
}
require.NoError(t, err)
assert.Equal(t, test.expected.InstanceID, result.InstanceID)
assert.Equal(t, test.expected.Domain, result.Domain)
assert.Equal(t, test.expected.IsPrimary, result.IsPrimary)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
})
}
}
func TestListInstanceDomains(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
// add multiple domains
domainRepo := instanceRepo.Domains(false)
domains := []domain.AddInstanceDomain{
{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
},
{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
},
{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
Type: domain.DomainTypeTrusted,
},
}
for i := range domains {
err = domainRepo.Add(t.Context(), &domains[i])
require.NoError(t, err)
}
tests := []struct {
name string
opts []database.QueryOption
expectedCount int
}{
{
name: "list all domains",
opts: []database.QueryOption{},
expectedCount: 3,
},
{
name: "list primary domains",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
},
expectedCount: 1,
},
{
name: "list by instance",
opts: []database.QueryOption{
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
},
expectedCount: 3,
},
{
name: "list non-existent instance",
opts: []database.QueryOption{
database.WithCondition(domainRepo.InstanceIDCondition("non-existent")),
},
expectedCount: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
results, err := domainRepo.List(ctx, test.opts...)
require.NoError(t, err)
assert.Len(t, results, test.expectedCount)
for _, result := range results {
assert.Equal(t, instanceID, result.InstanceID)
assert.NotEmpty(t, result.Domain)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
}
})
}
}
func TestUpdateInstanceDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
// add domain
domainRepo := instanceRepo.Domains(false)
domainName := gofakeit.DomainName()
instanceDomain := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName,
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(t.Context(), instanceDomain)
require.NoError(t, err)
tests := []struct {
name string
condition database.Condition
changes []database.Change
expected int64
err error
}{
{
name: "set primary",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
changes: []database.Change{domainRepo.SetPrimary()},
expected: 1,
},
{
name: "update non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
changes: []database.Change{domainRepo.SetPrimary()},
expected: 0,
},
{
name: "no changes",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
changes: []database.Change{},
expected: 0,
err: database.ErrNoChanges,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
}
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify changes were applied if rows were affected
if rowsAffected > 0 && len(test.changes) > 0 {
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
require.NoError(t, err)
// We know changes were applied since rowsAffected > 0
// The specific verification of what changed is less important
// than knowing the operation succeeded
assert.NotNil(t, result)
}
})
}
}
func TestRemoveInstanceDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
// add domains
domainRepo := instanceRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domain1 := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: domainName1,
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
domain2 := &domain.AddInstanceDomain{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(false),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(t.Context(), domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
require.NoError(t, err)
tests := []struct {
name string
condition database.Condition
expected int64
}{
{
name: "remove by domain name",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
expected: 1,
},
{
name: "remove by primary condition",
condition: domainRepo.IsPrimaryCondition(false),
expected: 1, // domain2 should still exist and be non-primary
},
{
name: "remove non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
// count before removal
beforeCount, err := domainRepo.List(ctx)
require.NoError(t, err)
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify removal
afterCount, err := domainRepo.List(ctx)
require.NoError(t, err)
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
})
}
}
func TestInstanceDomainConditions(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
domainRepo := instanceRepo.Domains(false)
tests := []struct {
name string
condition database.Condition
expected string
}{
{
name: "domain condition equal",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "example.com"),
expected: "instance_domains.domain = $1",
},
{
name: "domain condition starts with",
condition: domainRepo.DomainCondition(database.TextOperationStartsWith, "example"),
expected: "instance_domains.domain LIKE $1 || '%'",
},
{
name: "instance id condition",
condition: domainRepo.InstanceIDCondition("instance-123"),
expected: "instance_domains.instance_id = $1",
},
{
name: "is primary true",
condition: domainRepo.IsPrimaryCondition(true),
expected: "instance_domains.is_primary = $1",
},
{
name: "is primary false",
condition: domainRepo.IsPrimaryCondition(false),
expected: "instance_domains.is_primary = $1",
},
{
name: "is type custom",
condition: domainRepo.TypeCondition(domain.DomainTypeCustom),
expected: "instance_domains.type = $1",
},
{
name: "is type trusted",
condition: domainRepo.TypeCondition(domain.DomainTypeTrusted),
expected: "instance_domains.type = $1",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var builder database.StatementBuilder
test.condition.Write(&builder)
assert.Equal(t, test.expected, builder.String())
})
}
}
func TestInstanceDomainChanges(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
domainRepo := instanceRepo.Domains(false)
tests := []struct {
name string
change database.Change
expected string
}{
{
name: "set primary",
change: domainRepo.SetPrimary(),
expected: "is_primary = $1",
},
{
name: "set type",
change: domainRepo.SetType(domain.DomainTypeCustom),
expected: "type = $1",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var builder database.StatementBuilder
test.change.Write(&builder)
assert.Equal(t, test.expected, builder.String())
})
}
}

View File

@@ -0,0 +1,723 @@
package repository_test
import (
"context"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
func TestCreateInstance(t *testing.T) {
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
instance domain.Instance
err error
}{
{
name: "happy path",
instance: func() domain.Instance {
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
},
{
name: "create instance without name",
instance: func() domain.Instance {
instanceId := gofakeit.Name()
// instanceName := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: "",
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
err: new(database.CheckError),
},
{
name: "adding same instance twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
err := instanceRepo.Create(ctx, &inst)
// change the name to make sure same only the id clashes
inst.Name = gofakeit.Name()
require.NoError(t, err)
return &inst
},
err: new(database.UniqueError),
},
func() struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
instance domain.Instance
err error
} {
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
return struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
instance domain.Instance
err error
}{
name: "adding instance with same name twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
// change the id
inst.ID = instanceId
inst.CreatedAt = time.Time{}
inst.UpdatedAt = time.Time{}
return &inst
},
instance: domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
},
// two instances can have the sane name
err: nil,
}
}(),
{
name: "adding instance with no id",
instance: func() domain.Instance {
// instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
instance := domain.Instance{
// ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
err: new(database.CheckError),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
} else {
instance = &tt.instance
}
instanceRepo := repository.InstanceRepository(pool)
// create instance
beforeCreate := time.Now()
err := instanceRepo.Create(ctx, instance)
assert.ErrorIs(t, err, tt.err)
if err != nil {
return
}
afterCreate := time.Now()
// check instance values
instance, err = instanceRepo.Get(ctx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
)
require.NoError(t, err)
assert.Equal(t, tt.instance.ID, instance.ID)
assert.Equal(t, tt.instance.Name, instance.Name)
assert.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID)
assert.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID)
assert.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID)
assert.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID)
assert.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage)
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
})
}
}
func TestUpdateInstance(t *testing.T) {
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
rowsAffected int64
getErr error
}{
{
name: "happy path",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
rowsAffected: 1,
},
{
name: "update deleted instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
// delete instance
affectedRows, err := instanceRepo.Delete(ctx,
inst.ID,
)
require.NoError(t, err)
assert.Equal(t, int64(1), affectedRows)
return &inst
},
rowsAffected: 0,
},
{
name: "update non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
}
return &inst
},
rowsAffected: 0,
getErr: new(database.NoRowFoundError),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
instance := tt.testFunc(ctx, t)
beforeUpdate := time.Now()
// update name
newName := "new_" + instance.Name
rowsAffected, err := instanceRepo.Update(ctx,
instance.ID,
instanceRepo.SetName(newName),
)
afterUpdate := time.Now()
require.NoError(t, err)
assert.Equal(t, tt.rowsAffected, rowsAffected)
if rowsAffected == 0 {
return
}
// check instance values
instance, err = instanceRepo.Get(ctx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
)
require.Equal(t, tt.getErr, err)
assert.Equal(t, newName, instance.Name)
assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
})
}
}
func TestGetInstance(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
err error
}
tests := []test{
func() test {
instanceId := gofakeit.Name()
return test{
name: "happy path get using id",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
}
}(),
{
name: "happy path including domains",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
domainRepo := instanceRepo.Domains(false)
d := &domain.AddInstanceDomain{
InstanceID: inst.ID,
Domain: gofakeit.DomainName(),
IsPrimary: gu.Ptr(true),
IsGenerated: gu.Ptr(false),
Type: domain.DomainTypeCustom,
}
err = domainRepo.Add(ctx, d)
require.NoError(t, err)
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
InstanceID: d.InstanceID,
Domain: d.Domain,
IsPrimary: d.IsPrimary,
IsGenerated: d.IsGenerated,
Type: d.Type,
CreatedAt: d.CreatedAt,
UpdatedAt: d.UpdatedAt,
})
return &inst
},
},
{
name: "get non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
inst := domain.Instance{
ID: "get non existent instance",
}
return &inst
},
err: new(database.NoRowFoundError),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
}
// check instance values
returnedInstance, err := instanceRepo.Get(ctx,
database.WithCondition(
instanceRepo.IDCondition(instance.ID),
),
)
if tt.err != nil {
require.ErrorIs(t, err, tt.err)
return
}
if instance.ID == "get non existent instance" {
assert.Nil(t, returnedInstance)
return
}
assert.Equal(t, returnedInstance.ID, instance.ID)
assert.Equal(t, returnedInstance.Name, instance.Name)
assert.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID)
assert.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID)
assert.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID)
assert.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID)
assert.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage)
})
}
}
func TestListInstance(t *testing.T) {
ctx := context.Background()
pool, stop, err := newEmbeddedDB(ctx)
require.NoError(t, err)
defer stop()
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
conditionClauses []database.Condition
noInstanceReturned bool
}
tests := []test{
{
name: "happy path single instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
},
{
name: "happy path multiple instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
},
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
return test{
name: "instance filter on id",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "multiple instance filter on name",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)},
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
require.NoError(t, err)
})
instances := tt.testFunc(ctx, t)
instanceRepo := repository.InstanceRepository(pool)
var condition database.Condition
if len(tt.conditionClauses) > 0 {
condition = database.And(tt.conditionClauses...)
}
// check instance values
returnedInstances, err := instanceRepo.List(ctx,
database.WithCondition(condition),
database.WithOrderByAscending(instanceRepo.CreatedAtColumn()),
)
require.NoError(t, err)
if tt.noInstanceReturned {
assert.Nil(t, returnedInstances)
return
}
assert.Equal(t, len(instances), len(returnedInstances))
for i, instance := range instances {
assert.Equal(t, returnedInstances[i].ID, instance.ID)
assert.Equal(t, returnedInstances[i].Name, instance.Name)
assert.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID)
assert.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID)
assert.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID)
assert.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID)
assert.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage)
}
})
}
}
func TestDeleteInstance(t *testing.T) {
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T)
instanceID string
noOfDeletedRows int64
}
tests := []test{
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
var noOfInstances int64 = 1
return test{
name: "happy path delete single instance filter id",
testFunc: func(ctx context.Context, t *testing.T) {
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
instanceID: instanceId,
noOfDeletedRows: noOfInstances,
}
}(),
func() test {
non_existent_instance_name := gofakeit.Name()
return test{
name: "delete non existent instance",
instanceID: non_existent_instance_name,
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "deleted already deleted instance",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
// delete instance
affectedRows, err := instanceRepo.Delete(ctx,
instances[0].ID,
)
require.NoError(t, err)
assert.Equal(t, int64(1), affectedRows)
},
instanceID: instanceName,
// this test should return 0 affected rows as the instance was already deleted
noOfDeletedRows: 0,
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
if tt.testFunc != nil {
tt.testFunc(ctx, t)
}
// delete instance
noOfDeletedRows, err := instanceRepo.Delete(ctx,
tt.instanceID,
)
require.NoError(t, err)
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
// check instance was deleted
instance, err := instanceRepo.Get(ctx,
database.WithCondition(
instanceRepo.IDCondition(tt.instanceID),
),
)
require.ErrorIs(t, err, new(database.NoRowFoundError))
assert.Nil(t, instance)
})
}
}

View File

@@ -0,0 +1,282 @@
package repository
import (
"context"
"encoding/json"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
var _ domain.OrganizationRepository = (*org)(nil)
type org struct {
repository
shouldLoadDomains bool
domainRepo domain.OrganizationDomainRepository
}
func OrganizationRepository(client database.QueryExecutor) domain.OrganizationRepository {
return &org{
repository: repository{
client: client,
},
}
}
const queryOrganizationStmt = `SELECT organizations.id, organizations.name, organizations.instance_id, organizations.state, organizations.created_at, organizations.updated_at` +
` , CASE WHEN count(org_domains.domain) > 0 THEN jsonb_agg(json_build_object('domain', org_domains.domain, 'isVerified', org_domains.is_verified, 'isPrimary', org_domains.is_primary, 'validationType', org_domains.validation_type, 'createdAt', org_domains.created_at, 'updatedAt', org_domains.updated_at)) ELSE NULL::JSONB END domains` +
` FROM zitadel.organizations`
// Get implements [domain.OrganizationRepository].
func (o *org) Get(ctx context.Context, opts ...database.QueryOption) (*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
return scanOrganization(ctx, o.client, &builder)
}
// List implements [domain.OrganizationRepository].
func (o *org) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.Organization, error) {
opts = append(opts,
o.joinDomains(),
database.WithGroupBy(o.InstanceIDColumn(), o.IDColumn()),
)
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationStmt)
options.Write(&builder)
return scanOrganizations(ctx, o.client, &builder)
}
func (o *org) joinDomains() database.QueryOption {
columns := make([]database.Condition, 0, 3)
columns = append(columns,
database.NewColumnCondition(o.InstanceIDColumn(), o.Domains(false).InstanceIDColumn()),
database.NewColumnCondition(o.IDColumn(), o.Domains(false).OrgIDColumn()),
)
// If domains should not be joined, we make sure to return null for the domain columns
// the query optimizer of the dialect should optimize this away if no domains are requested
if !o.shouldLoadDomains {
columns = append(columns, database.IsNull(o.domainRepo.OrgIDColumn()))
}
return database.WithLeftJoin(
"zitadel.org_domains",
database.And(columns...),
)
}
const createOrganizationStmt = `INSERT INTO zitadel.organizations (id, name, instance_id, state)` +
` VALUES ($1, $2, $3, $4)` +
` RETURNING created_at, updated_at`
// Create implements [domain.OrganizationRepository].
func (o *org) Create(ctx context.Context, organization *domain.Organization) error {
builder := database.StatementBuilder{}
builder.AppendArgs(organization.ID, organization.Name, organization.InstanceID, organization.State)
builder.WriteString(createOrganizationStmt)
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&organization.CreatedAt, &organization.UpdatedAt)
}
// Update implements [domain.OrganizationRepository].
func (o *org) Update(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE zitadel.organizations SET `)
instanceIDCondition := o.InstanceIDCondition(instanceID)
conditions := []database.Condition{id, instanceIDCondition}
database.Changes(changes).Write(&builder)
writeCondition(&builder, database.And(conditions...))
stmt := builder.String()
rowsAffected, err := o.client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err
}
// Delete implements [domain.OrganizationRepository].
func (o *org) Delete(ctx context.Context, id domain.OrgIdentifierCondition, instanceID string) (int64, error) {
builder := database.StatementBuilder{}
builder.WriteString(`DELETE FROM zitadel.organizations`)
instanceIDCondition := o.InstanceIDCondition(instanceID)
conditions := []database.Condition{id, instanceIDCondition}
writeCondition(&builder, database.And(conditions...))
return o.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetName implements [domain.organizationChanges].
func (o org) SetName(name string) database.Change {
return database.NewChange(o.NameColumn(), name)
}
// SetState implements [domain.organizationChanges].
func (o org) SetState(state domain.OrgState) database.Change {
return database.NewChange(o.StateColumn(), state)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// IDCondition implements [domain.organizationConditions].
func (o org) IDCondition(id string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.IDColumn(), database.TextOperationEqual, id)
}
// NameCondition implements [domain.organizationConditions].
func (o org) NameCondition(name string) domain.OrgIdentifierCondition {
return database.NewTextCondition(o.NameColumn(), database.TextOperationEqual, name)
}
// InstanceIDCondition implements [domain.organizationConditions].
func (o org) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
}
// StateCondition implements [domain.organizationConditions].
func (o org) StateCondition(state domain.OrgState) database.Condition {
return database.NewTextCondition(o.StateColumn(), database.TextOperationEqual, state.String())
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.organizationColumns].
func (org) IDColumn() database.Column {
return database.NewColumn("organizations", "id")
}
// NameColumn implements [domain.organizationColumns].
func (org) NameColumn() database.Column {
return database.NewColumn("organizations", "name")
}
// InstanceIDColumn implements [domain.organizationColumns].
func (org) InstanceIDColumn() database.Column {
return database.NewColumn("organizations", "instance_id")
}
// StateColumn implements [domain.organizationColumns].
func (org) StateColumn() database.Column {
return database.NewColumn("organizations", "state")
}
// CreatedAtColumn implements [domain.organizationColumns].
func (org) CreatedAtColumn() database.Column {
return database.NewColumn("organizations", "created_at")
}
// UpdatedAtColumn implements [domain.organizationColumns].
func (org) UpdatedAtColumn() database.Column {
return database.NewColumn("organizations", "updated_at")
}
// -------------------------------------------------------------
// scanners
// -------------------------------------------------------------
type rawOrganization struct {
*domain.Organization
RawDomains json.RawMessage `json:"domains,omitempty" db:"domains"`
}
func scanOrganization(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) (*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
var org rawOrganization
if err := rows.(database.CollectableRows).CollectExactlyOneRow(&org); err != nil {
return nil, err
}
if len(org.RawDomains) > 0 {
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
return nil, err
}
}
return org.Organization, nil
}
func scanOrganizations(ctx context.Context, querier database.Querier, builder *database.StatementBuilder) ([]*domain.Organization, error) {
rows, err := querier.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
var rawOrgs []*rawOrganization
if err := rows.(database.CollectableRows).Collect(&rawOrgs); err != nil {
return nil, err
}
organizations := make([]*domain.Organization, len(rawOrgs))
for i, org := range rawOrgs {
if len(org.RawDomains) > 0 {
if err := json.Unmarshal(org.RawDomains, &org.Domains); err != nil {
return nil, err
}
}
organizations[i] = org.Organization
}
return organizations, nil
}
// -------------------------------------------------------------
// sub repositories
// -------------------------------------------------------------
// Domains implements [domain.OrganizationRepository].
func (o *org) Domains(shouldLoad bool) domain.OrganizationDomainRepository {
if !o.shouldLoadDomains {
o.shouldLoadDomains = shouldLoad
}
if o.domainRepo != nil {
return o.domainRepo
}
o.domainRepo = &orgDomain{
repository: o.repository,
org: o,
}
return o.domainRepo
}

View File

@@ -0,0 +1,230 @@
package repository
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ domain.OrganizationDomainRepository = (*orgDomain)(nil)
type orgDomain struct {
repository
*org
}
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
const queryOrganizationDomainStmt = `SELECT instance_id, org_id, domain, is_verified, is_primary, validation_type, created_at, updated_at ` +
`FROM zitadel.org_domains`
// Get implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Get of orgDomain.org.
func (o *orgDomain) Get(ctx context.Context, opts ...database.QueryOption) (*domain.OrganizationDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationDomainStmt)
options.Write(&builder)
return scanOrganizationDomain(ctx, o.client, &builder)
}
// List implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).List of orgDomain.org.
func (o *orgDomain) List(ctx context.Context, opts ...database.QueryOption) ([]*domain.OrganizationDomain, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
var builder database.StatementBuilder
builder.WriteString(queryOrganizationDomainStmt)
options.Write(&builder)
return scanOrganizationDomains(ctx, o.client, &builder)
}
// Add implements [domain.OrganizationDomainRepository].
func (o *orgDomain) Add(ctx context.Context, domain *domain.AddOrganizationDomain) error {
var (
builder database.StatementBuilder
createdAt, updatedAt any = database.DefaultInstruction, database.DefaultInstruction
)
if !domain.CreatedAt.IsZero() {
createdAt = domain.CreatedAt
}
if !domain.UpdatedAt.IsZero() {
updatedAt = domain.UpdatedAt
}
builder.WriteString(`INSERT INTO zitadel.org_domains (instance_id, org_id, domain, is_verified, is_primary, validation_type, created_at, updated_at) VALUES (`)
builder.WriteArgs(domain.InstanceID, domain.OrgID, domain.Domain, domain.IsVerified, domain.IsPrimary, domain.ValidationType, createdAt, updatedAt)
builder.WriteString(`) RETURNING created_at, updated_at`)
return o.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&domain.CreatedAt, &domain.UpdatedAt)
}
// Update implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).Update of orgDomain.org.
func (o *orgDomain) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
if len(changes) == 0 {
return 0, database.ErrNoChanges
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.org_domains SET `)
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
return o.client.Exec(ctx, builder.String(), builder.Args()...)
}
// Remove implements [domain.OrganizationDomainRepository].
func (o *orgDomain) Remove(ctx context.Context, condition database.Condition) (int64, error) {
var builder database.StatementBuilder
builder.WriteString(`DELETE FROM zitadel.org_domains `)
writeCondition(&builder, condition)
return o.client.Exec(ctx, builder.String(), builder.Args()...)
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetPrimary implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetPrimary() database.Change {
return database.NewChange(o.IsPrimaryColumn(), true)
}
// SetValidationType implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetValidationType(verificationType domain.DomainValidationType) database.Change {
return database.NewChange(o.ValidationTypeColumn(), verificationType)
}
// SetVerified implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetVerified() database.Change {
return database.NewChange(o.IsVerifiedColumn(), true)
}
// SetUpdatedAt implements [domain.OrganizationDomainRepository].
func (o orgDomain) SetUpdatedAt(updatedAt time.Time) database.Change {
return database.NewChange(o.UpdatedAtColumn(), updatedAt)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// DomainCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) DomainCondition(op database.TextOperation, domain string) database.Condition {
return database.NewTextCondition(o.DomainColumn(), op, domain)
}
// InstanceIDCondition implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDCondition of orgDomain.org.
func (o orgDomain) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(o.InstanceIDColumn(), database.TextOperationEqual, instanceID)
}
// IsPrimaryCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) IsPrimaryCondition(isPrimary bool) database.Condition {
return database.NewBooleanCondition(o.IsPrimaryColumn(), isPrimary)
}
// IsVerifiedCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) IsVerifiedCondition(isVerified bool) database.Condition {
return database.NewBooleanCondition(o.IsVerifiedColumn(), isVerified)
}
// OrgIDCondition implements [domain.OrganizationDomainRepository].
func (o orgDomain) OrgIDCondition(orgID string) database.Condition {
return database.NewTextCondition(o.OrgIDColumn(), database.TextOperationEqual, orgID)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// CreatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).CreatedAtColumn of orgDomain.org.
func (orgDomain) CreatedAtColumn() database.Column {
return database.NewColumn("org_domains", "created_at")
}
// DomainColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) DomainColumn() database.Column {
return database.NewColumn("org_domains", "domain")
}
// InstanceIDColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).InstanceIDColumn of orgDomain.org.
func (orgDomain) InstanceIDColumn() database.Column {
return database.NewColumn("org_domains", "instance_id")
}
// IsPrimaryColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsPrimaryColumn() database.Column {
return database.NewColumn("org_domains", "is_primary")
}
// IsVerifiedColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) IsVerifiedColumn() database.Column {
return database.NewColumn("org_domains", "is_verified")
}
// OrgIDColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) OrgIDColumn() database.Column {
return database.NewColumn("org_domains", "org_id")
}
// UpdatedAtColumn implements [domain.OrganizationDomainRepository].
// Subtle: this method shadows the method ([domain.OrganizationRepository]).UpdatedAtColumn of orgDomain.org.
func (orgDomain) UpdatedAtColumn() database.Column {
return database.NewColumn("org_domains", "updated_at")
}
// ValidationTypeColumn implements [domain.OrganizationDomainRepository].
func (orgDomain) ValidationTypeColumn() database.Column {
return database.NewColumn("org_domains", "validation_type")
}
// -------------------------------------------------------------
// scanners
// -------------------------------------------------------------
func scanOrganizationDomain(ctx context.Context, client database.Querier, builder *database.StatementBuilder) (*domain.OrganizationDomain, error) {
rows, err := client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
domain := &domain.OrganizationDomain{}
if err := rows.(database.CollectableRows).CollectExactlyOneRow(domain); err != nil {
return nil, err
}
return domain, nil
}
func scanOrganizationDomains(ctx context.Context, client database.Querier, builder *database.StatementBuilder) ([]*domain.OrganizationDomain, error) {
rows, err := client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
var domains []*domain.OrganizationDomain
if err := rows.(database.CollectableRows).Collect(&domains); err != nil {
return nil, err
}
return domains, nil
}

View File

@@ -0,0 +1,972 @@
package repository_test
import (
"context"
"testing"
"github.com/brianvoe/gofakeit/v6"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
func TestAddOrganizationDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
instanceRepo := repository.InstanceRepository(pool)
err := instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
Name: gofakeit.Name(),
InstanceID: instanceID,
State: domain.OrgStateActive,
}
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain
organizationDomain domain.AddOrganizationDomain
err error
}{
{
name: "happy path",
organizationDomain: domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
},
{
name: "add verified domain",
organizationDomain: domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
},
},
{
name: "add primary domain",
organizationDomain: domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
},
{
name: "add domain without domain name",
organizationDomain: domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: "",
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
err: new(database.CheckError),
},
{
name: "add domain with same domain twice",
testFunc: func(ctx context.Context, t *testing.T, domainRepo domain.OrganizationDomainRepository) *domain.AddOrganizationDomain {
domainName := gofakeit.DomainName()
organizationDomain := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName,
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
err := domainRepo.Add(ctx, organizationDomain)
require.NoError(t, err)
// return same domain again
return &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
}
},
err: new(database.UniqueError),
},
{
name: "add domain with non-existent instance",
organizationDomain: domain.AddOrganizationDomain{
InstanceID: "non-existent-instance",
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
err: new(database.ForeignKeyError),
},
{
name: "add domain with non-existent organization",
organizationDomain: domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: "non-existent-org",
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
err: new(database.ForeignKeyError),
},
{
name: "add domain without instance id",
organizationDomain: domain.AddOrganizationDomain{
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
err: new(database.ForeignKeyError),
},
{
name: "add domain without org id",
organizationDomain: domain.AddOrganizationDomain{
InstanceID: instanceID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
err: new(database.ForeignKeyError),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
require.NoError(t, err)
domainRepo := orgRepo.Domains(false)
var organizationDomain *domain.AddOrganizationDomain
if test.testFunc != nil {
organizationDomain = test.testFunc(ctx, t, domainRepo)
} else {
organizationDomain = &test.organizationDomain
}
err = domainRepo.Add(ctx, organizationDomain)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
}
require.NoError(t, err)
assert.NotZero(t, organizationDomain.CreatedAt)
assert.NotZero(t, organizationDomain.UpdatedAt)
})
}
}
func TestGetOrganizationDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
Name: gofakeit.Name(),
InstanceID: instanceID,
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
require.NoError(t, err)
// add domains
domainRepo := orgRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domainName2 := gofakeit.DomainName()
domain1 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
domain2 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName2,
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
}
err = domainRepo.Add(t.Context(), domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
require.NoError(t, err)
tests := []struct {
name string
opts []database.QueryOption
expected *domain.OrganizationDomain
err error
}{
{
name: "get primary domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
},
{
name: "get by domain name",
opts: []database.QueryOption{
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName2,
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
},
},
{
name: "get by org ID",
opts: []database.QueryOption{
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
},
{
name: "get verified domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
},
expected: &domain.OrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
},
{
name: "get non-existent domain",
opts: []database.QueryOption{
database.WithCondition(domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com")),
},
err: new(database.NoRowFoundError),
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
result, err := domainRepo.Get(ctx, test.opts...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
}
require.NoError(t, err)
assert.Equal(t, test.expected.InstanceID, result.InstanceID)
assert.Equal(t, test.expected.OrgID, result.OrgID)
assert.Equal(t, test.expected.Domain, result.Domain)
assert.Equal(t, test.expected.IsVerified, result.IsVerified)
assert.Equal(t, test.expected.IsPrimary, result.IsPrimary)
assert.Equal(t, test.expected.ValidationType, result.ValidationType)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
})
}
}
func TestListOrganizationDomains(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
Name: gofakeit.Name(),
InstanceID: instanceID,
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
require.NoError(t, err)
// add multiple domains
domainRepo := orgRepo.Domains(false)
domains := []domain.AddOrganizationDomain{
{
InstanceID: instanceID,
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
{
InstanceID: instanceID,
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
},
{
InstanceID: instanceID,
OrgID: orgID,
Domain: gofakeit.DomainName(),
IsVerified: true,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
},
}
for i := range domains {
err = domainRepo.Add(t.Context(), &domains[i])
require.NoError(t, err)
}
tests := []struct {
name string
opts []database.QueryOption
expectedCount int
}{
{
name: "list all domains",
opts: []database.QueryOption{},
expectedCount: 3,
},
{
name: "list verified domains",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsVerifiedCondition(true)),
},
expectedCount: 2,
},
{
name: "list primary domains",
opts: []database.QueryOption{
database.WithCondition(domainRepo.IsPrimaryCondition(true)),
},
expectedCount: 1,
},
{
name: "list by organization",
opts: []database.QueryOption{
database.WithCondition(domainRepo.OrgIDCondition(orgID)),
},
expectedCount: 3,
},
{
name: "list by instance",
opts: []database.QueryOption{
database.WithCondition(domainRepo.InstanceIDCondition(instanceID)),
},
expectedCount: 3,
},
{
name: "list non-existent organization",
opts: []database.QueryOption{
database.WithCondition(domainRepo.OrgIDCondition("non-existent")),
},
expectedCount: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
results, err := domainRepo.List(ctx, test.opts...)
require.NoError(t, err)
assert.Len(t, results, test.expectedCount)
for _, result := range results {
assert.Equal(t, instanceID, result.InstanceID)
assert.Equal(t, orgID, result.OrgID)
assert.NotEmpty(t, result.Domain)
assert.NotEmpty(t, result.CreatedAt)
assert.NotEmpty(t, result.UpdatedAt)
}
})
}
}
func TestUpdateOrganizationDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
Name: gofakeit.Name(),
InstanceID: instanceID,
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
require.NoError(t, err)
// add domain
domainRepo := orgRepo.Domains(false)
domainName := gofakeit.DomainName()
organizationDomain := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName,
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
err = domainRepo.Add(t.Context(), organizationDomain)
require.NoError(t, err)
tests := []struct {
name string
condition database.Condition
changes []database.Change
expected int64
err error
}{
{
name: "set verified",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
changes: []database.Change{domainRepo.SetVerified()},
expected: 1,
},
{
name: "set primary",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
changes: []database.Change{domainRepo.SetPrimary()},
expected: 1,
},
{
name: "set validation type",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
changes: []database.Change{domainRepo.SetValidationType(domain.DomainValidationTypeHTTP)},
expected: 1,
},
{
name: "multiple changes",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
changes: []database.Change{
domainRepo.SetVerified(),
domainRepo.SetPrimary(),
domainRepo.SetValidationType(domain.DomainValidationTypeDNS),
},
expected: 1,
},
{
name: "update by org ID and domain",
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName)),
changes: []database.Change{domainRepo.SetVerified()},
expected: 1,
},
{
name: "update non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
changes: []database.Change{domainRepo.SetVerified()},
expected: 0,
},
{
name: "no changes",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName),
changes: []database.Change{},
expected: 0,
err: database.ErrNoChanges,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
rowsAffected, err := domainRepo.Update(ctx, test.condition, test.changes...)
if test.err != nil {
assert.ErrorIs(t, err, test.err)
return
}
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify changes were applied if rows were affected
if rowsAffected > 0 && len(test.changes) > 0 {
result, err := domainRepo.Get(ctx, database.WithCondition(test.condition))
require.NoError(t, err)
// We know changes were applied since rowsAffected > 0
// The specific verification of what changed is less important
// than knowing the operation succeeded
assert.NotNil(t, result)
}
})
}
}
func TestRemoveOrganizationDomain(t *testing.T) {
// create instance
instanceID := gofakeit.UUID()
instance := domain.Instance{
ID: instanceID,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleClient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create organization
orgID := gofakeit.UUID()
organization := domain.Organization{
ID: orgID,
Name: gofakeit.Name(),
InstanceID: instanceID,
State: domain.OrgStateActive,
}
tx, err := pool.Begin(t.Context(), nil)
require.NoError(t, err)
defer func() {
require.NoError(t, tx.Rollback(t.Context()))
}()
instanceRepo := repository.InstanceRepository(tx)
err = instanceRepo.Create(t.Context(), &instance)
require.NoError(t, err)
orgRepo := repository.OrganizationRepository(tx)
err = orgRepo.Create(t.Context(), &organization)
require.NoError(t, err)
// add domains
domainRepo := orgRepo.Domains(false)
domainName1 := gofakeit.DomainName()
domainName2 := gofakeit.DomainName()
domain1 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName1,
IsVerified: true,
IsPrimary: true,
ValidationType: gu.Ptr(domain.DomainValidationTypeDNS),
}
domain2 := &domain.AddOrganizationDomain{
InstanceID: instanceID,
OrgID: orgID,
Domain: domainName2,
IsVerified: false,
IsPrimary: false,
ValidationType: gu.Ptr(domain.DomainValidationTypeHTTP),
}
err = domainRepo.Add(t.Context(), domain1)
require.NoError(t, err)
err = domainRepo.Add(t.Context(), domain2)
require.NoError(t, err)
tests := []struct {
name string
condition database.Condition
expected int64
}{
{
name: "remove by domain name",
condition: domainRepo.DomainCondition(database.TextOperationEqual, domainName1),
expected: 1,
},
{
name: "remove by primary condition",
condition: domainRepo.IsPrimaryCondition(false),
expected: 1, // domain2 should still exist and be non-primary
},
{
name: "remove by org ID and domain",
condition: database.And(domainRepo.OrgIDCondition(orgID), domainRepo.DomainCondition(database.TextOperationEqual, domainName2)),
expected: 1,
},
{
name: "remove non-existent domain",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "non-existent.com"),
expected: 0,
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
ctx := t.Context()
snapshot, err := tx.Begin(ctx)
require.NoError(t, err)
defer func() {
require.NoError(t, snapshot.Rollback(ctx))
}()
orgRepo := repository.OrganizationRepository(snapshot)
domainRepo := orgRepo.Domains(false)
// count before removal
beforeCount, err := domainRepo.List(ctx)
require.NoError(t, err)
rowsAffected, err := domainRepo.Remove(ctx, test.condition)
require.NoError(t, err)
assert.Equal(t, test.expected, rowsAffected)
// verify removal
afterCount, err := domainRepo.List(ctx)
require.NoError(t, err)
assert.Equal(t, len(beforeCount)-int(test.expected), len(afterCount))
})
}
}
func TestOrganizationDomainConditions(t *testing.T) {
orgRepo := repository.OrganizationRepository(pool)
domainRepo := orgRepo.Domains(false)
tests := []struct {
name string
condition database.Condition
expected string
}{
{
name: "domain condition equal",
condition: domainRepo.DomainCondition(database.TextOperationEqual, "example.com"),
expected: "org_domains.domain = $1",
},
{
name: "domain condition starts with",
condition: domainRepo.DomainCondition(database.TextOperationStartsWith, "example"),
expected: "org_domains.domain LIKE $1 || '%'",
},
{
name: "instance id condition",
condition: domainRepo.InstanceIDCondition("instance-123"),
expected: "org_domains.instance_id = $1",
},
{
name: "org id condition",
condition: domainRepo.OrgIDCondition("org-123"),
expected: "org_domains.org_id = $1",
},
{
name: "is primary true",
condition: domainRepo.IsPrimaryCondition(true),
expected: "org_domains.is_primary = $1",
},
{
name: "is primary false",
condition: domainRepo.IsPrimaryCondition(false),
expected: "org_domains.is_primary = $1",
},
{
name: "is verified true",
condition: domainRepo.IsVerifiedCondition(true),
expected: "org_domains.is_verified = $1",
},
{
name: "is verified false",
condition: domainRepo.IsVerifiedCondition(false),
expected: "org_domains.is_verified = $1",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var builder database.StatementBuilder
test.condition.Write(&builder)
assert.Equal(t, test.expected, builder.String())
})
}
}
func TestOrganizationDomainChanges(t *testing.T) {
orgRepo := repository.OrganizationRepository(pool)
domainRepo := orgRepo.Domains(false)
tests := []struct {
name string
change database.Change
expected string
}{
{
name: "set verified",
change: domainRepo.SetVerified(),
expected: "is_verified = $1",
},
{
name: "set primary",
change: domainRepo.SetPrimary(),
expected: "is_primary = $1",
},
{
name: "set validation type DNS",
change: domainRepo.SetValidationType(domain.DomainValidationTypeDNS),
expected: "validation_type = $1",
},
{
name: "set validation type HTTP",
change: domainRepo.SetValidationType(domain.DomainValidationTypeHTTP),
expected: "validation_type = $1",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var builder database.StatementBuilder
test.change.Write(&builder)
assert.Equal(t, test.expected, builder.String())
})
}
}
func TestOrganizationDomainColumns(t *testing.T) {
orgRepo := repository.OrganizationRepository(pool)
domainRepo := orgRepo.Domains(false)
tests := []struct {
name string
column database.Column
qualified bool
expected string
}{
{
name: "instance id column qualified",
column: domainRepo.InstanceIDColumn(),
qualified: true,
expected: "org_domains.instance_id",
},
{
name: "instance id column unqualified",
column: domainRepo.InstanceIDColumn(),
qualified: false,
expected: "instance_id",
},
{
name: "org id column qualified",
column: domainRepo.OrgIDColumn(),
qualified: true,
expected: "org_domains.org_id",
},
{
name: "org id column unqualified",
column: domainRepo.OrgIDColumn(),
qualified: false,
expected: "org_id",
},
{
name: "domain column qualified",
column: domainRepo.DomainColumn(),
qualified: true,
expected: "org_domains.domain",
},
{
name: "domain column unqualified",
column: domainRepo.DomainColumn(),
qualified: false,
expected: "domain",
},
{
name: "is verified column qualified",
column: domainRepo.IsVerifiedColumn(),
qualified: true,
expected: "org_domains.is_verified",
},
{
name: "is verified column unqualified",
column: domainRepo.IsVerifiedColumn(),
qualified: false,
expected: "is_verified",
},
{
name: "is primary column qualified",
column: domainRepo.IsPrimaryColumn(),
qualified: true,
expected: "org_domains.is_primary",
},
{
name: "is primary column unqualified",
column: domainRepo.IsPrimaryColumn(),
qualified: false,
expected: "is_primary",
},
{
name: "validation type column qualified",
column: domainRepo.ValidationTypeColumn(),
qualified: true,
expected: "org_domains.validation_type",
},
{
name: "validation type column unqualified",
column: domainRepo.ValidationTypeColumn(),
qualified: false,
expected: "validation_type",
},
{
name: "created at column qualified",
column: domainRepo.CreatedAtColumn(),
qualified: true,
expected: "org_domains.created_at",
},
{
name: "created at column unqualified",
column: domainRepo.CreatedAtColumn(),
qualified: false,
expected: "created_at",
},
{
name: "updated at column qualified",
column: domainRepo.UpdatedAtColumn(),
qualified: true,
expected: "org_domains.updated_at",
},
{
name: "updated at column unqualified",
column: domainRepo.UpdatedAtColumn(),
qualified: false,
expected: "updated_at",
},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
var builder database.StatementBuilder
if test.qualified {
test.column.WriteQualified(&builder)
} else {
test.column.WriteUnqualified(&builder)
}
assert.Equal(t, test.expected, builder.String())
})
}
}

View File

@@ -0,0 +1,28 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
)
type orgMember struct {
*org
}
// AddMember implements [domain.MemberRepository].
func (o *orgMember) AddMember(ctx context.Context, orgID string, userID string, roles []string) error {
return nil
}
// RemoveMember implements [domain.MemberRepository].
func (o *orgMember) RemoveMember(ctx context.Context, orgID string, userID string) error {
return nil
}
// SetMemberRoles implements [domain.MemberRepository].
func (o *orgMember) SetMemberRoles(ctx context.Context, orgID string, userID string, roles []string) error {
return nil
}
var _ domain.MemberRepository = (*orgMember)(nil)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,20 @@
package repository
import (
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type repository struct {
client database.QueryExecutor
}
func writeCondition(
builder *database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
}
builder.WriteString(" WHERE ")
condition.Write(builder)
}

View File

@@ -0,0 +1,51 @@
package repository_test
import (
"context"
"fmt"
"log"
"os"
"testing"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres/embedded"
)
func TestMain(m *testing.M) {
os.Exit(runTests(m))
}
var pool database.PoolTest
func runTests(m *testing.M) int {
var stop func()
var err error
ctx := context.Background()
pool, stop, err = newEmbeddedDB(ctx)
if err != nil {
log.Printf("error with embedded postgres database: %v", err)
return 1
}
defer stop()
return m.Run()
}
func newEmbeddedDB(ctx context.Context) (pool database.PoolTest, stop func(), err error) {
connector, stop, err := embedded.StartEmbedded()
if err != nil {
return nil, nil, fmt.Errorf("unable to start embedded postgres: %w", err)
}
pool_, err := connector.Connect(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to connect to embedded postgres: %w", err)
}
pool = pool_.(database.PoolTest)
err = pool.MigrateTest(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to migrate database: %w", err)
}
return pool, stop, err
}

View File

@@ -0,0 +1,282 @@
package repository
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
const queryUserStmt = `SELECT instance_id, org_id, id, username, type, created_at, updated_at, deleted_at,` +
` first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at, description` +
` FROM users_view users`
type user struct {
repository
}
func UserRepository(client database.QueryExecutor) domain.UserRepository {
return &user{
repository: repository{
client: client,
},
}
}
var _ domain.UserRepository = (*user)(nil)
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
// Human implements [domain.UserRepository].
func (u *user) Human() domain.HumanRepository {
return &userHuman{user: u}
}
// Machine implements [domain.UserRepository].
func (u *user) Machine() domain.MachineRepository {
return &userMachine{user: u}
}
// List implements [domain.UserRepository].
func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []*domain.User, err error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
builder := database.StatementBuilder{}
builder.WriteString(queryUserStmt)
options.WriteCondition(&builder)
options.WriteOrderBy(&builder)
options.WriteLimit(&builder)
options.WriteOffset(&builder)
rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
defer func() {
closeErr := rows.Close()
if err != nil {
return
}
err = closeErr
}()
for rows.Next() {
user, err := scanUser(rows)
if err != nil {
return nil, err
}
users = append(users, user)
}
if err := rows.Err(); err != nil {
return nil, err
}
return users, nil
}
// Get implements [domain.UserRepository].
func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.User, error) {
options := new(database.QueryOpts)
for _, opt := range opts {
opt(options)
}
builder := database.StatementBuilder{}
builder.WriteString(queryUserStmt)
options.WriteCondition(&builder)
options.WriteOrderBy(&builder)
options.WriteLimit(&builder)
options.WriteOffset(&builder)
return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
}
const (
createHumanStmt = `INSERT INTO human_users (instance_id, org_id, user_id, username, first_name, last_name, email_address, email_verified_at, phone_number, phone_verified_at)` +
` VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)` +
` RETURNING created_at, updated_at`
createMachineStmt = `INSERT INTO user_machines (instance_id, org_id, user_id, username, description)` +
` VALUES ($1, $2, $3, $4, $5)` +
` RETURNING created_at, updated_at`
)
// Create implements [domain.UserRepository].
func (u *user) Create(ctx context.Context, user *domain.User) error {
builder := database.StatementBuilder{}
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) {
case *domain.Human:
builder.WriteString(createHumanStmt)
builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
case *domain.Machine:
builder.WriteString(createMachineStmt)
builder.AppendArgs(trait.Description)
}
return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
}
// Delete implements [domain.UserRepository].
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
builder := database.StatementBuilder{}
builder.WriteString("DELETE FROM users")
writeCondition(&builder, condition)
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetUsername implements [domain.userChanges].
func (u user) SetUsername(username string) database.Change {
return database.NewChange(u.UsernameColumn(), username)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// InstanceIDCondition implements [domain.userConditions].
func (u user) InstanceIDCondition(instanceID string) database.Condition {
return database.NewTextCondition(u.InstanceIDColumn(), database.TextOperationEqual, instanceID)
}
// OrgIDCondition implements [domain.userConditions].
func (u user) OrgIDCondition(orgID string) database.Condition {
return database.NewTextCondition(u.OrgIDColumn(), database.TextOperationEqual, orgID)
}
// IDCondition implements [domain.userConditions].
func (u user) IDCondition(userID string) database.Condition {
return database.NewTextCondition(u.IDColumn(), database.TextOperationEqual, userID)
}
// UsernameCondition implements [domain.userConditions].
func (u user) UsernameCondition(op database.TextOperation, username string) database.Condition {
return database.NewTextCondition(u.UsernameColumn(), op, username)
}
// CreatedAtCondition implements [domain.userConditions].
func (u user) CreatedAtCondition(op database.NumberOperation, createdAt time.Time) database.Condition {
return database.NewNumberCondition(u.CreatedAtColumn(), op, createdAt)
}
// UpdatedAtCondition implements [domain.userConditions].
func (u user) UpdatedAtCondition(op database.NumberOperation, updatedAt time.Time) database.Condition {
return database.NewNumberCondition(u.UpdatedAtColumn(), op, updatedAt)
}
// DeletedCondition implements [domain.userConditions].
func (u user) DeletedCondition(isDeleted bool) database.Condition {
if isDeleted {
return database.IsNotNull(u.DeletedAtColumn())
}
return database.IsNull(u.DeletedAtColumn())
}
// DeletedAtCondition implements [domain.userConditions].
func (u user) DeletedAtCondition(op database.NumberOperation, deletedAt time.Time) database.Condition {
return database.NewNumberCondition(u.DeletedAtColumn(), op, deletedAt)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// InstanceIDColumn implements [domain.userColumns].
func (user) InstanceIDColumn() database.Column {
return database.NewColumn("users", "instance_id")
}
// OrgIDColumn implements [domain.userColumns].
func (user) OrgIDColumn() database.Column {
return database.NewColumn("users", "org_id")
}
// IDColumn implements [domain.userColumns].
func (user) IDColumn() database.Column {
return database.NewColumn("users", "id")
}
// UsernameColumn implements [domain.userColumns].
func (user) UsernameColumn() database.Column {
return database.NewColumn("users", "username")
}
// FirstNameColumn implements [domain.userColumns].
func (user) CreatedAtColumn() database.Column {
return database.NewColumn("users", "created_at")
}
// UpdatedAtColumn implements [domain.userColumns].
func (user) UpdatedAtColumn() database.Column {
return database.NewColumn("users", "updated_at")
}
// DeletedAtColumn implements [domain.userColumns].
func (user) DeletedAtColumn() database.Column {
return database.NewColumn("users", "deleted_at")
}
func (u user) columns() database.Columns {
return database.Columns{
u.InstanceIDColumn(),
u.OrgIDColumn(),
u.IDColumn(),
u.UsernameColumn(),
u.CreatedAtColumn(),
u.UpdatedAtColumn(),
u.DeletedAtColumn(),
}
}
func scanUser(scanner database.Scanner) (*domain.User, error) {
var (
user domain.User
human domain.Human
email domain.Email
phone domain.Phone
machine domain.Machine
typ domain.UserType
)
err := scanner.Scan(
&user.InstanceID,
&user.OrgID,
&user.ID,
&user.Username,
&typ,
&user.CreatedAt,
&user.UpdatedAt,
&user.DeletedAt,
&human.FirstName,
&human.LastName,
&email.Address,
&email.VerifiedAt,
&phone.Number,
&phone.VerifiedAt,
&machine.Description,
)
if err != nil {
return nil, err
}
switch typ {
case domain.UserTypeHuman:
if email.Address != "" {
human.Email = &email
}
if phone.Number != "" {
human.Phone = &phone
}
user.Traits = &human
case domain.UserTypeMachine:
user.Traits = &machine
}
return &user, nil
}

View File

@@ -0,0 +1,208 @@
package repository
import (
"context"
"time"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
type userHuman struct {
*user
}
var _ domain.HumanRepository = (*userHuman)(nil)
const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_humans h`
// GetEmail implements [domain.HumanRepository].
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
var email domain.Email
builder := database.StatementBuilder{}
builder.WriteString(userEmailQuery)
writeCondition(&builder, condition)
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
&email.Address,
&email.VerifiedAt,
)
if err != nil {
return nil, err
}
return &email, nil
}
// Update implements [domain.HumanRepository].
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
stmt := builder.String()
_, err := h.client.Exec(ctx, stmt, builder.Args()...)
return err
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetFirstName implements [domain.humanChanges].
func (h userHuman) SetFirstName(firstName string) database.Change {
return database.NewChange(h.FirstNameColumn(), firstName)
}
// SetLastName implements [domain.humanChanges].
func (h userHuman) SetLastName(lastName string) database.Change {
return database.NewChange(h.LastNameColumn(), lastName)
}
// SetEmail implements [domain.humanChanges].
func (h userHuman) SetEmail(address string, verified *time.Time) database.Change {
return database.NewChanges(
h.SetEmailAddress(address),
database.NewChangePtr(h.EmailVerifiedAtColumn(), verified),
)
}
// SetEmailAddress implements [domain.humanChanges].
func (h userHuman) SetEmailAddress(address string) database.Change {
return database.NewChange(h.EmailAddressColumn(), address)
}
// SetEmailVerifiedAt implements [domain.humanChanges].
func (h userHuman) SetEmailVerifiedAt(at time.Time) database.Change {
if at.IsZero() {
return database.NewChange(h.EmailVerifiedAtColumn(), database.NowInstruction)
}
return database.NewChange(h.EmailVerifiedAtColumn(), at)
}
// SetPhone implements [domain.humanChanges].
func (h userHuman) SetPhone(number string, verifiedAt *time.Time) database.Change {
return database.NewChanges(
h.SetPhoneNumber(number),
database.NewChangePtr(h.PhoneVerifiedAtColumn(), verifiedAt),
)
}
// SetPhoneNumber implements [domain.humanChanges].
func (h userHuman) SetPhoneNumber(number string) database.Change {
return database.NewChange(h.PhoneNumberColumn(), number)
}
// SetPhoneVerifiedAt implements [domain.humanChanges].
func (h userHuman) SetPhoneVerifiedAt(at time.Time) database.Change {
if at.IsZero() {
return database.NewChange(h.PhoneVerifiedAtColumn(), database.NowInstruction)
}
return database.NewChange(h.PhoneVerifiedAtColumn(), at)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// FirstNameCondition implements [domain.humanConditions].
func (h userHuman) FirstNameCondition(op database.TextOperation, firstName string) database.Condition {
return database.NewTextCondition(h.FirstNameColumn(), op, firstName)
}
// LastNameCondition implements [domain.humanConditions].
func (h userHuman) LastNameCondition(op database.TextOperation, lastName string) database.Condition {
return database.NewTextCondition(h.LastNameColumn(), op, lastName)
}
// EmailAddressCondition implements [domain.humanConditions].
func (h userHuman) EmailAddressCondition(op database.TextOperation, email string) database.Condition {
return database.NewTextCondition(h.EmailAddressColumn(), op, email)
}
// EmailVerifiedCondition implements [domain.humanConditions].
func (h *userHuman) EmailVerifiedCondition(isVerified bool) database.Condition {
if isVerified {
return database.IsNotNull(h.EmailVerifiedAtColumn())
}
return database.IsNull(h.EmailVerifiedAtColumn())
}
// EmailVerifiedAtCondition implements [domain.humanConditions].
func (h userHuman) EmailVerifiedAtCondition(op database.NumberOperation, verifiedAt time.Time) database.Condition {
return database.NewNumberCondition(h.EmailVerifiedAtColumn(), op, verifiedAt)
}
// PhoneNumberCondition implements [domain.humanConditions].
func (h userHuman) PhoneNumberCondition(op database.TextOperation, phoneNumber string) database.Condition {
return database.NewTextCondition(h.PhoneNumberColumn(), op, phoneNumber)
}
// PhoneVerifiedCondition implements [domain.humanConditions].
func (h userHuman) PhoneVerifiedCondition(isVerified bool) database.Condition {
if isVerified {
return database.IsNotNull(h.PhoneVerifiedAtColumn())
}
return database.IsNull(h.PhoneVerifiedAtColumn())
}
// PhoneVerifiedAtCondition implements [domain.humanConditions].
func (h userHuman) PhoneVerifiedAtCondition(op database.NumberOperation, verifiedAt time.Time) database.Condition {
return database.NewNumberCondition(h.PhoneVerifiedAtColumn(), op, verifiedAt)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// FirstNameColumn implements [domain.humanColumns].
func (h userHuman) FirstNameColumn() database.Column {
return database.NewColumn("user_humans", "first_name")
}
// LastNameColumn implements [domain.humanColumns].
func (h userHuman) LastNameColumn() database.Column {
return database.NewColumn("user_humans", "last_name")
}
// EmailAddressColumn implements [domain.humanColumns].
func (h userHuman) EmailAddressColumn() database.Column {
return database.NewColumn("user_humans", "email_address")
}
// EmailVerifiedAtColumn implements [domain.humanColumns].
func (h userHuman) EmailVerifiedAtColumn() database.Column {
return database.NewColumn("user_humans", "email_verified_at")
}
// PhoneNumberColumn implements [domain.humanColumns].
func (h userHuman) PhoneNumberColumn() database.Column {
return database.NewColumn("user_humans", "phone_number")
}
// PhoneVerifiedAtColumn implements [domain.humanColumns].
func (h userHuman) PhoneVerifiedAtColumn() database.Column {
return database.NewColumn("user_humans", "phone_verified_at")
}
// func (h userHuman) columns() database.Columns {
// return append(h.user.columns(),
// h.FirstNameColumn(),
// h.LastNameColumn(),
// h.EmailAddressColumn(),
// h.EmailVerifiedAtColumn(),
// h.PhoneNumberColumn(),
// h.PhoneVerifiedAtColumn(),
// )
// }
// func (h userHuman) writeReturning(builder *database.StatementBuilder) {
// builder.WriteString(" RETURNING ")
// h.columns().Write(builder)
// }

View File

@@ -0,0 +1,67 @@
package repository
import (
"context"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
type userMachine struct {
*user
}
var _ domain.MachineRepository = (*userMachine)(nil)
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
// Update implements [domain.MachineRepository].
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&builder)
writeCondition(&builder, condition)
m.writeReturning()
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetDescription implements [domain.machineChanges].
func (m userMachine) SetDescription(description string) database.Change {
return database.NewChange(m.DescriptionColumn(), description)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// DescriptionCondition implements [domain.machineConditions].
func (m userMachine) DescriptionCondition(op database.TextOperation, description string) database.Condition {
return database.NewTextCondition(m.DescriptionColumn(), op, description)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// DescriptionColumn implements [domain.machineColumns].
func (m userMachine) DescriptionColumn() database.Column {
return database.NewColumn("user_machines", "description")
}
func (m userMachine) columns() database.Columns {
return append(m.user.columns(), m.DescriptionColumn())
}
func (m *userMachine) writeReturning() {
builder := database.StatementBuilder{}
builder.WriteString(" RETURNING ")
m.columns().WriteQualified(&builder)
}

View File

@@ -0,0 +1,68 @@
package database
import (
"strconv"
"strings"
)
type Instruction string
const (
DefaultInstruction Instruction = "DEFAULT"
NowInstruction Instruction = "NOW()"
NullInstruction Instruction = "NULL"
)
// StatementBuilder is a helper to build SQL statement.
type StatementBuilder struct {
strings.Builder
args []any
existingArgs map[any]string
}
// WriteArgs adds the argument to the statement and writes the placeholder to the query.
func (b *StatementBuilder) WriteArg(arg any) {
b.WriteString(b.AppendArg(arg))
}
// WriteArgs adds the arguments to the statement and writes the placeholders to the query.
// The placeholders are comma separated.
func (b *StatementBuilder) WriteArgs(args ...any) {
for i, arg := range args {
if i > 0 {
b.WriteString(", ")
}
b.WriteArg(arg)
}
}
// AppendArg adds the argument to the statement and returns the placeholder.
func (b *StatementBuilder) AppendArg(arg any) (placeholder string) {
if b.existingArgs == nil {
b.existingArgs = make(map[any]string)
}
if placeholder, ok := b.existingArgs[arg]; ok {
return placeholder
}
if instruction, ok := arg.(Instruction); ok {
return string(instruction)
}
b.args = append(b.args, arg)
placeholder = "$" + strconv.Itoa(len(b.args))
b.existingArgs[arg] = placeholder
return placeholder
}
// AppendArgs adds the arguments to the statement and doesn't return the placeholders.
// If an argument is already added, it will not be added again.
func (b *StatementBuilder) AppendArgs(args ...any) {
for _, arg := range args {
b.AppendArg(arg)
}
}
// Args returns the arguments added to the statement.
func (b *StatementBuilder) Args() []any {
return b.args
}

Some files were not shown because too many files have changed in this diff Show More