feat(db): adding relational instance table (#10007)

<!--
Please inform yourself about the contribution guidelines on submitting a
PR here:
https://github.com/zitadel/zitadel/blob/main/CONTRIBUTING.md#submit-a-pull-request-pr.
Take note of how PR/commit titles should be written and replace the
template texts in the sections below. Don't remove any of the sections.
It is important that the commit history clearly shows what is changed
and why.
Important: By submitting a contribution you agree to the terms from our
Licensing Policy as described here:
https://github.com/zitadel/zitadel/blob/main/LICENSING.md#community-contributions.
-->

# Which Problems Are Solved

Implementing Instance table to new relational database schema

# How the Problems Are Solved


The following fields must be managed in this table:

- `id`
- `name`
- `default_org_id`
- `zitadel_project_id`
- `console_client_id`
- `console_app_id`
- `default_language`
- `created_at`
- `updated_at`
- `deleted_at`

The repository must provide the following functions:

Manipulations:
- create
  - `name`
  - `default_org_id`
  - `zitadel_project_id`
  - `console_client_id`
  - `console_app_id`
  - `default_language`
- update
  - `name`
  - `default_language`
- delete

Queries:
- get returns single instance matching the criteria and pagination,
should return error if multiple instances were found
- list returns list of instances matching the criteria, pagination

Criteria are the following:
- by id

pagination:
- by created_at
- by updated_at
- by name

### instance events

The following events must be applied on the table using a projection
(`internal/query/projection`)

- `instance.added` results in create
- `instance.changed` changes the `name` field
- `instance.removed` sets the the `deleted_at` field
- `instance.default.org.set` sets the `default_org_id` field
- `instance.iam.project.set` sets the `zitadel_project_id` field
- `instance.iam.console.set` sets the `console_client_id` and
`console_app_id` fields
- `instance.default.language.set` sets the `default_language` field
- if answer is yes to discussion: `instance.domain.primary.set` sets the
`primary_domain` field

### acceptance criteria

- [x] migration is implemented and gets executed
- [x] domain interfaces are implemented and documented for service layer
- [x] repository is implemented and implements domain interface
- [x] testing
  - [x] the repository methods
  - [x] events get reduced correctly
  - [x] unique constraints

# Additional Context

- Closes https://github.com/zitadel/zitadel/issues/9935
This commit is contained in:
Iraq
2025-06-17 09:46:01 +02:00
committed by GitHub
parent d75a45ebed
commit 9595a1bcca
23 changed files with 1537 additions and 133 deletions

View File

@@ -14,6 +14,12 @@ type Pool interface {
Close(ctx context.Context) error
}
type PoolTest interface {
Pool
// MigrateTest is the same as [Migrator] but executes the migrations multiple times instead of only once.
MigrateTest(ctx context.Context) error
}
// Client is a single database connection which can be released back to the pool.
type Client interface {
Beginner
@@ -30,8 +36,9 @@ type Querier interface {
}
// Executor is a database client that can execute statements.
// It returns the number of rows affected or an error
type Executor interface {
Exec(ctx context.Context, stmt string, args ...any) error
Exec(ctx context.Context, stmt string, args ...any) (int64, error)
}
// QueryExecutor is a database client that can execute queries and statements.

View File

@@ -157,15 +157,16 @@ func (c *MockPoolCloseCall) DoAndReturn(f func(context.Context) error) *MockPool
}
// Exec mocks base method.
func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) error {
func (m *MockPool) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
@@ -182,19 +183,19 @@ type MockPoolExecCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockPoolExecCall) Return(arg0 error) *MockPoolExecCall {
c.Call = c.Call.Return(arg0)
func (c *MockPoolExecCall) Return(arg0 int64, arg1 error) *MockPoolExecCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockPoolExecCall) Do(f func(context.Context, string, ...any) error) *MockPoolExecCall {
func (c *MockPoolExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockPoolExecCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockPoolExecCall {
func (c *MockPoolExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockPoolExecCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
@@ -387,15 +388,16 @@ func (c *MockClientBeginCall) DoAndReturn(f func(context.Context, *database.Tran
}
// Exec mocks base method.
func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) error {
func (m *MockClient) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
@@ -412,19 +414,19 @@ type MockClientExecCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockClientExecCall) Return(arg0 error) *MockClientExecCall {
c.Call = c.Call.Return(arg0)
func (c *MockClientExecCall) Return(arg0 int64, arg1 error) *MockClientExecCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) error) *MockClientExecCall {
func (c *MockClientExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockClientExecCall {
func (c *MockClientExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockClientExecCall {
c.Call = c.Call.DoAndReturn(f)
return c
}
@@ -975,15 +977,16 @@ func (c *MockTransactionEndCall) DoAndReturn(f func(context.Context, error) erro
}
// Exec mocks base method.
func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) error {
func (m *MockTransaction) Exec(arg0 context.Context, arg1 string, arg2 ...any) (int64, error) {
m.ctrl.T.Helper()
varargs := []any{arg0, arg1}
for _, a := range arg2 {
varargs = append(varargs, a)
}
ret := m.ctrl.Call(m, "Exec", varargs...)
ret0, _ := ret[0].(error)
return ret0
ret0, _ := ret[0].(int64)
ret1, _ := ret[1].(error)
return ret0, ret1
}
// Exec indicates an expected call of Exec.
@@ -1000,19 +1003,19 @@ type MockTransactionExecCall struct {
}
// Return rewrite *gomock.Call.Return
func (c *MockTransactionExecCall) Return(arg0 error) *MockTransactionExecCall {
c.Call = c.Call.Return(arg0)
func (c *MockTransactionExecCall) Return(arg0 int64, arg1 error) *MockTransactionExecCall {
c.Call = c.Call.Return(arg0, arg1)
return c
}
// Do rewrite *gomock.Call.Do
func (c *MockTransactionExecCall) Do(f func(context.Context, string, ...any) error) *MockTransactionExecCall {
func (c *MockTransactionExecCall) Do(f func(context.Context, string, ...any) (int64, error)) *MockTransactionExecCall {
c.Call = c.Call.Do(f)
return c
}
// DoAndReturn rewrite *gomock.Call.DoAndReturn
func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ...any) error) *MockTransactionExecCall {
func (c *MockTransactionExecCall) DoAndReturn(f func(context.Context, string, ...any) (int64, error)) *MockTransactionExecCall {
c.Call = c.Call.DoAndReturn(f)
return c
}

View File

@@ -13,9 +13,7 @@ type pgxConn struct {
*pgxpool.Conn
}
var (
_ database.Client = (*pgxConn)(nil)
)
var _ database.Client = (*pgxConn)(nil)
// Release implements [database.Client].
func (c *pgxConn) Release(_ context.Context) error {
@@ -47,9 +45,9 @@ func (c *pgxConn) QueryRow(ctx context.Context, sql string, args ...any) databas
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Conn.Exec(ctx, sql, args...)
return err
func (c *pgxConn) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Conn.Exec(ctx, sql, args...)
return res.RowsAffected(), err
}
// Migrate implements [database.Migrator].

View File

@@ -1,6 +1,26 @@
CREATE TABLE IF NOT EXISTS zitadel.instances(
id VARCHAR(100) NOT NULL
, PRIMARY KEY (id)
id TEXT NOT NULL CHECK (id <> '') PRIMARY KEY,
name TEXT NOT NULL CHECK (name <> ''),
default_org_id TEXT, -- NOT NULL,
iam_project_id TEXT, -- NOT NULL,
console_client_id TEXT, -- NOT NULL,
console_app_id TEXT, -- NOT NULL,
default_language TEXT, -- NOT NULL,
created_at TIMESTAMPTZ DEFAULT NOW(),
updated_at TIMESTAMPTZ DEFAULT NOW(),
deleted_at TIMESTAMPTZ DEFAULT NULL
);
, name VARCHAR(100) NOT NULL
);
CREATE OR REPLACE FUNCTION zitadel.set_updated_at()
RETURNS TRIGGER AS $$
BEGIN
NEW.updated_at := NOW();
RETURN NEW;
END;
$$ LANGUAGE plpgsql;
CREATE TRIGGER trigger_set_updated_at
BEFORE UPDATE ON zitadel.instances
FOR EACH ROW
WHEN (OLD.updated_at IS NOT DISTINCT FROM NEW.updated_at)
EXECUTE FUNCTION zitadel.set_updated_at();

View File

@@ -13,9 +13,13 @@ type pgxPool struct {
*pgxpool.Pool
}
var (
_ database.Pool = (*pgxPool)(nil)
)
var _ database.Pool = (*pgxPool)(nil)
func PGxPool(pool *pgxpool.Pool) *pgxPool {
return &pgxPool{
Pool: pool,
}
}
// Acquire implements [database.Pool].
func (c *pgxPool) Acquire(ctx context.Context) (database.Client, error) {
@@ -41,9 +45,9 @@ func (c *pgxPool) QueryRow(ctx context.Context, sql string, args ...any) databas
// Exec implements [database.Pool].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) error {
_, err := c.Pool.Exec(ctx, sql, args...)
return err
func (c *pgxPool) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := c.Pool.Exec(ctx, sql, args...)
return res.RowsAffected(), err
}
// Begin implements [database.Pool].
@@ -76,3 +80,15 @@ func (c *pgxPool) Migrate(ctx context.Context) error {
isMigrated = err == nil
return err
}
// Migrate implements [database.PoolTest].
func (c *pgxPool) MigrateTest(ctx context.Context) error {
client, err := c.Pool.Acquire(ctx)
if err != nil {
return err
}
err = migration.Migrate(ctx, client.Conn())
isMigrated = err == nil
return err
}

View File

@@ -50,9 +50,9 @@ func (tx *pgxTx) QueryRow(ctx context.Context, sql string, args ...any) database
// Exec implements [database.Transaction].
// Subtle: this method shadows the method (Pool).Exec of pgxPool.Pool.
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) error {
_, err := tx.Tx.Exec(ctx, sql, args...)
return err
func (tx *pgxTx) Exec(ctx context.Context, sql string, args ...any) (int64, error) {
res, err := tx.Tx.Exec(ctx, sql, args...)
return res.RowsAffected(), err
}
// Begin implements [database.Transaction].

View File

@@ -0,0 +1,171 @@
//go:build integration
package instance_test
import (
"context"
"os"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/jackc/pgx/v5/pgxpool"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/dialect/postgres"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
"github.com/zitadel/zitadel/internal/integration"
"github.com/zitadel/zitadel/pkg/grpc/system"
)
const ConnString = "host=localhost port=5432 user=zitadel dbname=zitadel sslmode=disable"
var (
dbPool *pgxpool.Pool
CTX context.Context
Instance *integration.Instance
SystemClient system.SystemServiceClient
)
var pool database.Pool
func TestMain(m *testing.M) {
os.Exit(func() int {
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Minute)
defer cancel()
Instance = integration.NewInstance(ctx)
CTX = Instance.WithAuthorization(ctx, integration.UserTypeIAMOwner)
SystemClient = integration.SystemClient()
var err error
dbPool, err = pgxpool.New(context.Background(), ConnString)
if err != nil {
panic(err)
}
pool = postgres.PGxPool(dbPool)
return m.Run()
}())
}
func TestServer_TestInstanceAddReduces(t *testing.T) {
instanceName := gofakeit.Name()
beforeCreate := time.Now()
_, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
})
afterCreate := time.Now()
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(ttt, err)
// event instance.added
require.Equal(ttt, instanceName, instance.Name)
// event instance.default.org.set
require.NotNil(t, instance.DefaultOrgID)
// event instance.iam.project.set
require.NotNil(t, instance.IAMProjectID)
// event instance.iam.console.set
require.NotNil(t, instance.ConsoleAppID)
// event instance.default.language.set
require.NotNil(t, instance.DefaultLanguage)
// event instance.added
assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
// event instance.added
assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
require.Nil(t, instance.DeletedAt)
}, retryDuration, tick)
}
func TestServer_TestInstanceUpdateNameReduces(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
})
require.NoError(t, err)
instanceName += "new"
_, err = SystemClient.UpdateInstance(CTX, &system.UpdateInstanceRequest{
InstanceId: res.InstanceId,
InstanceName: instanceName,
})
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(ttt, err)
// event instance.changed
require.Equal(ttt, instanceName, instance.Name)
}, retryDuration, tick)
}
func TestServer_TestInstanceDeleteReduces(t *testing.T) {
instanceName := gofakeit.Name()
res, err := SystemClient.CreateInstance(CTX, &system.CreateInstanceRequest{
InstanceName: instanceName,
Owner: &system.CreateInstanceRequest_Machine_{
Machine: &system.CreateInstanceRequest_Machine{
UserName: "owner",
Name: "owner",
PersonalAccessToken: &system.CreateInstanceRequest_PersonalAccessToken{},
},
},
})
require.NoError(t, err)
instanceRepo := repository.InstanceRepository(pool)
// check instance exists
retryDuration, tick := integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(ttt, err)
require.Equal(ttt, instanceName, instance.Name)
}, retryDuration, tick)
_, err = SystemClient.RemoveInstance(CTX, &system.RemoveInstanceRequest{
InstanceId: res.InstanceId,
})
require.NoError(t, err)
retryDuration, tick = integration.WaitForAndTickWithMaxDuration(CTX, time.Minute)
assert.EventuallyWithT(t, func(ttt *assert.CollectT) {
instance, err := instanceRepo.Get(CTX,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
// event instance.removed
require.Nil(t, instance)
require.NoError(ttt, err)
}, retryDuration, tick)
}

View File

@@ -0,0 +1,275 @@
package repository
import (
"context"
"errors"
"time"
"github.com/jackc/pgx/v5/pgconn"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
)
var _ domain.InstanceRepository = (*instance)(nil)
type instance struct {
repository
}
func InstanceRepository(client database.QueryExecutor) domain.InstanceRepository {
return &instance{
repository: repository{
client: client,
},
}
}
// -------------------------------------------------------------
// repository
// -------------------------------------------------------------
const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language, created_at, updated_at, deleted_at` +
` FROM zitadel.instances`
// Get implements [domain.InstanceRepository].
func (i *instance) Get(ctx context.Context, opts ...database.Condition) (*domain.Instance, error) {
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
// return only non deleted instances
opts = append(opts, database.IsNull(i.DeletedAtColumn()))
i.writeCondition(&builder, database.And(opts...))
return scanInstance(i.client.QueryRow(ctx, builder.String(), builder.Args()...))
}
// List implements [domain.InstanceRepository].
func (i *instance) List(ctx context.Context, opts ...database.Condition) ([]*domain.Instance, error) {
var builder database.StatementBuilder
builder.WriteString(queryInstanceStmt)
// return only non deleted instances
opts = append(opts, database.IsNull(i.DeletedAtColumn()))
notDeletedCondition := database.And(opts...)
i.writeCondition(&builder, notDeletedCondition)
rows, err := i.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
defer rows.Close()
return scanInstances(rows)
}
const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org_id, iam_project_id, console_client_id, console_app_id, default_language)` +
` VALUES ($1, $2, $3, $4, $5, $6, $7)` +
` RETURNING created_at, updated_at`
// Create implements [domain.InstanceRepository].
func (i *instance) Create(ctx context.Context, instance *domain.Instance) error {
var builder database.StatementBuilder
builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage)
builder.WriteString(createInstanceStmt)
err := i.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&instance.CreatedAt, &instance.UpdatedAt)
if err != nil {
var pgErr *pgconn.PgError
if errors.As(err, &pgErr) {
// constraint violation
if pgErr.Code == "23514" {
if pgErr.ConstraintName == "instances_name_check" {
return errors.New("instance name not provided")
}
if pgErr.ConstraintName == "instances_id_check" {
return errors.New("instance id not provided")
}
}
// duplicate
if pgErr.Code == "23505" {
if pgErr.ConstraintName == "instances_pkey" {
return errors.New("instance id already exists")
}
}
}
}
return err
}
// Update implements [domain.InstanceRepository].
func (i instance) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) {
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET `)
// don't update deleted instances
conditions := []database.Condition{condition, database.IsNull(i.DeletedAtColumn())}
database.Changes(changes).Write(&builder)
i.writeCondition(&builder, database.And(conditions...))
stmt := builder.String()
rowsAffected, err := i.client.Exec(ctx, stmt, builder.Args()...)
return rowsAffected, err
}
// Delete implements [domain.InstanceRepository].
func (i instance) Delete(ctx context.Context, condition database.Condition) error {
if condition == nil {
return errors.New("Delete must contain a condition") // (otherwise ALL instances will be deleted)
}
var builder database.StatementBuilder
builder.WriteString(`UPDATE zitadel.instances SET deleted_at = $1`)
builder.AppendArgs(time.Now())
i.writeCondition(&builder, condition)
_, err := i.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
// changes
// -------------------------------------------------------------
// SetName implements [domain.instanceChanges].
func (i instance) SetName(name string) database.Change {
return database.NewChange(i.NameColumn(), name)
}
// -------------------------------------------------------------
// conditions
// -------------------------------------------------------------
// IDCondition implements [domain.instanceConditions].
func (i instance) IDCondition(id string) database.Condition {
return database.NewTextCondition(i.IDColumn(), database.TextOperationEqual, id)
}
// NameCondition implements [domain.instanceConditions].
func (i instance) NameCondition(op database.TextOperation, name string) database.Condition {
return database.NewTextCondition(i.NameColumn(), op, name)
}
// -------------------------------------------------------------
// columns
// -------------------------------------------------------------
// IDColumn implements [domain.instanceColumns].
func (instance) IDColumn() database.Column {
return database.NewColumn("id")
}
// NameColumn implements [domain.instanceColumns].
func (instance) NameColumn() database.Column {
return database.NewColumn("name")
}
// CreatedAtColumn implements [domain.instanceColumns].
func (instance) CreatedAtColumn() database.Column {
return database.NewColumn("created_at")
}
// DefaultOrgIdColumn implements [domain.instanceColumns].
func (instance) DefaultOrgIDColumn() database.Column {
return database.NewColumn("default_org_id")
}
// IAMProjectIDColumn implements [domain.instanceColumns].
func (instance) IAMProjectIDColumn() database.Column {
return database.NewColumn("iam_project_id")
}
// ConsoleClientIDColumn implements [domain.instanceColumns].
func (instance) ConsoleClientIDColumn() database.Column {
return database.NewColumn("console_client_id")
}
// ConsoleAppIDColumn implements [domain.instanceColumns].
func (instance) ConsoleAppIDColumn() database.Column {
return database.NewColumn("console_app_id")
}
// DefaultLanguageColumn implements [domain.instanceColumns].
func (instance) DefaultLanguageColumn() database.Column {
return database.NewColumn("default_language")
}
// UpdatedAtColumn implements [domain.instanceColumns].
func (instance) UpdatedAtColumn() database.Column {
return database.NewColumn("updated_at")
}
// DeletedAtColumn implements [domain.instanceColumns].
func (instance) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
func (i *instance) writeCondition(
builder *database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
}
builder.WriteString(" WHERE ")
condition.Write(builder)
}
func scanInstance(scanner database.Scanner) (*domain.Instance, error) {
var instance domain.Instance
err := scanner.Scan(
&instance.ID,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleClientID,
&instance.ConsoleAppID,
&instance.DefaultLanguage,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
// if no results returned, this is not a error
// it just means the instance was not found
// the caller should check if the returned instance is nil
if err.Error() == "no rows in result set" {
return nil, nil
}
return nil, err
}
return &instance, nil
}
func scanInstances(rows database.Rows) ([]*domain.Instance, error) {
instances := make([]*domain.Instance, 0)
for rows.Next() {
var instance domain.Instance
err := rows.Scan(
&instance.ID,
&instance.Name,
&instance.DefaultOrgID,
&instance.IAMProjectID,
&instance.ConsoleClientID,
&instance.ConsoleAppID,
&instance.DefaultLanguage,
&instance.CreatedAt,
&instance.UpdatedAt,
&instance.DeletedAt,
)
if err != nil {
return nil, err
}
instances = append(instances, &instance)
}
return instances, nil
}

View File

@@ -0,0 +1,690 @@
package repository_test
import (
"context"
"errors"
"testing"
"time"
"github.com/brianvoe/gofakeit/v6"
"github.com/stretchr/testify/require"
"github.com/zitadel/zitadel/backend/v3/domain"
"github.com/zitadel/zitadel/backend/v3/storage/database"
"github.com/zitadel/zitadel/backend/v3/storage/database/repository"
)
func TestCreateInstance(t *testing.T) {
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
instance domain.Instance
err error
}{
{
name: "happy path",
instance: func() domain.Instance {
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
},
{
name: "create instance without name",
instance: func() domain.Instance {
instanceId := gofakeit.Name()
// instanceName := gofakeit.Name()
instance := domain.Instance{
ID: instanceId,
Name: "",
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
err: errors.New("instance name not provided"),
},
{
name: "adding same instance twice",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
err: errors.New("instance id already exists"),
},
{
name: "adding instance with no id",
instance: func() domain.Instance {
// instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
instance := domain.Instance{
// ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
return instance
}(),
err: errors.New("instance id not provided"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
} else {
instance = &tt.instance
}
instanceRepo := repository.InstanceRepository(pool)
// create instance
beforeCreate := time.Now()
err := instanceRepo.Create(ctx, instance)
require.Equal(t, tt.err, err)
if err != nil {
return
}
afterCreate := time.Now()
// check instance values
instance, err = instanceRepo.Get(ctx,
instanceRepo.NameCondition(database.TextOperationEqual, instance.Name),
)
require.NoError(t, err)
require.Equal(t, tt.instance.ID, instance.ID)
require.Equal(t, tt.instance.Name, instance.Name)
require.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID)
require.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage)
require.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate)
require.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate)
require.Nil(t, instance.DeletedAt)
})
}
}
func TestUpdateInstance(t *testing.T) {
tests := []struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
rowsAffected int64
}{
{
name: "happy path",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
rowsAffected: 1,
},
{
name: "update deleted instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
// delete instance
err = instanceRepo.Delete(ctx,
instanceRepo.IDCondition(inst.ID),
)
require.NoError(t, err)
return &inst
},
rowsAffected: 0,
},
{
name: "update non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
}
return &inst
},
rowsAffected: 0,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
instance := tt.testFunc(ctx, t)
beforeUpdate := time.Now()
// update name
newName := "new_" + instance.Name
rowsAffected, err := instanceRepo.Update(ctx,
instanceRepo.IDCondition(instance.ID),
instanceRepo.SetName(newName),
)
afterUpdate := time.Now()
require.NoError(t, err)
require.Equal(t, tt.rowsAffected, rowsAffected)
if rowsAffected == 0 {
return
}
// check instance values
instance, err = instanceRepo.Get(ctx,
instanceRepo.IDCondition(instance.ID),
)
require.NoError(t, err)
require.Equal(t, newName, instance.Name)
require.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate)
require.Nil(t, instance.DeletedAt)
})
}
}
func TestGetInstance(t *testing.T) {
instanceRepo := repository.InstanceRepository(pool)
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
conditionClauses []database.Condition
}
tests := []test{
func() test {
instanceId := gofakeit.Name()
return test{
name: "happy path get using id",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceName := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
}
}(),
func() test {
instanceName := gofakeit.Name()
return test{
name: "happy path get using name",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
inst := domain.Instance{
ID: instanceId,
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
return &inst
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)},
}
}(),
{
name: "get non existent instance",
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
instanceId := gofakeit.Name()
_ = domain.Instance{
ID: instanceId,
}
return nil
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, "non-existent-instance-name")},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
var instance *domain.Instance
if tt.testFunc != nil {
instance = tt.testFunc(ctx, t)
}
// check instance values
returnedInstance, err := instanceRepo.Get(ctx,
tt.conditionClauses...,
)
require.NoError(t, err)
if instance == nil {
require.Nil(t, instance, returnedInstance)
return
}
require.NoError(t, err)
require.Equal(t, returnedInstance.ID, instance.ID)
require.Equal(t, returnedInstance.Name, instance.Name)
require.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID)
require.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage)
})
}
}
func TestListInstance(t *testing.T) {
ctx := context.Background()
pool, stop, err := newEmbeddedDB(ctx)
require.NoError(t, err)
defer stop()
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
conditionClauses []database.Condition
noInstanceReturned bool
}
tests := []test{
{
name: "happy path single instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
},
{
name: "happy path multiple instance no filter",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
instanceRepo := repository.InstanceRepository(pool)
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
},
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
return test{
name: "instance filter on id",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "multiple instance filter on name",
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
return instances
},
conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)},
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Cleanup(func() {
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
require.NoError(t, err)
})
instances := tt.testFunc(ctx, t)
instanceRepo := repository.InstanceRepository(pool)
// check instance values
returnedInstances, err := instanceRepo.List(ctx,
tt.conditionClauses...,
)
require.NoError(t, err)
if tt.noInstanceReturned {
require.Nil(t, returnedInstances)
return
}
require.Equal(t, len(instances), len(returnedInstances))
for i, instance := range instances {
require.Equal(t, returnedInstances[i].ID, instance.ID)
require.Equal(t, returnedInstances[i].Name, instance.Name)
require.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID)
require.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID)
require.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID)
require.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID)
require.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage)
}
})
}
}
func TestDeleteInstance(t *testing.T) {
type test struct {
name string
testFunc func(ctx context.Context, t *testing.T)
conditionClauses database.Condition
}
tests := []test{
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceId := gofakeit.Name()
return test{
name: "happy path delete single instance filter id",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: instanceId,
Name: gofakeit.Name(),
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.IDCondition(instanceId),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "happy path delete single instance filter name",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
non_existent_instance_name := gofakeit.Name()
return test{
name: "delete non existent instance",
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, non_existent_instance_name),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "multiple instance filter on name",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 5
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
}
}(),
func() test {
instanceRepo := repository.InstanceRepository(pool)
instanceName := gofakeit.Name()
return test{
name: "deleted already deleted instance",
testFunc: func(ctx context.Context, t *testing.T) {
noOfInstances := 1
instances := make([]*domain.Instance, noOfInstances)
for i := range noOfInstances {
inst := domain.Instance{
ID: gofakeit.Name(),
Name: instanceName,
DefaultOrgID: "defaultOrgId",
IAMProjectID: "iamProject",
ConsoleClientID: "consoleCLient",
ConsoleAppID: "consoleApp",
DefaultLanguage: "defaultLanguage",
}
// create instance
err := instanceRepo.Create(ctx, &inst)
require.NoError(t, err)
instances[i] = &inst
}
// delete instance
err := instanceRepo.Delete(ctx,
instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
)
require.NoError(t, err)
},
conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName),
}
}(),
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
ctx := context.Background()
instanceRepo := repository.InstanceRepository(pool)
if tt.testFunc != nil {
tt.testFunc(ctx, t)
}
// delete instance
err := instanceRepo.Delete(ctx,
tt.conditionClauses,
)
require.NoError(t, err)
// check instance was deleted
instance, err := instanceRepo.Get(ctx,
tt.conditionClauses,
)
require.NoError(t, err)
require.Nil(t, instance)
})
}
}

View File

@@ -1,16 +1,10 @@
package repository
import (
"context"
"testing"
"github.com/stretchr/testify/assert"
)
package repository_test
// iraq: I had to comment out so that the UTs would pass
// TestBla is an example and can be removed later
func TestBla(t *testing.T) {
var count int
err := pool.QueryRow(context.Background(), "select count(*) from zitadel.instances").Scan(&count)
assert.NoError(t, err)
assert.Equal(t, 0, count)
}
// func TestBla(t *testing.T) {
// var count int
// err := pool.QueryRow(context.Background(), "select count(*) from zitadel.instances").Scan(&count)
// assert.NoError(t, err)
// assert.Equal(t, 0, count)
// }

View File

@@ -3,6 +3,5 @@ package repository
import "github.com/zitadel/zitadel/backend/v3/storage/database"
type repository struct {
builder database.StatementBuilder
client database.QueryExecutor
client database.QueryExecutor
}

View File

@@ -1,7 +1,8 @@
package repository
package repository_test
import (
"context"
"fmt"
"log"
"os"
"testing"
@@ -14,28 +15,37 @@ func TestMain(m *testing.M) {
os.Exit(runTests(m))
}
var pool database.Pool
var pool database.PoolTest
func runTests(m *testing.M) int {
connector, stop, err := embedded.StartEmbedded()
var stop func()
var err error
ctx := context.Background()
pool, stop, err = newEmbeddedDB(ctx)
if err != nil {
log.Fatalf("unable to start embedded postgres: %v", err)
log.Printf("error with embedded postgres database: %v", err)
return 1
}
defer stop()
ctx := context.Background()
pool, err = connector.Connect(ctx)
if err != nil {
log.Printf("unable to connect to embedded postgres: %v", err)
return 1
}
err = pool.Migrate(ctx)
if err != nil {
log.Printf("unable to migrate database: %v", err)
return 1
}
return m.Run()
}
func newEmbeddedDB(ctx context.Context) (pool database.PoolTest, stop func(), err error) {
connector, stop, err := embedded.StartEmbedded()
if err != nil {
return nil, nil, fmt.Errorf("unable to start embedded postgres: %w", err)
}
pool_, err := connector.Connect(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to connect to embedded postgres: %w", err)
}
pool = pool_.(database.PoolTest)
err = pool.MigrateTest(ctx)
if err != nil {
return nil, nil, fmt.Errorf("unable to migrate database: %w", err)
}
return pool, stop, err
}

View File

@@ -47,13 +47,14 @@ func (u *user) List(ctx context.Context, opts ...database.QueryOption) (users []
opt(options)
}
u.builder.WriteString(queryUserStmt)
options.WriteCondition(&u.builder)
options.WriteOrderBy(&u.builder)
options.WriteLimit(&u.builder)
options.WriteOffset(&u.builder)
builder := database.StatementBuilder{}
builder.WriteString(queryUserStmt)
options.WriteCondition(&builder)
options.WriteOrderBy(&builder)
options.WriteLimit(&builder)
options.WriteOffset(&builder)
rows, err := u.client.Query(ctx, u.builder.String(), u.builder.Args()...)
rows, err := u.client.Query(ctx, builder.String(), builder.Args()...)
if err != nil {
return nil, err
}
@@ -84,13 +85,14 @@ func (u *user) Get(ctx context.Context, opts ...database.QueryOption) (*domain.U
opt(options)
}
u.builder.WriteString(queryUserStmt)
options.WriteCondition(&u.builder)
options.WriteOrderBy(&u.builder)
options.WriteLimit(&u.builder)
options.WriteOffset(&u.builder)
builder := database.StatementBuilder{}
builder.WriteString(queryUserStmt)
options.WriteCondition(&builder)
options.WriteOrderBy(&builder)
options.WriteLimit(&builder)
options.WriteOffset(&builder)
return scanUser(u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...))
return scanUser(u.client.QueryRow(ctx, builder.String(), builder.Args()...))
}
const (
@@ -104,23 +106,26 @@ const (
// Create implements [domain.UserRepository].
func (u *user) Create(ctx context.Context, user *domain.User) error {
u.builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
builder := database.StatementBuilder{}
builder.AppendArgs(user.InstanceID, user.OrgID, user.ID, user.Username, user.Traits.Type())
switch trait := user.Traits.(type) {
case *domain.Human:
u.builder.WriteString(createHumanStmt)
u.builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
builder.WriteString(createHumanStmt)
builder.AppendArgs(trait.FirstName, trait.LastName, trait.Email.Address, trait.Email.VerifiedAt, trait.Phone.Number, trait.Phone.VerifiedAt)
case *domain.Machine:
u.builder.WriteString(createMachineStmt)
u.builder.AppendArgs(trait.Description)
builder.WriteString(createMachineStmt)
builder.AppendArgs(trait.Description)
}
return u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
return u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(&user.CreatedAt, &user.UpdatedAt)
}
// Delete implements [domain.UserRepository].
func (u *user) Delete(ctx context.Context, condition database.Condition) error {
u.builder.WriteString("DELETE FROM users")
u.writeCondition(condition)
return u.client.Exec(ctx, u.builder.String(), u.builder.Args()...)
builder := database.StatementBuilder{}
builder.WriteString("DELETE FROM users")
u.writeCondition(builder, condition)
_, err := u.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
@@ -218,12 +223,15 @@ func (user) DeletedAtColumn() database.Column {
return database.NewColumn("deleted_at")
}
func (u *user) writeCondition(condition database.Condition) {
func (u *user) writeCondition(
builder database.StatementBuilder,
condition database.Condition,
) {
if condition == nil {
return
}
u.builder.WriteString(" WHERE ")
condition.Write(&u.builder)
builder.WriteString(" WHERE ")
condition.Write(&builder)
}
func (u user) columns() database.Columns {

View File

@@ -24,10 +24,11 @@ const userEmailQuery = `SELECT h.email_address, h.email_verified_at FROM user_hu
func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition) (*domain.Email, error) {
var email domain.Email
u.builder.WriteString(userEmailQuery)
u.writeCondition(condition)
builder := database.StatementBuilder{}
builder.WriteString(userEmailQuery)
u.writeCondition(builder, condition)
err := u.client.QueryRow(ctx, u.builder.String(), u.builder.Args()...).Scan(
err := u.client.QueryRow(ctx, builder.String(), builder.Args()...).Scan(
&email.Address,
&email.VerifiedAt,
)
@@ -39,13 +40,15 @@ func (u *userHuman) GetEmail(ctx context.Context, condition database.Condition)
// Update implements [domain.HumanRepository].
func (h userHuman) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
h.builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&h.builder)
h.writeCondition(condition)
builder := database.StatementBuilder{}
builder.WriteString(`UPDATE human_users SET `)
database.Changes(changes).Write(&builder)
h.writeCondition(builder, condition)
stmt := h.builder.String()
stmt := builder.String()
return h.client.Exec(ctx, stmt, h.builder.Args()...)
_, err := h.client.Exec(ctx, stmt, builder.Args()...)
return err
}
// -------------------------------------------------------------

View File

@@ -18,13 +18,15 @@ var _ domain.MachineRepository = (*userMachine)(nil)
// -------------------------------------------------------------
// Update implements [domain.MachineRepository].
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (err error) {
m.builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&m.builder)
m.writeCondition(condition)
func (m userMachine) Update(ctx context.Context, condition database.Condition, changes ...database.Change) error {
builder := database.StatementBuilder{}
builder.WriteString("UPDATE user_machines SET ")
database.Changes(changes).Write(&builder)
m.writeCondition(builder, condition)
m.writeReturning()
return m.client.Exec(ctx, m.builder.String(), m.builder.Args()...)
_, err := m.client.Exec(ctx, builder.String(), builder.Args()...)
return err
}
// -------------------------------------------------------------
@@ -59,6 +61,7 @@ func (m userMachine) columns() database.Columns {
}
func (m *userMachine) writeReturning() {
m.builder.WriteString(" RETURNING ")
m.columns().Write(&m.builder)
builder := database.StatementBuilder{}
builder.WriteString(" RETURNING ")
m.columns().Write(&builder)
}

View File

@@ -74,3 +74,4 @@ package repository_test
// user.Human().Update(context.Background(), user.IDCondition("test"), user.SetUsername("test"))
// })
// }

View File

@@ -15,7 +15,8 @@ type Event struct {
func Publish(ctx context.Context, events []*Event, db database.Executor) error {
for _, event := range events {
if err := db.Exec(ctx, `INSERT INTO events (aggregate_type, aggregate_id) VALUES ($1, $2)`, event.AggregateType, event.AggregateID); err != nil {
_, err := db.Exec(ctx, `INSERT INTO events (aggregate_type, aggregate_id) VALUES ($1, $2)`, event.AggregateType, event.AggregateID)
if err != nil {
return err
}
}