From 9e1a2eada90c7ff2f3ba50def41557f82f0305c0 Mon Sep 17 00:00:00 2001 From: Marco Ardizzone Date: Tue, 16 Sep 2025 19:00:19 +0200 Subject: [PATCH] Add extra checks for organization update --- backend/v3/domain/errors.go | 14 ++++ backend/v3/domain/org_update.go | 20 ++++- backend/v3/domain/org_update_test.go | 115 +++++++++++++++++++++++++-- 3 files changed, 138 insertions(+), 11 deletions(-) diff --git a/backend/v3/domain/errors.go b/backend/v3/domain/errors.go index b0e432b6ef6..3dd9129e3ae 100644 --- a/backend/v3/domain/errors.go +++ b/backend/v3/domain/errors.go @@ -43,3 +43,17 @@ func NewMultipleOrgsUpdatedError(id string, expected, actual int64) error { func (err *MultipleOrgsUpdatedError) Error() string { return fmt.Sprintf("ID=%s Message=expecting %d row(s) updated, got %d", err.ID, err.Expected, err.Actual) } + +type OrgNameNotChangedError struct { + ID string +} + +func NewOrgNameNotChangedError(errID string) error { + return &OrgNameNotChangedError{ + ID: errID, + } +} + +func (err *OrgNameNotChangedError) Error() string { + return fmt.Sprintf("ID=%s Message=organization name has not changed", err.ID) +} diff --git a/backend/v3/domain/org_update.go b/backend/v3/domain/org_update.go index 280e207e846..da9e22cd9fb 100644 --- a/backend/v3/domain/org_update.go +++ b/backend/v3/domain/org_update.go @@ -4,7 +4,6 @@ import ( "context" "github.com/zitadel/zitadel/backend/v3/storage/database" - "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -31,10 +30,25 @@ func (u *UpdateOrgCommand) Execute(ctx context.Context, opts *CommandOpts) (err organizationRepo := opts.orgRepo() + org, err := organizationRepo.Get(ctx, database.WithCondition(organizationRepo.IDCondition(u.ID))) + if err != nil { + return err + } + + if org.Name == u.Name { + err = NewOrgNameNotChangedError("DOM-nDzwIu") + return err + } + + if org.State == OrgStateInactive { + err = NewOrgNotFoundError("DOM-OcA1jq") + return err + } + updateCount, err := organizationRepo.Update( ctx, - organizationRepo.IDCondition(u.ID), - authz.GetInstance(ctx).InstanceID(), + organizationRepo.IDCondition(org.ID), + org.InstanceID, database.NewChange(organizationRepo.NameColumn(), u.Name), ) if err != nil { diff --git a/backend/v3/domain/org_update_test.go b/backend/v3/domain/org_update_test.go index c0ebdb6348b..3d3a1c4850b 100644 --- a/backend/v3/domain/org_update_test.go +++ b/backend/v3/domain/org_update_test.go @@ -17,7 +17,10 @@ import ( ) func TestUpdateOrgCommand_Execute(t *testing.T) { + t.Parallel() + txInitErr := errors.New("tx init error") + getErr := errors.New("get error") updateErr := errors.New("update error") tt := []struct { @@ -43,14 +46,79 @@ func TestUpdateOrgCommand_Execute(t *testing.T) { }, expectedError: txInitErr, }, + { + testName: "when retrieving org fails should return error", + orgRepo: func(ctrl *gomock.Controller) func(database.QueryExecutor) domain.OrganizationRepository { + repo := domainmock.NewOrgRepo(ctrl) + repo.EXPECT(). + Get(gomock.Any(), gomock.Any()). + Times(1). + Return(nil, getErr) + return func(_ database.QueryExecutor) domain.OrganizationRepository { + return repo + } + }, + inputID: "org-1", + inputName: "test org update", + expectedError: getErr, + }, + { + testName: "when org name is not changed should return name not changed error", + orgRepo: func(ctrl *gomock.Controller) func(database.QueryExecutor) domain.OrganizationRepository { + repo := domainmock.NewOrgRepo(ctrl) + repo.EXPECT(). + Get(gomock.Any(), gomock.Any()). + Times(1). + Return(&domain.Organization{ + ID: "org-1", + Name: "test org update", + }, nil) + return func(_ database.QueryExecutor) domain.OrganizationRepository { + return repo + } + }, + inputID: "org-1", + inputName: "test org update", + expectedError: domain.NewOrgNameNotChangedError("DOM-nDzwIu"), + }, + { + testName: "when org is inactive should return not found error", + orgRepo: func(ctrl *gomock.Controller) func(database.QueryExecutor) domain.OrganizationRepository { + repo := domainmock.NewOrgRepo(ctrl) + repo.EXPECT(). + Get(gomock.Any(), gomock.Any()). + Times(1). + Return(&domain.Organization{ + ID: "org-1", + Name: "old org name", + State: domain.OrgStateInactive, + }, nil) + return func(_ database.QueryExecutor) domain.OrganizationRepository { + return repo + } + }, + inputID: "org-1", + inputName: "test org update", + expectedError: domain.NewOrgNotFoundError("DOM-OcA1jq"), + }, { testName: "when org update fails should return error", orgRepo: func(ctrl *gomock.Controller) func(database.QueryExecutor) domain.OrganizationRepository { repo := domainmock.NewOrgRepo(ctrl) + repo.EXPECT(). + Get(gomock.Any(), gomock.Any()). + Times(1). + Return(&domain.Organization{ + ID: "org-1", + Name: "old org name", + InstanceID: "instance-1", + State: domain.OrgStateActive, + }, nil) + repo.EXPECT(). Update(gomock.Any(), repo.IDCondition("org-1"), "instance-1", repo.SetName("test org update")). - Return(int64(0), updateErr). - AnyTimes() + Times(1). + Return(int64(0), updateErr) return func(_ database.QueryExecutor) domain.OrganizationRepository { return repo } @@ -65,10 +133,20 @@ func TestUpdateOrgCommand_Execute(t *testing.T) { inputName: "test org update", orgRepo: func(ctrl *gomock.Controller) func(database.QueryExecutor) domain.OrganizationRepository { repo := domainmock.NewOrgRepo(ctrl) + repo.EXPECT(). + Get(gomock.Any(), gomock.Any()). + Times(1). + Return(&domain.Organization{ + ID: "org-1", + Name: "old org name", + InstanceID: "instance-1", + State: domain.OrgStateActive, + }, nil) + repo.EXPECT(). Update(gomock.Any(), repo.IDCondition("org-1"), "instance-1", repo.SetName("test org update")). - Return(int64(0), nil). - AnyTimes() + Times(1). + Return(int64(0), nil) return func(_ database.QueryExecutor) domain.OrganizationRepository { return repo } @@ -79,10 +157,20 @@ func TestUpdateOrgCommand_Execute(t *testing.T) { testName: "when org update returns more than 1 row updated should return internal error", orgRepo: func(ctrl *gomock.Controller) func(database.QueryExecutor) domain.OrganizationRepository { repo := domainmock.NewOrgRepo(ctrl) + repo.EXPECT(). + Get(gomock.Any(), gomock.Any()). + Times(1). + Return(&domain.Organization{ + ID: "org-1", + Name: "old org name", + InstanceID: "instance-1", + State: domain.OrgStateActive, + }, nil) + repo.EXPECT(). Update(gomock.Any(), repo.IDCondition("org-1"), "instance-1", repo.SetName("test org update")). - Return(int64(2), nil). - AnyTimes() + Times(1). + Return(int64(2), nil) return func(_ database.QueryExecutor) domain.OrganizationRepository { return repo } @@ -95,10 +183,20 @@ func TestUpdateOrgCommand_Execute(t *testing.T) { testName: "when org update returns 1 row updated should return no error and set cache", orgRepo: func(ctrl *gomock.Controller) func(database.QueryExecutor) domain.OrganizationRepository { repo := domainmock.NewOrgRepo(ctrl) + repo.EXPECT(). + Get(gomock.Any(), gomock.Any()). + Times(1). + Return(&domain.Organization{ + ID: "org-1", + Name: "old org name", + InstanceID: "instance-1", + State: domain.OrgStateActive, + }, nil) + repo.EXPECT(). Update(gomock.Any(), repo.IDCondition("org-1"), "instance-1", repo.SetName("test org update")). - Return(int64(1), nil). - AnyTimes() + Times(1). + Return(int64(1), nil) return func(_ database.QueryExecutor) domain.OrganizationRepository { return repo } @@ -129,6 +227,7 @@ func TestUpdateOrgCommand_Execute(t *testing.T) { opts.DB = tc.queryExecutor(ctrl) } + // Test err := cmd.Execute(ctx, opts) // Verify