feat(db): adding relational instance table (#10007)

<!--
Please inform yourself about the contribution guidelines on submitting a
PR here:
https://github.com/zitadel/zitadel/blob/main/CONTRIBUTING.md#submit-a-pull-request-pr.
Take note of how PR/commit titles should be written and replace the
template texts in the sections below. Don't remove any of the sections.
It is important that the commit history clearly shows what is changed
and why.
Important: By submitting a contribution you agree to the terms from our
Licensing Policy as described here:
https://github.com/zitadel/zitadel/blob/main/LICENSING.md#community-contributions.
-->

# Which Problems Are Solved

Implementing Instance table to new relational database schema

# How the Problems Are Solved


The following fields must be managed in this table:

- `id`
- `name`
- `default_org_id`
- `zitadel_project_id`
- `console_client_id`
- `console_app_id`
- `default_language`
- `created_at`
- `updated_at`
- `deleted_at`

The repository must provide the following functions:

Manipulations:
- create
  - `name`
  - `default_org_id`
  - `zitadel_project_id`
  - `console_client_id`
  - `console_app_id`
  - `default_language`
- update
  - `name`
  - `default_language`
- delete

Queries:
- get returns single instance matching the criteria and pagination,
should return error if multiple instances were found
- list returns list of instances matching the criteria, pagination

Criteria are the following:
- by id

pagination:
- by created_at
- by updated_at
- by name

### instance events

The following events must be applied on the table using a projection
(`internal/query/projection`)

- `instance.added` results in create
- `instance.changed` changes the `name` field
- `instance.removed` sets the the `deleted_at` field
- `instance.default.org.set` sets the `default_org_id` field
- `instance.iam.project.set` sets the `zitadel_project_id` field
- `instance.iam.console.set` sets the `console_client_id` and
`console_app_id` fields
- `instance.default.language.set` sets the `default_language` field
- if answer is yes to discussion: `instance.domain.primary.set` sets the
`primary_domain` field

### acceptance criteria

- [x] migration is implemented and gets executed
- [x] domain interfaces are implemented and documented for service layer
- [x] repository is implemented and implements domain interface
- [x] testing
  - [x] the repository methods
  - [x] events get reduced correctly
  - [x] unique constraints

# Additional Context

- Closes https://github.com/zitadel/zitadel/issues/9935
This commit is contained in:
Iraq
2025-06-17 09:46:01 +02:00
committed by GitHub
parent d75a45ebed
commit 9595a1bcca
23 changed files with 1537 additions and 133 deletions

View File

@@ -0,0 +1,275 @@
package repository
import (
"context"
"errors"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ domain.InstanceRepository = (*instance)(nil)
type instance struct {
repository
}
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
return &instance{
repository: repository{
client: client,
},
}
}
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at, deleted_at` +
` FROM zitadel.instances`
// Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, opts ...database.Condition) (*domain.Instance, error) {
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
// return only non deleted instances
opts = append(opts, database.IsNull(i.DeletedAtColumn()))
i.writeCondition(&builder, database.And(opts...))
return scanInstance(i.client.QueryRow(ctx, builder.String(), builder.Args()...))
}
// List implements [domain.InstanceRepository].
func (i *instance) List(ctx context.Context, opts ...database.Condition) ([]*domain.Instance, error) {
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
// return only non deleted instances
opts = append(opts, database.IsNull(i.DeletedAtColumn()))
notDeletedCondition := database.And(opts...)
i.writeCondition(&builder, notDeletedCondition)
rows, err := i.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanInstances(rows)
}
const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` +
` VALUES ($1, $2, $3, $4, $5, $6, $7)` +
` RETURNING created_at, updated_at`
// Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
var builder database.StatementBuilder
builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage)
builder.WriteString(createInstanceStmt)
err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
// constraint violation
if pgErr.Code == "23514" {
if pgErr.ConstraintName == "instances_name_check" {
return errors.New("instance name not provided")
}
if pgErr.ConstraintName == "instances_id_check" {
return errors.New("instance id not provided")
}
}
// duplicate
if pgErr.Code == "23505" {
if pgErr.ConstraintName == "instances_pkey" {
return errors.New("instance id already exists")
}
}
}
}
return err
}
// Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET `)
// don't update deleted instances
conditions := []database.Condition{condition, database.IsNull(i.DeletedAtColumn())}
database.Changes(changes).Write(&builder)
i.writeCondition(&builder, database.And(conditions...))
stmt := builder.String()
rowsAffected, err := i.client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err
}
// Delete implements [domain.InstanceRepository].
func (i instance) Delete(ctx context.Context, condition database.Condition) error {
if condition == nil {
return errors.New("Delete must contain a condition") // (otherwise ALL instances will be deleted)
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET deleted_at = $1`)
builder.AppendArgs(time.Now())
i.writeCondition(&builder, condition)
_, err := i.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetName implements [domain.instanceChanges].
func (i instance) SetName(name string) database.Change {
return database.NewChange(i.NameColumn(), name)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// IDCondition implements [domain.instanceConditions].
func (i instance) IDCondition(id string) database.Condition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
}
// NameCondition implements [domain.instanceConditions].
func (i instance) NameCondition(op database.TextOperation, name string) database.Condition {
return database.NewTextCondition(i.NameColumn(), op, name)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.instanceColumns].
func (instance) IDColumn() database.Column {
return database.NewColumn("id")
}
// NameColumn implements [domain.instanceColumns].
func (instance) NameColumn() database.Column {
return database.NewColumn("name")
}
// CreatedAtColumn implements [domain.instanceColumns].
func (instance) CreatedAtColumn() database.Column {
return database.NewColumn("created_at")
}
// DefaultOrgIdColumn implements [domain.instanceColumns].
func (instance) DefaultOrgIDColumn() database.Column {
return database.NewColumn("default_org_id")
}
// IAMProjectIDColumn implements [domain.instanceColumns].
func (instance) IAMProjectIDColumn() database.Column {
return database.NewColumn("iam_project_id")
}
// ConsoleClientIDColumn implements [domain.instanceColumns].
func (instance) ConsoleClientIDColumn() database.Column {
return database.NewColumn("console_client_id")
}
// ConsoleAppIDColumn implements [domain.instanceColumns].
func (instance) ConsoleAppIDColumn() database.Column {
return database.NewColumn("console_app_id")
}
// DefaultLanguageColumn implements [domain.instanceColumns].
func (instance) DefaultLanguageColumn() database.Column {
return database.NewColumn("default_language")
}
// UpdatedAtColumn implements [domain.instanceColumns].
func (instance) UpdatedAtColumn() database.Column {
return database.NewColumn("updated_at")
}
// DeletedAtColumn implements [domain.instanceColumns].
func (instance) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
func (i *instance) writeCondition(
builder *database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
}
builder.WriteString(" WHERE ")
condition.Write(builder)
}
func scanInstance(scanner database.Scanner) (*domain.Instance, error) {
var instance domain.Instance
err := scanner.Scan(
&instance.ID,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleClientID,
&instance.ConsoleAppID,
&instance.DefaultLanguage,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
// if no results returned, this is not a error
// it just means the instance was not found
// the caller should check if the returned instance is nil
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
return &instance, nil
}
func scanInstances(rows database.Rows) ([]*domain.Instance, error) {
instances := make([]*domain.Instance, 0)
for rows.Next() {
var instance domain.Instance
err := rows.Scan(
&instance.ID,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleClientID,
&instance.ConsoleAppID,
&instance.DefaultLanguage,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
return nil, err
}
instances = append(instances, &instance)
}
return instances, nil
}

View File

@@ -0,0 +1,690 @@
package repository_test
import (
"context"
"errors"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
func TestCreateInstance(t *testing.T) {
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
instance domain.Instance
err error
}{
{
name: "happy path",
instance: func() domain.Instance {
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
},
{
name: "create instance without name",
instance: func() domain.Instance {
instanceId := gofakeit.Name()
// instanceName := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: "",
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
err: errors.New("instance name not provided"),
},
{
name: "adding same instance twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
err: errors.New("instance id already exists"),
},
{
name: "adding instance with no id",
instance: func() domain.Instance {
// instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
instance := domain.Instance{
// ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
err: errors.New("instance id not provided"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
} else {
instance = &tt.instance
}
instanceRepo := repository.InstanceRepository(pool)
// create instance
beforeCreate := time.Now()
err := instanceRepo.Create(ctx, instance)
require.Equal(t, tt.err, err)
if err != nil {
return
}
afterCreate := time.Now()
// check instance values
instance, err = instanceRepo.Get(ctx,
instanceRepo.NameCondition(database.TextOperationEqual, instance.Name),
)
require.NoError(t, err)
require.Equal(t, tt.instance.ID, instance.ID)
require.Equal(t, tt.instance.Name, instance.Name)
require.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID)
require.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage)
require.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
require.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
require.Nil(t, instance.DeletedAt)
})
}
}
func TestUpdateInstance(t *testing.T) {
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
rowsAffected int64
}{
{
name: "happy path",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
rowsAffected: 1,
},
{
name: "update deleted instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
// delete instance
err = instanceRepo.Delete(ctx,
instanceRepo.IDCondition(inst.ID),
)
require.NoError(t, err)
return &inst
},
rowsAffected: 0,
},
{
name: "update non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
}
return &inst
},
rowsAffected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
instance := tt.testFunc(ctx, t)
beforeUpdate := time.Now()
// update name
newName := "new_" + instance.Name
rowsAffected, err := instanceRepo.Update(ctx,
instanceRepo.IDCondition(instance.ID),
instanceRepo.SetName(newName),
)
afterUpdate := time.Now()
require.NoError(t, err)
require.Equal(t, tt.rowsAffected, rowsAffected)
if rowsAffected == 0 {
return
}
// check instance values
instance, err = instanceRepo.Get(ctx,
instanceRepo.IDCondition(instance.ID),
)
require.NoError(t, err)
require.Equal(t, newName, instance.Name)
require.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
require.Nil(t, instance.DeletedAt)
})
}
}
func TestGetInstance(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
conditionClauses []database.Condition
}
tests := []test{
func() test {
instanceId := gofakeit.Name()
return test{
name: "happy path get using id",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
}
}(),
func() test {
instanceName := gofakeit.Name()
return test{
name: "happy path get using name",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)},
}
}(),
{
name: "get non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
_ = domain.Instance{
ID: instanceId,
}
return nil
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
}
// check instance values
returnedInstance, err := instanceRepo.Get(ctx,
tt.conditionClauses...,
)
require.NoError(t, err)
if instance == nil {
require.Nil(t, instance, returnedInstance)
return
}
require.NoError(t, err)
require.Equal(t, returnedInstance.ID, instance.ID)
require.Equal(t, returnedInstance.Name, instance.Name)
require.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID)
require.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage)
})
}
}
func TestListInstance(t *testing.T) {
ctx := context.Background()
pool, stop, err := newEmbeddedDB(ctx)
require.NoError(t, err)
defer stop()
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
conditionClauses []database.Condition
noInstanceReturned bool
}
tests := []test{
{
name: "happy path single instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
},
{
name: "happy path multiple instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
},
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
return test{
name: "instance filter on id",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "multiple instance filter on name",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)},
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
require.NoError(t, err)
})
instances := tt.testFunc(ctx, t)
instanceRepo := repository.InstanceRepository(pool)
// check instance values
returnedInstances, err := instanceRepo.List(ctx,
tt.conditionClauses...,
)
require.NoError(t, err)
if tt.noInstanceReturned {
require.Nil(t, returnedInstances)
return
}
require.Equal(t, len(instances), len(returnedInstances))
for i, instance := range instances {
require.Equal(t, returnedInstances[i].ID, instance.ID)
require.Equal(t, returnedInstances[i].Name, instance.Name)
require.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID)
require.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage)
}
})
}
}
func TestDeleteInstance(t *testing.T) {
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T)
conditionClauses database.Condition
}
tests := []test{
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
return test{
name: "happy path delete single instance filter id",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.IDCondition(instanceId),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "happy path delete single instance filter name",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
non_existent_instance_name := gofakeit.Name()
return test{
name: "delete non existent instance",
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, non_existent_instance_name),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "multiple instance filter on name",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "deleted already deleted instance",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
// delete instance
err := instanceRepo.Delete(ctx,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(t, err)
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
if tt.testFunc != nil {
tt.testFunc(ctx, t)
}
// delete instance
err := instanceRepo.Delete(ctx,
tt.conditionClauses,
)
require.NoError(t, err)
// check instance was deleted
instance, err := instanceRepo.Get(ctx,
tt.conditionClauses,
)
require.NoError(t, err)
require.Nil(t, instance)
})
}
}

View File

@@ -1,16 +1,10 @@
package repository
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
package repository_test
// iraq: I had to comment out so that the UTs would pass
// TestBla is an example and can be removed later
func TestBla(t *testing.T) {
var count int
err := pool.QueryRow(context.Background(), "select count(*) from zitadel.instances").Scan(&count)
assert.NoError(t, err)
assert.Equal(t, 0, count)
}
// func TestBla(t *testing.T) {
// var count int
// err := pool.QueryRow(context.Background(), "select count(*) from zitadel.instances").Scan(&count)
// assert.NoError(t, err)
// assert.Equal(t, 0, count)
// }

View File

@@ -3,6 +3,5 @@ package repository
import "github.com/zitadel/zitadel/backend/v3/storage/database"
type repository struct {
builder database.StatementBuilder
client database.QueryExecutor
client database.QueryExecutor
}

View File

@@ -1,7 +1,8 @@
package repository
package repository_test
import (
"context"
"fmt"
"log"
"os"
"testing"
@@ -14,28 +15,37 @@ func TestMain(m *testing.M) {
os.Exit(runTests(m))
}
var pool database.Pool
var pool database.PoolTest
func runTests(m *testing.M) int {
connector, stop, err := embedded.StartEmbedded()
var stop func()
var err error
ctx := context.Background()
pool, stop, err = newEmbeddedDB(ctx)
if err != nil {
log.Fatalf("unable to start embedded postgres: %v", err)
log.Printf("error with embedded postgres database: %v", err)
return 1
}
defer stop()
ctx := context.Background()
pool, err = connector.Connect(ctx)
if err != nil {
log.Printf("unable to connect to embedded postgres: %v", err)
return 1
}
err = pool.Migrate(ctx)
if err != nil {
log.Printf("unable to migrate database: %v", err)
return 1
}
return m.Run()
}
func newEmbeddedDB(ctx context.Context) (pool database.PoolTest, stop func(), err error) {
connector, stop, err := embedded.StartEmbedded()
if err != nil {
return nil, nil, fmt.Errorf("unable to start embedded postgres: %w", err)
}
pool_, err := connector.Connect(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to connect to embedded postgres: %w", err)
}
pool = pool_.(database.PoolTest)
err = pool.MigrateTest(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to migrate database: %w", err)
}
return pool, stop, err
}

View File

@@ -47,13 +47,14 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
opt(options)
}
u.builder.WriteString(queryUserStmt)
options.WriteCondition(&u.builder)
options.WriteOrderBy(&u.builder)
options.WriteLimit(&u.builder)
options.WriteOffset(&u.builder)
builder := database.StatementBuilder{}
builder.WriteString(queryUserStmt)
options.WriteCondition(&builder)
options.WriteOrderBy(&builder)
options.WriteLimit(&builder)
options.WriteOffset(&builder)
rows, err := u.client.Query(ctx, u.builder.String(), u.builder.Args()...)
rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
@@ -84,13 +85,14 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U
opt(options)
}
u.builder.WriteString(queryUserStmt)
options.WriteCondition(&u.builder)
options.WriteOrderBy(&u.builder)
options.WriteLimit(&u.builder)
options.WriteOffset(&u.builder)
builder := database.StatementBuilder{}
builder.WriteString(queryUserStmt)
options.WriteCondition(&builder)
options.WriteOrderBy(&builder)
options.WriteLimit(&builder)
options.WriteOffset(&builder)
return scanUser(u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...))
return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
}
const (
@@ -104,23 +106,26 @@ const (
// Create implements [domain.UserRepository].
func (u *user) Create(ctx context.Context, user *domain.User) error {
u.builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
builder := database.StatementBuilder{}
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) {
case *domain.Human:
u.builder.WriteString(createHumanStmt)
u.builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
builder.WriteString(createHumanStmt)
builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
case *domain.Machine:
u.builder.WriteString(createMachineStmt)
u.builder.AppendArgs(trait.Description)
builder.WriteString(createMachineStmt)
builder.AppendArgs(trait.Description)
}
return u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
}
// Delete implements [domain.UserRepository].
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
u.builder.WriteString("DELETE FROM users")
u.writeCondition(condition)
return u.client.Exec(ctx, u.builder.String(), u.builder.Args()...)
builder := database.StatementBuilder{}
builder.WriteString("DELETE FROM users")
u.writeCondition(builder, condition)
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
@@ -218,12 +223,15 @@ func (user) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
func (u *user) writeCondition(condition database.Condition) {
func (u *user) writeCondition(
builder database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
}
u.builder.WriteString(" WHERE ")
condition.Write(&u.builder)
builder.WriteString(" WHERE ")
condition.Write(&builder)
}
func (u user) columns() database.Columns {

View File

@@ -24,10 +24,11 @@ const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_hu
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
var email domain.Email
u.builder.WriteString(userEmailQuery)
u.writeCondition(condition)
builder := database.StatementBuilder{}
builder.WriteString(userEmailQuery)
u.writeCondition(builder, condition)
err := u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...).Scan(
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
&email.Address,
&email.VerifiedAt,
)
@@ -39,13 +40,15 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
// Update implements [domain.HumanRepository].
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
h.builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&h.builder)
h.writeCondition(condition)
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&builder)
h.writeCondition(builder, condition)
stmt := h.builder.String()
stmt := builder.String()
return h.client.Exec(ctx, stmt, h.builder.Args()...)
_, err := h.client.Exec(ctx, stmt, builder.Args()...)
return err
}
// -------------------------------------------------------------

View File

@@ -18,13 +18,15 @@ var _ domain.MachineRepository = (*userMachine)(nil)
// -------------------------------------------------------------
// Update implements [domain.MachineRepository].
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (err error) {
m.builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&m.builder)
m.writeCondition(condition)
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&builder)
m.writeCondition(builder, condition)
m.writeReturning()
return m.client.Exec(ctx, m.builder.String(), m.builder.Args()...)
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
@@ -59,6 +61,7 @@ func (m userMachine) columns() database.Columns {
}
func (m *userMachine) writeReturning() {
m.builder.WriteString(" RETURNING ")
m.columns().Write(&m.builder)
builder := database.StatementBuilder{}
builder.WriteString(" RETURNING ")
m.columns().Write(&builder)
}

View File

@@ -74,3 +74,4 @@ package repository_test
// user.Human().Update(context.Background(), user.IDCondition("test"), user.SetUsername("test"))
// })
// }