diff --git a/backend/v3/storage/database/database.go b/backend/v3/storage/database/database.go index f962132fa0..8d890b2637 100644 --- a/backend/v3/storage/database/database.go +++ b/backend/v3/storage/database/database.go @@ -35,6 +35,7 @@ type Querier interface { } // Executor is a database client that can execute statements. +// It returns the number of rows affected or an error type Executor interface { Exec(ctx context.Context, stmt string, args ...any) (int64, error) } diff --git a/backend/v3/storage/database/repository/instance.go b/backend/v3/storage/database/repository/instance.go index 8278eeb24e..03fc5a8149 100644 --- a/backend/v3/storage/database/repository/instance.go +++ b/backend/v3/storage/database/repository/instance.go @@ -33,21 +33,20 @@ const queryInstanceStmt = `SELECT id, name, default_org_id, iam_project_id, cons // Get implements [domain.InstanceRepository]. func (i *instance) Get(ctx context.Context, opts ...database.Condition) (*domain.Instance, error) { - builder := database.StatementBuilder{} + var builder database.StatementBuilder builder.WriteString(queryInstanceStmt) - // return only non deleted isntances + // return only non deleted instances opts = append(opts, database.IsNull(i.DeletedAtColumn())) - andCondition := database.And(opts...) - i.writeCondition(&builder, andCondition) + i.writeCondition(&builder, database.And(opts...)) return scanInstance(i.client.QueryRow(ctx, builder.String(), builder.Args()...)) } // List implements [domain.InstanceRepository]. func (i *instance) List(ctx context.Context, opts ...database.Condition) ([]*domain.Instance, error) { - builder := database.StatementBuilder{} + var builder database.StatementBuilder builder.WriteString(queryInstanceStmt) @@ -71,7 +70,8 @@ const createInstanceStmt = `INSERT INTO zitadel.instances (id, name, default_org // Create implements [domain.InstanceRepository]. func (i *instance) Create(ctx context.Context, instance *domain.Instance) error { - builder := database.StatementBuilder{} + var builder database.StatementBuilder + builder.AppendArgs(instance.ID, instance.Name, instance.DefaultOrgID, instance.IAMProjectID, instance.ConsoleClientID, instance.ConsoleAppID, instance.DefaultLanguage) builder.WriteString(createInstanceStmt) @@ -101,7 +101,8 @@ func (i *instance) Create(ctx context.Context, instance *domain.Instance) error // Update implements [domain.InstanceRepository]. func (i instance) Update(ctx context.Context, condition database.Condition, changes ...database.Change) (int64, error) { - builder := database.StatementBuilder{} + var builder database.StatementBuilder + builder.WriteString(`UPDATE zitadel.instances SET `) database.Changes(changes).Write(&builder) i.writeCondition(&builder, condition) @@ -118,6 +119,7 @@ func (i instance) Delete(ctx context.Context, condition database.Condition) erro return errors.New("Delete must contain a condition") // (otherwise ALL instances will be deleted) } builder := database.StatementBuilder{} + builder.WriteString(`UPDATE zitadel.instances SET deleted_at = $1`) builder.AppendArgs(time.Now()) diff --git a/backend/v3/storage/database/repository/instance_test.go b/backend/v3/storage/database/repository/instance_test.go index 107cc81002..55a62c5e6e 100644 --- a/backend/v3/storage/database/repository/instance_test.go +++ b/backend/v3/storage/database/repository/instance_test.go @@ -7,7 +7,7 @@ import ( "time" "github.com/brianvoe/gofakeit/v6" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/zitadel/zitadel/backend/v3/domain" "github.com/zitadel/zitadel/backend/v3/storage/database" @@ -75,7 +75,7 @@ func TestCreateInstance(t *testing.T) { } err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) return &inst }, err: errors.New("instnace id already exists"), @@ -114,7 +114,7 @@ func TestCreateInstance(t *testing.T) { // create instance beforeCreate := time.Now() err := instanceRepo.Create(ctx, instance) - assert.Equal(t, tt.err, err) + require.Equal(t, tt.err, err) if err != nil { return } @@ -124,17 +124,17 @@ func TestCreateInstance(t *testing.T) { instance, err = instanceRepo.Get(ctx, instanceRepo.NameCondition(database.TextOperationEqual, instance.Name), ) - assert.Equal(t, tt.instance.ID, instance.ID) - assert.Equal(t, tt.instance.Name, instance.Name) - assert.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID) - assert.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID) - assert.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID) - assert.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID) - assert.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage) - assert.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) - assert.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) - assert.Nil(t, instance.DeletedAt) - assert.NoError(t, err) + require.Equal(t, tt.instance.ID, instance.ID) + require.Equal(t, tt.instance.Name, instance.Name) + require.Equal(t, tt.instance.DefaultOrgID, instance.DefaultOrgID) + require.Equal(t, tt.instance.IAMProjectID, instance.IAMProjectID) + require.Equal(t, tt.instance.ConsoleClientID, instance.ConsoleClientID) + require.Equal(t, tt.instance.ConsoleAppID, instance.ConsoleAppID) + require.Equal(t, tt.instance.DefaultLanguage, instance.DefaultLanguage) + require.WithinRange(t, instance.CreatedAt, beforeCreate, afterCreate) + require.WithinRange(t, instance.UpdatedAt, beforeCreate, afterCreate) + require.Nil(t, instance.DeletedAt) + require.NoError(t, err) }) } } @@ -165,7 +165,7 @@ func TestUpdateInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) return &inst }, rowsAffected: 1, @@ -199,9 +199,9 @@ func TestUpdateInstance(t *testing.T) { instanceRepo.SetName(newName), ) afterUpdate := time.Now() - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, tt.rowsAffected, rowsAffected) + require.Equal(t, tt.rowsAffected, rowsAffected) if rowsAffected == 0 { return @@ -211,11 +211,11 @@ func TestUpdateInstance(t *testing.T) { instance, err = instanceRepo.Get(ctx, instanceRepo.IDCondition(instance.ID), ) - assert.NoError(t, err) + require.NoError(t, err) - assert.Equal(t, newName, instance.Name) - assert.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate) - assert.Nil(t, instance.DeletedAt) + require.Equal(t, newName, instance.Name) + require.WithinRange(t, instance.UpdatedAt, beforeUpdate, afterUpdate) + require.Nil(t, instance.DeletedAt) }) } } @@ -250,7 +250,7 @@ func TestGetInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) return &inst }, conditionClauses: []database.Condition{instanceRepo.IDCondition(instanceId)}, @@ -276,7 +276,7 @@ func TestGetInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) return &inst }, conditionClauses: []database.Condition{instanceRepo.NameCondition(database.TextOperationEqual, instanceName)}, @@ -310,20 +310,20 @@ func TestGetInstance(t *testing.T) { returnedInstance, err := instanceRepo.Get(ctx, tt.conditionClauses..., ) - assert.NoError(t, err) + require.NoError(t, err) if tt.noInstanceReturned { - assert.Nil(t, returnedInstance) + require.Nil(t, returnedInstance) return } - assert.Equal(t, returnedInstance.ID, instance.ID) - assert.Equal(t, returnedInstance.Name, instance.Name) - assert.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID) - assert.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID) - assert.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID) - assert.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID) - assert.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage) - assert.NoError(t, err) + require.Equal(t, returnedInstance.ID, instance.ID) + require.Equal(t, returnedInstance.Name, instance.Name) + require.Equal(t, returnedInstance.DefaultOrgID, instance.DefaultOrgID) + require.Equal(t, returnedInstance.IAMProjectID, instance.IAMProjectID) + require.Equal(t, returnedInstance.ConsoleClientID, instance.ConsoleClientID) + require.Equal(t, returnedInstance.ConsoleAppID, instance.ConsoleAppID) + require.Equal(t, returnedInstance.DefaultLanguage, instance.DefaultLanguage) + require.NoError(t, err) }) } } @@ -342,19 +342,16 @@ func TestListInstance(t *testing.T) { ctx := context.Background() // create new db to make sure no instances exist pool, stop, err := newEmbeededDB() - assert.NoError(t, err) + require.NoError(t, err) instanceRepo := repository.InstanceRepository(pool) noOfInstances := 1 instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceId := gofakeit.Name() - instanceName := gofakeit.Name() - inst := domain.Instance{ - ID: instanceId, - Name: instanceName, + ID: gofakeit.Name(), + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -364,7 +361,7 @@ func TestListInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -378,19 +375,16 @@ func TestListInstance(t *testing.T) { ctx := context.Background() // create new db to make sure no instances exist pool, stop, err := newEmbeededDB() - assert.NoError(t, err) + require.NoError(t, err) instanceRepo := repository.InstanceRepository(pool) noOfInstances := 5 instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceId := gofakeit.Name() - instanceName := gofakeit.Name() - inst := domain.Instance{ - ID: instanceId, - Name: instanceName, + ID: gofakeit.Name(), + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -400,7 +394,7 @@ func TestListInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -420,11 +414,9 @@ func TestListInstance(t *testing.T) { instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceName := gofakeit.Name() - inst := domain.Instance{ ID: instanceId, - Name: instanceName, + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -434,7 +426,7 @@ func TestListInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -456,10 +448,8 @@ func TestListInstance(t *testing.T) { instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceId := gofakeit.Name() - inst := domain.Instance{ - ID: instanceId, + ID: gofakeit.Name(), Name: instanceName, DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -470,7 +460,7 @@ func TestListInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -503,22 +493,22 @@ func TestListInstance(t *testing.T) { returnedInstances, err := instanceRepo.List(ctx, tt.conditionClauses..., ) - assert.NoError(t, err) + require.NoError(t, err) if tt.noInstanceReturned { - assert.Nil(t, returnedInstances) + require.Nil(t, returnedInstances) return } - assert.Equal(t, len(instances), len(returnedInstances)) + require.Equal(t, len(instances), len(returnedInstances)) for i, instance := range instances { - assert.Equal(t, returnedInstances[i].ID, instance.ID) - assert.Equal(t, returnedInstances[i].Name, instance.Name) - assert.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID) - assert.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID) - assert.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID) - assert.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID) - assert.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage) - assert.NoError(t, err) + require.Equal(t, returnedInstances[i].ID, instance.ID) + require.Equal(t, returnedInstances[i].Name, instance.Name) + require.Equal(t, returnedInstances[i].DefaultOrgID, instance.DefaultOrgID) + require.Equal(t, returnedInstances[i].IAMProjectID, instance.IAMProjectID) + require.Equal(t, returnedInstances[i].ConsoleClientID, instance.ConsoleClientID) + require.Equal(t, returnedInstances[i].ConsoleAppID, instance.ConsoleAppID) + require.Equal(t, returnedInstances[i].DefaultLanguage, instance.DefaultLanguage) + require.NoError(t, err) } }) } @@ -543,11 +533,9 @@ func TestDeleteInstance(t *testing.T) { instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceName := gofakeit.Name() - inst := domain.Instance{ ID: instanceId, - Name: instanceName, + Name: gofakeit.Name(), DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", ConsoleClientID: "consoleCLient", @@ -557,7 +545,7 @@ func TestDeleteInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -577,10 +565,8 @@ func TestDeleteInstance(t *testing.T) { instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceId := gofakeit.Name() - inst := domain.Instance{ - ID: instanceId, + ID: gofakeit.Name(), Name: instanceName, DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -591,7 +577,7 @@ func TestDeleteInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -619,10 +605,8 @@ func TestDeleteInstance(t *testing.T) { instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceId := gofakeit.Name() - inst := domain.Instance{ - ID: instanceId, + ID: gofakeit.Name(), Name: instanceName, DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -633,7 +617,7 @@ func TestDeleteInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -653,10 +637,8 @@ func TestDeleteInstance(t *testing.T) { instances := make([]*domain.Instance, noOfInstances) for i := range noOfInstances { - instanceId := gofakeit.Name() - inst := domain.Instance{ - ID: instanceId, + ID: gofakeit.Name(), Name: instanceName, DefaultOrgID: "defaultOrgId", IAMProjectID: "iamProject", @@ -667,7 +649,7 @@ func TestDeleteInstance(t *testing.T) { // create instance err := instanceRepo.Create(ctx, &inst) - assert.NoError(t, err) + require.NoError(t, err) instances[i] = &inst } @@ -676,7 +658,7 @@ func TestDeleteInstance(t *testing.T) { err := instanceRepo.Delete(ctx, instanceRepo.NameCondition(database.TextOperationEqual, instanceName), ) - assert.NoError(t, err) + require.NoError(t, err) }, conditionClauses: instanceRepo.NameCondition(database.TextOperationEqual, instanceName), } @@ -695,14 +677,14 @@ func TestDeleteInstance(t *testing.T) { err := instanceRepo.Delete(ctx, tt.conditionClauses, ) - assert.NoError(t, err) + require.NoError(t, err) // check instance was deleted instance, err := instanceRepo.Get(ctx, tt.conditionClauses, ) - assert.NoError(t, err) - assert.Nil(t, instance) + require.NoError(t, err) + require.Nil(t, instance) }) } } diff --git a/backend/v3/storage/database/repository/user_test.go_ b/backend/v3/storage/database/repository/user_test.go_ new file mode 100644 index 0000000000..b25cdfd220 --- /dev/null +++ b/backend/v3/storage/database/repository/user_test.go_ @@ -0,0 +1,77 @@ +package repository_test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" + "go.uber.org/mock/gomock" + + "github.com/zitadel/zitadel/backend/v3/storage/database" + "github.com/zitadel/zitadel/backend/v3/storage/database/dbmock" + "github.com/zitadel/zitadel/backend/v3/storage/database/repository" +) + +func TestQueryUser(t *testing.T) { + t.Skip("tests are meant as examples and are not real tests") + t.Run("User filters", func(t *testing.T) { + client := dbmock.NewMockClient(gomock.NewController(t)) + + user := repository.UserRepository(client) + u, err := user.Get(context.Background(), + database.WithCondition( + database.And( + database.Or( + user.IDCondition("test"), + user.IDCondition("2"), + ), + user.UsernameCondition(database.TextOperationStartsWithIgnoreCase, "test"), + ), + ), + database.WithOrderBy(user.CreatedAtColumn()), + ) + + assert.NoError(t, err) + assert.NotNil(t, u) + }) + + t.Run("machine and human filters", func(t *testing.T) { + client := dbmock.NewMockClient(gomock.NewController(t)) + + user := repository.UserRepository(client) + machine := user.Machine() + human := user.Human() + email, err := human.GetEmail(context.Background(), database.And( + user.UsernameCondition(database.TextOperationStartsWithIgnoreCase, "test"), + database.Or( + machine.DescriptionCondition(database.TextOperationStartsWithIgnoreCase, "test"), + human.EmailVerifiedCondition(true), + database.IsNotNull(machine.DescriptionColumn()), + ), + )) + + assert.NoError(t, err) + assert.NotNil(t, email) + }) +} + +type dbInstruction string + +func TestArg(t *testing.T) { + var bla any = "asdf" + instr, ok := bla.(dbInstruction) + assert.False(t, ok) + assert.Empty(t, instr) + bla = dbInstruction("asdf") + instr, ok = bla.(dbInstruction) + assert.True(t, ok) + assert.Equal(t, instr, dbInstruction("asdf")) +} + +func TestWriteUser(t *testing.T) { + t.Skip("tests are meant as examples and are not real tests") + t.Run("update user", func(t *testing.T) { + user := repository.UserRepository(nil) + user.Human().Update(context.Background(), user.IDCondition("test"), user.SetUsername("test")) + }) +}