mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-24 00:07:36 +00:00
refactor: database interaction and error handling (#10762)
This pull request introduces a significant refactoring of the database interaction layer, focusing on improving explicitness, transactional control, and error handling. The core change is the removal of the stateful `QueryExecutor` from repository instances. Instead, it is now passed as an argument to each method that interacts with the database. This change makes transaction management more explicit and flexible, as the same repository instance can be used with a database pool or a specific transaction without needing to be re-instantiated. ### Key Changes - **Explicit `QueryExecutor` Passing:** - All repository methods (`Get`, `List`, `Create`, `Update`, `Delete`, etc.) in `InstanceRepository`, `OrganizationRepository`, `UserRepository`, and their sub-repositories now require a `database.QueryExecutor` (e.g., a `*pgxpool.Pool` or `pgx.Tx`) as the first argument. - Repository constructors no longer accept a `QueryExecutor`. For example, `repository.InstanceRepository(pool)` is now `repository.InstanceRepository()`. - **Enhanced Error Handling:** - A new `database.MissingConditionError` is introduced to enforce required query conditions, such as ensuring an `instance_id` is always present in `UPDATE` and `DELETE` operations. - The database error wrapper in the `postgres` package now correctly identifies and wraps `pgx.ErrTooManyRows` and similar errors from the `scany` library into a `database.MultipleRowsFoundError`. - **Improved Database Conditions:** - The `database.Condition` interface now includes a `ContainsColumn(Column) bool` method. This allows for runtime checks to ensure that critical filters (like `instance_id`) are included in a query, preventing accidental cross-tenant data modification. - A new `database.Exists()` condition has been added to support `EXISTS` subqueries, enabling more complex filtering logic, such as finding an organization that has a specific domain. - **Repository and Interface Refactoring:** - The method for loading related entities (e.g., domains for an organization) has been changed from a boolean flag (`Domains(true)`) to a more explicit, chainable method (`LoadDomains()`). This returns a new repository instance configured to load the sub-resource, promoting immutability. - The custom `OrgIdentifierCondition` has been removed in favor of using the standard `database.Condition` interface, simplifying the API. - **Code Cleanup and Test Updates:** - Unnecessary struct embeddings and metadata have been removed. - All integration and repository tests have been updated to reflect the new method signatures, passing the database pool or transaction object explicitly. - New tests have been added to cover the new `ExistsDomain` functionality and other enhancements. These changes make the data access layer more robust, predictable, and easier to work with, especially in the context of database transactions.
This commit is contained in:
@@ -2,6 +2,7 @@ package repository_test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@@ -16,9 +17,20 @@ import (
|
||||
)
|
||||
|
||||
func TestCreateInstance(t *testing.T) {
|
||||
beforeCreate := time.Now()
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
}{
|
||||
@@ -59,14 +71,10 @@ func TestCreateInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "adding same instance twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -74,7 +82,9 @@ func TestCreateInstance(t *testing.T) {
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// change the name to make sure same only the id clashes
|
||||
inst.Name = gofakeit.Name()
|
||||
require.NoError(t, err)
|
||||
@@ -84,7 +94,7 @@ func TestCreateInstance(t *testing.T) {
|
||||
},
|
||||
func() struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
} {
|
||||
@@ -92,14 +102,12 @@ func TestCreateInstance(t *testing.T) {
|
||||
instanceName := gofakeit.Name()
|
||||
return struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
instance domain.Instance
|
||||
err error
|
||||
}{
|
||||
name: "adding instance with same name twice",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: instanceName,
|
||||
@@ -110,7 +118,7 @@ func TestCreateInstance(t *testing.T) {
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// change the id
|
||||
@@ -135,11 +143,8 @@ func TestCreateInstance(t *testing.T) {
|
||||
{
|
||||
name: "adding instance with no id",
|
||||
instance: func() domain.Instance {
|
||||
// instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
instance := domain.Instance{
|
||||
// ID: instanceId,
|
||||
Name: instanceName,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -153,19 +158,25 @@ func TestCreateInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
var instance *domain.Instance
|
||||
if tt.testFunc != nil {
|
||||
instance = tt.testFunc(ctx, t)
|
||||
instance = tt.testFunc(t, savepoint)
|
||||
} else {
|
||||
instance = &tt.instance
|
||||
}
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
// create instance
|
||||
beforeCreate := time.Now()
|
||||
err := instanceRepo.Create(ctx, instance)
|
||||
|
||||
err = instanceRepo.Create(t.Context(), tx, instance)
|
||||
assert.ErrorIs(t, err, tt.err)
|
||||
if err != nil {
|
||||
return
|
||||
@@ -173,7 +184,7 @@ func TestCreateInstance(t *testing.T) {
|
||||
afterCreate := time.Now()
|
||||
|
||||
// check instance values
|
||||
instance, err = instanceRepo.Get(ctx,
|
||||
instance, err = instanceRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
@@ -194,22 +205,30 @@ func TestCreateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestUpdateInstance(t *testing.T) {
|
||||
beforeUpdate := time.Now()
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) *domain.Instance
|
||||
rowsAffected int64
|
||||
getErr error
|
||||
}{
|
||||
{
|
||||
name: "happy path",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -218,7 +237,7 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
return &inst
|
||||
},
|
||||
@@ -226,14 +245,10 @@ func TestUpdateInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "update deleted instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -242,11 +257,11 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// delete instance
|
||||
affectedRows, err := instanceRepo.Delete(ctx,
|
||||
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
|
||||
inst.ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
@@ -258,11 +273,9 @@ func TestUpdateInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "update non existent instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceId := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
ID: gofakeit.UUID(),
|
||||
}
|
||||
return &inst
|
||||
},
|
||||
@@ -272,15 +285,11 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instance := tt.testFunc(t, tx)
|
||||
|
||||
instance := tt.testFunc(ctx, t)
|
||||
|
||||
beforeUpdate := time.Now()
|
||||
// update name
|
||||
newName := "new_" + instance.Name
|
||||
rowsAffected, err := instanceRepo.Update(ctx,
|
||||
rowsAffected, err := instanceRepo.Update(t.Context(), tx,
|
||||
instance.ID,
|
||||
instanceRepo.SetName(newName),
|
||||
)
|
||||
@@ -294,7 +303,7 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// check instance values
|
||||
instance, err = instanceRepo.Get(ctx,
|
||||
instance, err = instanceRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
@@ -308,24 +317,31 @@ func TestUpdateInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestGetInstance(t *testing.T) {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
domainRepo := repository.InstanceDomainRepository()
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) *domain.Instance
|
||||
testFunc func(t *testing.T) *domain.Instance
|
||||
err error
|
||||
}
|
||||
|
||||
tests := []test{
|
||||
func() test {
|
||||
instanceId := gofakeit.Name()
|
||||
return test{
|
||||
name: "happy path get using id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.UUID(),
|
||||
Name: gofakeit.BeerName(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -334,7 +350,7 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
return &inst
|
||||
},
|
||||
@@ -342,14 +358,10 @@ func TestGetInstance(t *testing.T) {
|
||||
}(),
|
||||
{
|
||||
name: "happy path including domains",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceName := gofakeit.Name()
|
||||
|
||||
testFunc: func(t *testing.T) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: instanceName,
|
||||
ID: gofakeit.NewCrypto().UUID(),
|
||||
Name: gofakeit.BeerName(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
@@ -358,10 +370,9 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
domainRepo := instanceRepo.Domains(false)
|
||||
d := &domain.AddInstanceDomain{
|
||||
InstanceID: inst.ID,
|
||||
Domain: gofakeit.DomainName(),
|
||||
@@ -369,7 +380,7 @@ func TestGetInstance(t *testing.T) {
|
||||
IsGenerated: gu.Ptr(false),
|
||||
Type: domain.DomainTypeCustom,
|
||||
}
|
||||
err = domainRepo.Add(ctx, d)
|
||||
err = domainRepo.Add(t.Context(), tx, d)
|
||||
require.NoError(t, err)
|
||||
|
||||
inst.Domains = append(inst.Domains, &domain.InstanceDomain{
|
||||
@@ -387,7 +398,7 @@ func TestGetInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "get non existent instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) *domain.Instance {
|
||||
testFunc: func(t *testing.T) *domain.Instance {
|
||||
inst := domain.Instance{
|
||||
ID: "get non existent instance",
|
||||
}
|
||||
@@ -398,16 +409,13 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
|
||||
var instance *domain.Instance
|
||||
if tt.testFunc != nil {
|
||||
instance = tt.testFunc(ctx, t)
|
||||
instance = tt.testFunc(t)
|
||||
}
|
||||
|
||||
// check instance values
|
||||
returnedInstance, err := instanceRepo.Get(ctx,
|
||||
returnedInstance, err := instanceRepo.Get(t.Context(), tx,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(instance.ID),
|
||||
),
|
||||
@@ -434,28 +442,33 @@ func TestGetInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestListInstance(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
pool, stop, err := newEmbeddedDB(ctx)
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer stop()
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T) []*domain.Instance
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor) []*domain.Instance
|
||||
conditionClauses []database.Condition
|
||||
noInstanceReturned bool
|
||||
}
|
||||
tests := []test{
|
||||
{
|
||||
name: "happy path single instance no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -465,7 +478,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -476,14 +489,13 @@ func TestListInstance(t *testing.T) {
|
||||
},
|
||||
{
|
||||
name: "happy path multiple instance no filter",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 5
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -493,7 +505,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -503,17 +515,16 @@ func TestListInstance(t *testing.T) {
|
||||
},
|
||||
},
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
instanceID := gofakeit.BeerName()
|
||||
return test{
|
||||
name: "instance filter on id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -523,7 +534,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -531,21 +542,20 @@ func TestListInstance(t *testing.T) {
|
||||
|
||||
return instances
|
||||
},
|
||||
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)},
|
||||
conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceID)},
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceName := gofakeit.Name()
|
||||
instanceName := gofakeit.BeerName()
|
||||
return test{
|
||||
name: "multiple instance filter on name",
|
||||
testFunc: func(ctx context.Context, t *testing.T) []*domain.Instance {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) []*domain.Instance {
|
||||
noOfInstances := 5
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
ID: strconv.Itoa(i),
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
@@ -555,7 +565,7 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
@@ -569,14 +579,15 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Cleanup(func() {
|
||||
_, err := pool.Exec(ctx, "DELETE FROM zitadel.instances")
|
||||
require.NoError(t, err)
|
||||
})
|
||||
|
||||
instances := tt.testFunc(ctx, t)
|
||||
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
instances := tt.testFunc(t, savepoint)
|
||||
|
||||
var condition database.Condition
|
||||
if len(tt.conditionClauses) > 0 {
|
||||
@@ -584,13 +595,13 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
// check instance values
|
||||
returnedInstances, err := instanceRepo.List(ctx,
|
||||
returnedInstances, err := instanceRepo.List(t.Context(), tx,
|
||||
database.WithCondition(condition),
|
||||
database.WithOrderByAscending(instanceRepo.CreatedAtColumn()),
|
||||
database.WithOrderByAscending(instanceRepo.IDColumn()),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
if tt.noInstanceReturned {
|
||||
assert.Nil(t, returnedInstances)
|
||||
assert.Len(t, returnedInstances, 0)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -609,42 +620,45 @@ func TestListInstance(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestDeleteInstance(t *testing.T) {
|
||||
tx, err := pool.Begin(context.Background(), nil)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := tx.Rollback(context.Background())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
instanceRepo := repository.InstanceRepository()
|
||||
|
||||
type test struct {
|
||||
name string
|
||||
testFunc func(ctx context.Context, t *testing.T)
|
||||
testFunc func(t *testing.T, tx database.QueryExecutor)
|
||||
instanceID string
|
||||
noOfDeletedRows int64
|
||||
}
|
||||
tests := []test{
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceId := gofakeit.Name()
|
||||
var noOfInstances int64 = 1
|
||||
instanceID := gofakeit.NewCrypto().UUID()
|
||||
return test{
|
||||
name: "happy path delete single instance filter id",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: instanceId,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) {
|
||||
inst := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.Name(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
},
|
||||
instanceID: instanceId,
|
||||
noOfDeletedRows: noOfInstances,
|
||||
instanceID: instanceID,
|
||||
noOfDeletedRows: 1,
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
@@ -655,40 +669,33 @@ func TestDeleteInstance(t *testing.T) {
|
||||
}
|
||||
}(),
|
||||
func() test {
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
instanceName := gofakeit.Name()
|
||||
instanceID := gofakeit.Name()
|
||||
return test{
|
||||
name: "deleted already deleted instance",
|
||||
testFunc: func(ctx context.Context, t *testing.T) {
|
||||
noOfInstances := 1
|
||||
instances := make([]*domain.Instance, noOfInstances)
|
||||
for i := range noOfInstances {
|
||||
testFunc: func(t *testing.T, tx database.QueryExecutor) {
|
||||
|
||||
inst := domain.Instance{
|
||||
ID: gofakeit.Name(),
|
||||
Name: instanceName,
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(ctx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
instances[i] = &inst
|
||||
inst := domain.Instance{
|
||||
ID: instanceID,
|
||||
Name: gofakeit.BeerName(),
|
||||
DefaultOrgID: "defaultOrgId",
|
||||
IAMProjectID: "iamProject",
|
||||
ConsoleClientID: "consoleCLient",
|
||||
ConsoleAppID: "consoleApp",
|
||||
DefaultLanguage: "defaultLanguage",
|
||||
}
|
||||
|
||||
// create instance
|
||||
err := instanceRepo.Create(t.Context(), tx, &inst)
|
||||
require.NoError(t, err)
|
||||
|
||||
// delete instance
|
||||
affectedRows, err := instanceRepo.Delete(ctx,
|
||||
instances[0].ID,
|
||||
affectedRows, err := instanceRepo.Delete(t.Context(), tx,
|
||||
inst.ID,
|
||||
)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, int64(1), affectedRows)
|
||||
},
|
||||
instanceID: instanceName,
|
||||
instanceID: instanceID,
|
||||
// this test should return 0 affected rows as the instance was already deleted
|
||||
noOfDeletedRows: 0,
|
||||
}
|
||||
@@ -696,22 +703,26 @@ func TestDeleteInstance(t *testing.T) {
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
instanceRepo := repository.InstanceRepository(pool)
|
||||
savepoint, err := tx.Begin(t.Context())
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err = savepoint.Rollback(t.Context())
|
||||
if err != nil {
|
||||
t.Log("error during rollback:", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if tt.testFunc != nil {
|
||||
tt.testFunc(ctx, t)
|
||||
tt.testFunc(t, savepoint)
|
||||
}
|
||||
|
||||
// delete instance
|
||||
noOfDeletedRows, err := instanceRepo.Delete(ctx,
|
||||
tt.instanceID,
|
||||
)
|
||||
noOfDeletedRows, err := instanceRepo.Delete(t.Context(), savepoint, tt.instanceID)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, noOfDeletedRows, tt.noOfDeletedRows)
|
||||
|
||||
// check instance was deleted
|
||||
instance, err := instanceRepo.Get(ctx,
|
||||
instance, err := instanceRepo.Get(t.Context(), savepoint,
|
||||
database.WithCondition(
|
||||
instanceRepo.IDCondition(tt.instanceID),
|
||||
),
|
||||
|
||||
Reference in New Issue
Block a user