diff --git a/internal/api/grpc/group/v2/integration_test/query_test.go b/internal/api/grpc/group/v2/integration_test/query_test.go new file mode 100644 index 00000000000..1f325ede985 --- /dev/null +++ b/internal/api/grpc/group/v2/integration_test/query_test.go @@ -0,0 +1,369 @@ +//go:build integration + +package group_test + +import ( + "context" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + + "github.com/zitadel/zitadel/internal/integration" + "github.com/zitadel/zitadel/pkg/grpc/filter/v2" + group_v2 "github.com/zitadel/zitadel/pkg/grpc/group/v2" +) + +func TestServer_GetGroup(t *testing.T) { + iamOwnerCtx := instance.WithAuthorizationToken(CTX, integration.UserTypeIAMOwner) + type args struct { + ctx context.Context + req *group_v2.GetGroupRequest + dep func(*group_v2.GetGroupRequest, *group_v2.GetGroupResponse) + } + tests := []struct { + name string + args args + want *group_v2.GetGroupResponse + wantErrCode codes.Code + wantErrMsg string + }{ + { + name: "unauthenticated", + args: args{ + ctx: context.Background(), + req: &group_v2.GetGroupRequest{}, + }, + wantErrCode: codes.Unauthenticated, + wantErrMsg: "auth header missing", + }, + { + name: "missing id", + args: args{ + ctx: iamOwnerCtx, + req: &group_v2.GetGroupRequest{}, + }, + wantErrCode: codes.InvalidArgument, + wantErrMsg: "invalid GetGroupRequest.Id: value length must be between 1 and 200 runes, inclusive", + }, + { + name: "get group, not found", + args: args{ + ctx: iamOwnerCtx, + req: &group_v2.GetGroupRequest{ + Id: "group1", + }, + }, + wantErrCode: codes.NotFound, + wantErrMsg: "Errors.Group.NotFound (QUERY-SG4WbR)", + }, + { + name: "get group, found", + args: args{ + ctx: iamOwnerCtx, + dep: func(req *group_v2.GetGroupRequest, resp *group_v2.GetGroupResponse) { + orgResp := instance.CreateOrganization(iamOwnerCtx, integration.OrganizationName(), integration.Email()) + groupName := integration.GroupName() + group := instance.CreateGroup(iamOwnerCtx, t, orgResp.GetOrganizationId(), groupName) + + req.Id = group.GetId() + resp.Group = &group_v2.Group{ + Id: group.GetId(), + Name: groupName, + Description: "", + OrganizationId: orgResp.GetOrganizationId(), + ChangeDate: group.GetCreationDate(), + CreationDate: group.GetCreationDate(), + } + }, + req: &group_v2.GetGroupRequest{}, + }, + want: &group_v2.GetGroupResponse{ + Group: &group_v2.Group{}, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.dep != nil { + tt.args.dep(tt.args.req, tt.want) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(iamOwnerCtx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := instance.Client.GroupV2.GetGroup(tt.args.ctx, tt.args.req) + if tt.wantErrCode != codes.OK { + require.Error(t, err) + assert.Equal(t, tt.wantErrCode, status.Code(err)) + assert.Equal(t, tt.wantErrMsg, status.Convert(err).Message()) + return + } + require.NoError(t, err) + assert.EqualExportedValues(t, tt.want.Group, got.Group, "want: %v, got: %v", tt.want.Group, got.Group) + }, retryDuration, tick, "timeout waiting for expected result") + }) + } +} + +func TestServer_ListGroups(t *testing.T) { + iamOwnerCtx := instance.WithAuthorizationToken(CTX, integration.UserTypeIAMOwner) + type args struct { + ctx context.Context + req *group_v2.ListGroupsRequest + dep func(*group_v2.ListGroupsRequest, *group_v2.ListGroupsResponse) + } + tests := []struct { + name string + args args + want *group_v2.ListGroupsResponse + wantErrCode codes.Code + wantErrMsg string + }{ + { + name: "list groups, unauthenticated", + args: args{ + ctx: CTX, + req: &group_v2.ListGroupsRequest{}, + }, + wantErrCode: codes.Unauthenticated, + wantErrMsg: "auth header missing", + }, + { + name: "group ID not found", + args: args{ + ctx: iamOwnerCtx, + req: &group_v2.ListGroupsRequest{ + Filters: []*group_v2.GroupsSearchFilter{ + { + Filter: &group_v2.GroupsSearchFilter_GroupIds{ + GroupIds: &filter.InIDsFilter{ + Ids: []string{"random-group"}, + }, + }, + }, + }, + }, + }, + want: &group_v2.ListGroupsResponse{ + Pagination: &filter.PaginationResponse{ + TotalResult: 0, + AppliedLimit: 100, + }, + }, + }, + { + name: "list single group by ID, ok", + args: args{ + ctx: iamOwnerCtx, + dep: func(req *group_v2.ListGroupsRequest, resp *group_v2.ListGroupsResponse) { + orgResp := instance.CreateOrganization(iamOwnerCtx, integration.OrganizationName(), integration.Email()) + groupName := integration.GroupName() + group1 := instance.CreateGroup(iamOwnerCtx, t, orgResp.GetOrganizationId(), groupName) + + resp.Groups[0] = &group_v2.Group{ + Id: group1.GetId(), + Name: groupName, + Description: "", + OrganizationId: orgResp.GetOrganizationId(), + CreationDate: group1.GetCreationDate(), + ChangeDate: group1.GetCreationDate(), + } + req.Filters[0].Filter = &group_v2.GroupsSearchFilter_GroupIds{ + GroupIds: &filter.InIDsFilter{ + Ids: []string{group1.GetId()}, + }, + } + }, + req: &group_v2.ListGroupsRequest{ + Filters: []*group_v2.GroupsSearchFilter{{}}, + }, + }, + want: &group_v2.ListGroupsResponse{ + Pagination: &filter.PaginationResponse{ + TotalResult: 1, + AppliedLimit: 100, + }, + Groups: []*group_v2.Group{ + {}, + }, + }, + }, + { + name: "list multiple groups by IDs, ok", + args: args{ + ctx: iamOwnerCtx, + dep: func(req *group_v2.ListGroupsRequest, resp *group_v2.ListGroupsResponse) { + orgResp := instance.CreateOrganization(iamOwnerCtx, integration.OrganizationName(), integration.Email()) + groupName1 := integration.GroupName() + group1 := instance.CreateGroup(iamOwnerCtx, t, orgResp.GetOrganizationId(), groupName1) + + resp.Groups[1] = &group_v2.Group{ + Id: group1.GetId(), + Name: groupName1, + Description: "", + OrganizationId: orgResp.GetOrganizationId(), + CreationDate: group1.GetCreationDate(), + ChangeDate: group1.GetCreationDate(), + } + groupName2 := integration.GroupName() + group2 := instance.CreateGroup(iamOwnerCtx, t, orgResp.GetOrganizationId(), groupName2) + + resp.Groups[0] = &group_v2.Group{ + Id: group2.GetId(), + Name: groupName2, + Description: "", + OrganizationId: orgResp.GetOrganizationId(), + CreationDate: group2.GetCreationDate(), + ChangeDate: group2.GetCreationDate(), + } + req.Filters[0].Filter = &group_v2.GroupsSearchFilter_GroupIds{ + GroupIds: &filter.InIDsFilter{ + Ids: []string{group1.GetId(), group2.GetId()}, + }, + } + }, + req: &group_v2.ListGroupsRequest{ + Filters: []*group_v2.GroupsSearchFilter{{}}, + }, + }, + want: &group_v2.ListGroupsResponse{ + Pagination: &filter.PaginationResponse{ + TotalResult: 2, + AppliedLimit: 100, + }, + Groups: []*group_v2.Group{ + {}, {}, + }, + }, + }, + { + name: "list group by name, ok", + args: args{ + ctx: iamOwnerCtx, + dep: func(req *group_v2.ListGroupsRequest, resp *group_v2.ListGroupsResponse) { + orgResp := instance.CreateOrganization(iamOwnerCtx, integration.OrganizationName(), integration.Email()) + groupName := integration.GroupName() + group1 := instance.CreateGroup(iamOwnerCtx, t, orgResp.GetOrganizationId(), groupName) + + resp.Groups[0] = &group_v2.Group{ + Id: group1.GetId(), + Name: groupName, + Description: "", + OrganizationId: orgResp.GetOrganizationId(), + CreationDate: group1.GetCreationDate(), + ChangeDate: group1.GetCreationDate(), + } + req.Filters[0].Filter = &group_v2.GroupsSearchFilter_NameFilter{ + NameFilter: &group_v2.GroupNameFilter{ + Name: groupName, + }, + } + }, + req: &group_v2.ListGroupsRequest{ + Filters: []*group_v2.GroupsSearchFilter{{}}, + }, + }, + want: &group_v2.ListGroupsResponse{ + Pagination: &filter.PaginationResponse{ + TotalResult: 1, + AppliedLimit: 100, + }, + Groups: []*group_v2.Group{ + {}, + }, + }, + }, + { + name: "list by organization ID, ok", + args: args{ + ctx: iamOwnerCtx, + dep: func(req *group_v2.ListGroupsRequest, resp *group_v2.ListGroupsResponse) { + org1 := instance.CreateOrganization(iamOwnerCtx, integration.OrganizationName(), integration.Email()) + groupName2 := integration.GroupName() + group2 := instance.CreateGroup(iamOwnerCtx, t, org1.GetOrganizationId(), groupName2) + + resp.Groups[2] = &group_v2.Group{ + Id: group2.GetId(), + Name: groupName2, + Description: "", + OrganizationId: org1.GetOrganizationId(), + CreationDate: group2.GetCreationDate(), + ChangeDate: group2.GetCreationDate(), + } + groupName1 := integration.GroupName() + group1 := instance.CreateGroup(iamOwnerCtx, t, org1.GetOrganizationId(), groupName1) + + resp.Groups[1] = &group_v2.Group{ + Id: group1.GetId(), + Name: groupName1, + Description: "", + OrganizationId: org1.GetOrganizationId(), + CreationDate: group1.GetCreationDate(), + ChangeDate: group1.GetCreationDate(), + } + groupName0 := integration.GroupName() + group0 := instance.CreateGroup(iamOwnerCtx, t, org1.GetOrganizationId(), groupName0) + + resp.Groups[0] = &group_v2.Group{ + Id: group0.GetId(), + Name: groupName0, + Description: "", + OrganizationId: org1.GetOrganizationId(), + CreationDate: group0.GetCreationDate(), + ChangeDate: group0.GetCreationDate(), + } + org2 := instance.CreateOrganization(iamOwnerCtx, integration.OrganizationName(), integration.Email()) + org2GroupName0 := integration.GroupName() + _ = instance.CreateGroup(iamOwnerCtx, t, org2.GetOrganizationId(), org2GroupName0) + + req.Filters[0].Filter = &group_v2.GroupsSearchFilter_OrganizationId{ + OrganizationId: &filter.IDFilter{ + Id: org1.GetOrganizationId(), + }, + } + }, + req: &group_v2.ListGroupsRequest{ + Filters: []*group_v2.GroupsSearchFilter{{}}, + }, + }, + want: &group_v2.ListGroupsResponse{ + Pagination: &filter.PaginationResponse{ + TotalResult: 3, + AppliedLimit: 100, + }, + Groups: []*group_v2.Group{ + {}, {}, {}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if tt.args.dep != nil { + tt.args.dep(tt.args.req, tt.want) + } + + retryDuration, tick := integration.WaitForAndTickWithMaxDuration(iamOwnerCtx, time.Minute) + require.EventuallyWithT(t, func(ttt *assert.CollectT) { + got, err := instance.Client.GroupV2.ListGroups(tt.args.ctx, tt.args.req) + if tt.wantErrCode != codes.OK { + require.Error(t, err) + assert.Equal(t, tt.wantErrCode, status.Code(err)) + assert.Equal(t, tt.wantErrMsg, status.Convert(err).Message()) + return + } + require.NoError(t, err) + if assert.Len(t, got.Groups, len(tt.want.Groups)) { + for i := range got.Groups { + assert.EqualExportedValues(t, tt.want.Groups[i], got.Groups[i], "want: %v, got: %v", tt.want.Groups[i], got.Groups[i]) + } + } + assert.Equal(t, tt.want.Pagination.AppliedLimit, got.Pagination.AppliedLimit) + assert.Equal(t, tt.want.Pagination.TotalResult, got.Pagination.TotalResult) + }, retryDuration, tick, "timeout waiting for expected result") + }) + } +} diff --git a/internal/api/grpc/group/v2/query.go b/internal/api/grpc/group/v2/query.go index 859b438caf3..b4cbec31589 100644 --- a/internal/api/grpc/group/v2/query.go +++ b/internal/api/grpc/group/v2/query.go @@ -6,40 +6,117 @@ import ( "connectrpc.com/connect" "google.golang.org/protobuf/types/known/timestamppb" + "github.com/zitadel/zitadel/internal/api/grpc/filter/v2" + "github.com/zitadel/zitadel/internal/config/systemdefaults" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/zerrors" - group "github.com/zitadel/zitadel/pkg/grpc/group/v2" + group_v2 "github.com/zitadel/zitadel/pkg/grpc/group/v2" ) // GetGroup returns a group that matches the group ID in the request -func (s *Server) GetGroup(ctx context.Context, req *connect.Request[group.GetGroupRequest]) (*connect.Response[group.GetGroupResponse], error) { - return nil, zerrors.ThrowUnimplemented(nil, "GRP-1234", "Errors.Internal.Unimplemented") -} - -// ListGroups returns a list of groups that match the search criteria -func (s *Server) ListGroups(ctx context.Context, req *connect.Request[group.ListGroupsRequest]) (*connect.Response[group.ListGroupsResponse], error) { - resp, err := s.query.SearchGroups(ctx) +func (s *Server) GetGroup(ctx context.Context, req *connect.Request[group_v2.GetGroupRequest]) (*connect.Response[group_v2.GetGroupResponse], error) { + group, err := s.query.GetGroupByID(ctx, req.Msg.GetId()) if err != nil { return nil, err } - return connect.NewResponse(&group.ListGroupsResponse{ - Groups: groupsToPb(resp.Groups), + return connect.NewResponse(&group_v2.GetGroupResponse{ + Group: groupToPb(group), }), nil } -func groupsToPb(groups []*query.Group) []*group.Group { - pbGroups := make([]*group.Group, len(groups)) +// ListGroups returns a list of groups that match the search criteria +func (s *Server) ListGroups(ctx context.Context, req *connect.Request[group_v2.ListGroupsRequest]) (*connect.Response[group_v2.ListGroupsResponse], error) { + queries, err := listGroupsRequestToModel(req.Msg, s.systemDefaults) + if err != nil { + return nil, err + } + resp, err := s.query.SearchGroups(ctx, queries) + if err != nil { + return nil, err + } + return connect.NewResponse(&group_v2.ListGroupsResponse{ + Groups: groupsToPb(resp.Groups), + Pagination: filter.QueryToPaginationPb(queries.SearchRequest, resp.SearchResponse), + }), nil +} + +func listGroupsRequestToModel(req *group_v2.ListGroupsRequest, systemDefaults systemdefaults.SystemDefaults) (*query.GroupSearchQuery, error) { + offset, limit, asc, err := filter.PaginationPbToQuery(systemDefaults, req.GetPagination()) + if err != nil { + return nil, err + } + queries, err := groupSearchFiltersToQuery(req.GetFilters()) + if err != nil { + return nil, err + } + return &query.GroupSearchQuery{ + SearchRequest: query.SearchRequest{ + Offset: offset, + Limit: limit, + Asc: asc, + SortingColumn: groupFieldNameToSortingColumn(req.SortingColumn), + }, + Queries: queries, + }, nil +} + +func groupSearchFiltersToQuery(filters []*group_v2.GroupsSearchFilter) (_ []query.SearchQuery, err error) { + q := make([]query.SearchQuery, len(filters)) + for i, f := range filters { + q[i], err = groupFilterToQuery(f) + if err != nil { + return nil, err + } + } + return q, nil +} + +func groupFilterToQuery(f *group_v2.GroupsSearchFilter) (query.SearchQuery, error) { + switch q := f.Filter.(type) { + case *group_v2.GroupsSearchFilter_GroupIds: + return query.NewGroupIDsSearchQuery(q.GroupIds.GetIds()) + case *group_v2.GroupsSearchFilter_NameFilter: + return query.NewGroupNameSearchQuery(q.NameFilter.GetName(), filter.TextMethodPbToQuery(q.NameFilter.GetMethod())) + case *group_v2.GroupsSearchFilter_OrganizationId: + return query.NewGroupOrganizationIdSearchQuery(q.OrganizationId.GetId()) + default: + return nil, zerrors.ThrowInvalidArgument(nil, "GRPC-g3f4g", "List.Query.Invalid") + } +} + +func groupFieldNameToSortingColumn(field *group_v2.FieldName) query.Column { + if field == nil { + return query.GroupColumnCreationDate + } + switch *field { + case group_v2.FieldName_FIELD_NAME_CREATION_DATE, group_v2.FieldName_FIELD_NAME_UNSPECIFIED: + return query.GroupColumnCreationDate + case group_v2.FieldName_FIELD_NAME_ID: + return query.GroupColumnID + case group_v2.FieldName_FIELD_NAME_NAME: + return query.GroupColumnName + case group_v2.FieldName_FIELD_NAME_CHANGE_DATE: + return query.GroupColumnChangeDate + default: + return query.GroupColumnCreationDate + } +} + +func groupsToPb(groups []*query.Group) []*group_v2.Group { + pbGroups := make([]*group_v2.Group, len(groups)) for i, g := range groups { pbGroups[i] = groupToPb(g) } return pbGroups } -func groupToPb(g *query.Group) *group.Group { - return &group.Group{ - Id: g.ID, - Name: g.Name, - CreationDate: timestamppb.New(g.CreationDate), - ChangeDate: timestamppb.New(g.ChangeDate), +func groupToPb(g *query.Group) *group_v2.Group { + return &group_v2.Group{ + Id: g.ID, + Name: g.Name, + Description: g.Description, + OrganizationId: g.ResourceOwner, + CreationDate: timestamppb.New(g.CreationDate), + ChangeDate: timestamppb.New(g.ChangeDate), } } diff --git a/internal/api/grpc/group/v2/query_test.go b/internal/api/grpc/group/v2/query_test.go new file mode 100644 index 00000000000..c00e0c083f9 --- /dev/null +++ b/internal/api/grpc/group/v2/query_test.go @@ -0,0 +1,269 @@ +package group + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/muhlemmer/gu" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/timestamppb" + + "github.com/zitadel/zitadel/internal/config/systemdefaults" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query" + "github.com/zitadel/zitadel/internal/zerrors" + "github.com/zitadel/zitadel/pkg/grpc/filter/v2" + group_v2 "github.com/zitadel/zitadel/pkg/grpc/group/v2" +) + +func Test_ListGroupsRequestToModel(t *testing.T) { + + groupIDsSearchQuery, err := query.NewGroupIDsSearchQuery([]string{"group1", "group2"}) + require.NoError(t, err) + + tests := []struct { + name string + maxQueryLimit uint64 + req *group_v2.ListGroupsRequest + wantResp *query.GroupSearchQuery + wantErr error + }{ + { + name: "max query limit exceeded", + maxQueryLimit: 1, + req: &group_v2.ListGroupsRequest{ + Pagination: &filter.PaginationRequest{ + Limit: 5, + }, + Filters: []*group_v2.GroupsSearchFilter{ + { + Filter: &group_v2.GroupsSearchFilter_GroupIds{ + GroupIds: &filter.InIDsFilter{ + Ids: []string{"group1", "group2"}, + }, + }, + }, + }, + }, + wantErr: zerrors.ThrowInvalidArgumentf(errors.New("given: 5, allowed: 1"), "QUERY-4M0fs", "Errors.Query.LimitExceeded"), + }, + { + name: "valid request, list of group IDs, ok", + req: &group_v2.ListGroupsRequest{ + Filters: []*group_v2.GroupsSearchFilter{ + { + Filter: &group_v2.GroupsSearchFilter_GroupIds{ + GroupIds: &filter.InIDsFilter{ + Ids: []string{"group1", "group2"}, + }, + }, + }, + }, + }, + wantResp: &query.GroupSearchQuery{ + SearchRequest: query.SearchRequest{ + Offset: 0, + Limit: 0, + SortingColumn: query.GroupColumnCreationDate, + }, + Queries: []query.SearchQuery{groupIDsSearchQuery}, + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + sysDefaults := systemdefaults.SystemDefaults{MaxQueryLimit: tt.maxQueryLimit} + got, err := listGroupsRequestToModel(tt.req, sysDefaults) + if tt.wantErr != nil { + assert.Equal(t, tt.wantErr, err) + return + } + for _, q := range got.Queries { + fmt.Printf("%+v", q) + } + + require.NoError(t, err) + assert.Equal(t, tt.wantResp, got) + }) + } +} + +func Test_GroupSearchFiltersToQuery(t *testing.T) { + groupIDsSearchQuery, err := query.NewGroupIDsSearchQuery([]string{"group1", "group2"}) + require.NoError(t, err) + groupNameSearchQuery, err := query.NewGroupNameSearchQuery("mygroup", query.TextStartsWith) + require.NoError(t, err) + groupOrgIDSearchQuery, err := query.NewGroupOrganizationIdSearchQuery("org1") + require.NoError(t, err) + + tests := []struct { + name string + filters []*group_v2.GroupsSearchFilter + want []query.SearchQuery + wantErr error + }{ + { + name: "empty", + filters: []*group_v2.GroupsSearchFilter{}, + want: []query.SearchQuery{}, + }, + { + name: "all filters", + filters: []*group_v2.GroupsSearchFilter{ + { + Filter: &group_v2.GroupsSearchFilter_GroupIds{ + GroupIds: &filter.InIDsFilter{ + Ids: []string{"group1", "group2"}, + }, + }, + }, + { + Filter: &group_v2.GroupsSearchFilter_NameFilter{ + NameFilter: &group_v2.GroupNameFilter{ + Name: "mygroup", + Method: filter.TextFilterMethod_TEXT_FILTER_METHOD_STARTS_WITH, + }, + }, + }, + { + Filter: &group_v2.GroupsSearchFilter_OrganizationId{ + OrganizationId: &filter.IDFilter{ + Id: "org1", + }, + }, + }, + }, + want: []query.SearchQuery{ + groupIDsSearchQuery, + groupNameSearchQuery, + groupOrgIDSearchQuery, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := groupSearchFiltersToQuery(tt.filters) + if tt.wantErr != nil { + assert.Equal(t, tt.wantErr, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_GroupFieldNameToSortingColumn(t *testing.T) { + tests := []struct { + name string + field *group_v2.FieldName + want query.Column + }{ + { + name: "nil", + field: nil, + want: query.GroupColumnCreationDate, + }, + { + name: "creation date", + field: gu.Ptr(group_v2.FieldName_FIELD_NAME_CREATION_DATE), + want: query.GroupColumnCreationDate, + }, + { + name: "unspecified", + field: gu.Ptr(group_v2.FieldName_FIELD_NAME_UNSPECIFIED), + want: query.GroupColumnCreationDate, + }, + { + name: "id", + field: gu.Ptr(group_v2.FieldName_FIELD_NAME_ID), + want: query.GroupColumnID, + }, + { + name: "name", + field: gu.Ptr(group_v2.FieldName_FIELD_NAME_NAME), + want: query.GroupColumnName, + }, + { + name: "change date", + field: gu.Ptr(group_v2.FieldName_FIELD_NAME_CHANGE_DATE), + want: query.GroupColumnChangeDate, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := groupFieldNameToSortingColumn(tt.field) + assert.Equal(t, tt.want, got) + }) + } +} + +func Test_GroupsToPb(t *testing.T) { + timeNow := time.Now().UTC() + tests := []struct { + name string + groups []*query.Group + want []*group_v2.Group + }{ + { + name: "empty", + groups: []*query.Group{}, + want: []*group_v2.Group{}, + }, + { + name: "with groups, ok", + groups: []*query.Group{ + { + ID: "group1", + Name: "mygroup", + Description: "my first group", + CreationDate: timeNow, + ChangeDate: timeNow, + ResourceOwner: "org1", + InstanceID: "instance1", + State: domain.GroupStateActive, + Sequence: 1, + }, + { + ID: "group2", + Name: "mygroup2", + Description: "my second group", + CreationDate: timeNow, + ChangeDate: timeNow, + ResourceOwner: "org1", + InstanceID: "instance1", + State: domain.GroupStateActive, + Sequence: 1, + }, + }, + want: []*group_v2.Group{ + { + Id: "group1", + Name: "mygroup", + Description: "my first group", + OrganizationId: "org1", + ChangeDate: timestamppb.New(timeNow), + CreationDate: timestamppb.New(timeNow), + }, + { + Id: "group2", + Name: "mygroup2", + Description: "my second group", + OrganizationId: "org1", + ChangeDate: timestamppb.New(timeNow), + CreationDate: timestamppb.New(timeNow), + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := groupsToPb(tt.groups) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/query/group.go b/internal/query/group.go index 983c9f60f44..493f33433ea 100644 --- a/internal/query/group.go +++ b/internal/query/group.go @@ -2,11 +2,63 @@ package query import ( "context" + "database/sql" + "errors" "time" + sq "github.com/Masterminds/squirrel" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/query/projection" + "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" ) +var ( + groupsTable = table{ + name: projection.GroupProjectionTable, + instanceIDCol: projection.GroupColumnInstanceID, + } + + GroupColumnID = Column{ + name: projection.GroupColumnID, + table: groupsTable, + } + GroupColumnName = Column{ + name: projection.GroupColumnName, + table: groupsTable, + } + GroupColumnDescription = Column{ + name: projection.GroupColumnDescription, + table: groupsTable, + } + GroupColumnResourceOwner = Column{ + name: projection.GroupColumnResourceOwner, + table: groupsTable, + } + GroupColumnCreationDate = Column{ + name: projection.GroupColumnCreationDate, + table: groupsTable, + } + GroupColumnChangeDate = Column{ + name: projection.GroupColumnChangeDate, + table: groupsTable, + } + GroupColumnInstanceID = Column{ + name: projection.GroupColumnInstanceID, + table: groupsTable, + } + GroupColumnSequence = Column{ + name: projection.GroupColumnSequence, + table: groupsTable, + } + GroupColumnState = Column{ + name: projection.GroupColumnState, + table: groupsTable, + } +) + type Groups struct { SearchResponse Groups []*Group @@ -19,9 +71,185 @@ type Group struct { CreationDate time.Time ChangeDate time.Time ResourceOwner string + InstanceID string + State domain.GroupState + Sequence uint64 +} + +type GroupSearchQuery struct { + SearchRequest + Queries []SearchQuery +} + +func (q *Queries) GetGroupByID(ctx context.Context, id string) (group *Group, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + // todo: add permission check + + stmt, scan := prepareGroupQuery() + eq := sq.Eq{ + GroupColumnID.identifier(): id, + GroupColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + } + query, args, err := stmt.Where(eq).ToSql() + if err != nil { + return nil, zerrors.ThrowInternal(err, "QUERY-8bde1", "Errors.Query.SQLStatement") + } + + err = q.client.QueryRowContext(ctx, func(row *sql.Row) error { + group, err = scan(row) + return err + }, query, args...) + return group, err + } // SearchGroups returns the list of groups that match the search criteria -func (q *Queries) SearchGroups(ctx context.Context) (*Groups, error) { - return nil, zerrors.ThrowUnimplemented(nil, "QUERY-grpfli", "Not implemented") +func (q *Queries) SearchGroups(ctx context.Context, queries *GroupSearchQuery) (_ *Groups, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + // todo: add permission check + + groups, err := q.searchGroups(ctx, queries) + if err != nil { + return nil, err + } + return groups, nil +} + +func NewGroupNameSearchQuery(value string, comparison TextComparison) (SearchQuery, error) { + return NewTextQuery(GroupColumnName, value, comparison) +} + +func NewGroupIDsSearchQuery(ids []string) (SearchQuery, error) { + list := make([]interface{}, len(ids)) + for i, value := range ids { + list[i] = value + } + return NewListQuery(GroupColumnID, list, ListIn) +} + +func NewGroupOrganizationIdSearchQuery(id string) (SearchQuery, error) { + return NewTextQuery(GroupColumnResourceOwner, id, TextEquals) +} + +func prepareGroupQuery() (sq.SelectBuilder, func(*sql.Row) (*Group, error)) { + return sq.Select( + GroupColumnID.identifier(), + GroupColumnName.identifier(), + GroupColumnDescription.identifier(), + GroupColumnCreationDate.identifier(), + GroupColumnChangeDate.identifier(), + GroupColumnResourceOwner.identifier(), + GroupColumnInstanceID.identifier(), + GroupColumnSequence.identifier(), + GroupColumnState.identifier()). + From(groupsTable.identifier()). + PlaceholderFormat(sq.Dollar), + func(row *sql.Row) (*Group, error) { + group := new(Group) + err := row.Scan( + &group.ID, + &group.Name, + &group.Description, + &group.CreationDate, + &group.ChangeDate, + &group.ResourceOwner, + &group.InstanceID, + &group.Sequence, + &group.State, + ) + if err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, zerrors.ThrowNotFound(err, "QUERY-SG4WbR", "Errors.Group.NotFound") + } + return nil, zerrors.ThrowInternal(err, "QUERY-6yHJEz", "Errors.Internal") + } + return group, nil + } +} + +func (q *Queries) searchGroups(ctx context.Context, queries *GroupSearchQuery) (groups *Groups, err error) { + ctx, span := tracing.NewSpan(ctx) + defer func() { span.EndWithError(err) }() + + query, scan := prepareGroupsQuery() + eq := sq.And{ + sq.Eq{ + GroupColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), + }, + } + stmt, args, err := queries.toQuery(query).Where(eq).ToSql() + if err != nil { + return nil, zerrors.ThrowInvalidArgument(err, "QUERY-FpBnrv", "Errors.Query.InvalidRequest") + } + + err = q.client.QueryContext(ctx, func(rows *sql.Rows) error { + groups, err = scan(rows) + return err + }, stmt, args...) + if err != nil { + return nil, zerrors.ThrowInternal(err, "QUERY-vnQf5N", "Errors.Internal") + } + groups.State, err = q.latestState(ctx, groupsTable) + return groups, err +} + +func prepareGroupsQuery() (sq.SelectBuilder, func(*sql.Rows) (*Groups, error)) { + return sq.Select( + GroupColumnID.identifier(), + GroupColumnName.identifier(), + GroupColumnDescription.identifier(), + GroupColumnCreationDate.identifier(), + GroupColumnChangeDate.identifier(), + GroupColumnResourceOwner.identifier(), + GroupColumnInstanceID.identifier(), + GroupColumnSequence.identifier(), + GroupColumnState.identifier(), + countColumn.identifier()). + From(groupsTable.identifier()). + PlaceholderFormat(sq.Dollar), + func(rows *sql.Rows) (*Groups, error) { + groups := make([]*Group, 0) + var count uint64 + for rows.Next() { + group := new(Group) + err := rows.Scan( + &group.ID, + &group.Name, + &group.Description, + &group.CreationDate, + &group.ChangeDate, + &group.ResourceOwner, + &group.InstanceID, + &group.Sequence, + &group.State, + &count, + ) + if err != nil { + return nil, err + } + groups = append(groups, group) + } + if err := rows.Close(); err != nil { + return nil, zerrors.ThrowInternal(err, "QUERY-ndNVod", "Errors.Query.CloseRows") + } + + return &Groups{ + Groups: groups, + SearchResponse: SearchResponse{ + Count: count, + }, + }, nil + } +} + +func (q *GroupSearchQuery) toQuery(query sq.SelectBuilder) sq.SelectBuilder { + query = q.SearchRequest.toQuery(query) + for _, q := range q.Queries { + query = q.toQuery(query) + } + return query } diff --git a/internal/query/group_test.go b/internal/query/group_test.go new file mode 100644 index 00000000000..fa86151414f --- /dev/null +++ b/internal/query/group_test.go @@ -0,0 +1,305 @@ +package query + +import ( + "database/sql" + "database/sql/driver" + "errors" + "fmt" + "regexp" + "testing" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/zerrors" +) + +var ( + prepareGroupsStmt = `SELECT projections.groups1.id,` + + ` projections.groups1.name,` + + ` projections.groups1.description,` + + ` projections.groups1.creation_date,` + + ` projections.groups1.change_date,` + + ` projections.groups1.resource_owner,` + + ` projections.groups1.instance_id,` + + ` projections.groups1.sequence,` + + ` projections.groups1.state,` + + ` COUNT(*) OVER ()` + + ` FROM projections.groups1` + + prepareGroupsColumns = []string{ + "id", + "name", + "description", + "creation_date", + "change_date", + "resource_owner", + "instance_id", + "sequence", + "state", + "count", + } + + prepareGroupStmt = `SELECT projections.groups1.id,` + + ` projections.groups1.name,` + + ` projections.groups1.description,` + + ` projections.groups1.creation_date,` + + ` projections.groups1.change_date,` + + ` projections.groups1.resource_owner,` + + ` projections.groups1.instance_id,` + + ` projections.groups1.sequence,` + + ` projections.groups1.state` + + ` FROM projections.groups1` + + prepareGroupColumns = []string{ + "id", + "name", + "description", + "creation_date", + "change_date", + "resource_owner", + "instance_id", + "sequence", + "state", + } +) + +func Test_GroupPrepares(t *testing.T) { + type want struct { + sqlExpectations sqlExpectation + err checkErr + } + tests := []struct { + name string + prepare interface{} + want want + object interface{} + }{ + { + name: "prepareGroupsQuery, no result", + prepare: prepareGroupsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(prepareGroupsStmt), + nil, + nil, + ), + }, + object: &Groups{Groups: []*Group{}}, + }, + { + name: "prepareGroupsQuery, one result", + prepare: prepareGroupsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(prepareGroupsStmt), + prepareGroupsColumns, + [][]driver.Value{ + { + "9090", + "group1", + "my new group", + testNow, + testNow, + "org1", + "instance1", + 1, + domain.GroupStateActive, + }, + }, + ), + }, + object: &Groups{ + SearchResponse: SearchResponse{ + Count: 1, + }, + Groups: []*Group{ + { + ID: "9090", + Name: "group1", + Description: "my new group", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "org1", + InstanceID: "instance1", + Sequence: 1, + State: domain.GroupStateActive, + }, + }, + }, + }, + { + name: "prepareGroupsQuery, multiple results", + prepare: prepareGroupsQuery, + want: want{ + sqlExpectations: mockQueries( + regexp.QuoteMeta(prepareGroupsStmt), + prepareGroupsColumns, + [][]driver.Value{ + { + "9091", + "group1", + "my first group", + testNow, + testNow, + "org1", + "instance1", + 1, + domain.GroupStateActive, + }, + { + "9092", + "group2", + "my second group", + testNow, + testNow, + "org1", + "instance1", + 1, + domain.GroupStateActive, + }, + }, + ), + }, + object: &Groups{ + SearchResponse: SearchResponse{ + Count: 2, + }, + Groups: []*Group{ + { + ID: "9091", + Name: "group1", + Description: "my first group", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "org1", + InstanceID: "instance1", + Sequence: 1, + State: domain.GroupStateActive, + }, + { + ID: "9092", + Name: "group2", + Description: "my second group", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "org1", + InstanceID: "instance1", + Sequence: 1, + State: domain.GroupStateActive, + }, + }, + }, + }, + { + name: "prepareGroupsQuery, sql err", + prepare: prepareGroupsQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(prepareGroupsStmt), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: (*Groups)(nil), + }, + { + name: "prepareGroupsQuery, no result", + prepare: prepareGroupsQuery, + want: want{ + sqlExpectations: mockQueriesScanErr( + regexp.QuoteMeta(prepareGroupsStmt), + nil, + nil, + ), + err: func(err error) (error, bool) { + if !zerrors.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: &Groups{ + SearchResponse: SearchResponse{ + Count: 0, + }, + Groups: []*Group{}, + }, + }, + { + name: "prepareGroupQuery, no result", + prepare: prepareGroupQuery, + want: want{ + sqlExpectations: mockQueriesScanErr( + prepareGroupStmt, + nil, + nil, + ), + err: func(err error) (error, bool) { + if !zerrors.IsNotFound(err) { + return fmt.Errorf("err should be zitadel.NotFoundError got: %w", err), false + } + return nil, true + }, + }, + object: (*Group)(nil), + }, + { + name: "prepareGroupQuery, sql err", + prepare: prepareGroupQuery, + want: want{ + sqlExpectations: mockQueryErr( + regexp.QuoteMeta(prepareGroupStmt), + sql.ErrConnDone, + ), + err: func(err error) (error, bool) { + if !errors.Is(err, sql.ErrConnDone) { + return fmt.Errorf("err should be sql.ErrConnDone got: %w", err), false + } + return nil, true + }, + }, + object: (*Group)(nil), + }, + { + name: "prepareGroupQuery, found", + prepare: prepareGroupQuery, + want: want{ + sqlExpectations: mockQuery( + regexp.QuoteMeta(prepareGroupStmt), + prepareGroupColumns, + []driver.Value{ + "9090", + "group1", + "my new group", + testNow, + testNow, + "org1", + "instance1", + 1, + domain.GroupStateActive, + }, + ), + }, + object: &Group{ + ID: "9090", + Name: "group1", + Description: "my new group", + CreationDate: testNow, + ChangeDate: testNow, + ResourceOwner: "org1", + InstanceID: "instance1", + Sequence: 1, + State: domain.GroupStateActive, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assertPrepare(t, tt.prepare, tt.object, tt.want.sqlExpectations, tt.want.err) + }) + } +} diff --git a/internal/query/projection/group.go b/internal/query/projection/group.go new file mode 100644 index 00000000000..10495c75726 --- /dev/null +++ b/internal/query/projection/group.go @@ -0,0 +1,179 @@ +package projection + +import ( + "context" + + "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/eventstore" + old_handler "github.com/zitadel/zitadel/internal/eventstore/handler" + "github.com/zitadel/zitadel/internal/eventstore/handler/v2" + "github.com/zitadel/zitadel/internal/repository/group" + "github.com/zitadel/zitadel/internal/repository/instance" + "github.com/zitadel/zitadel/internal/repository/org" +) + +const ( + GroupProjectionTable = "projections.groups1" + + GroupColumnID = "id" + GroupColumnName = "name" + GroupColumnResourceOwner = "resource_owner" + GroupColumnInstanceID = "instance_id" + GroupColumnCreationDate = "creation_date" + GroupColumnChangeDate = "change_date" + GroupColumnSequence = "sequence" + GroupColumnState = "state" + GroupColumnDescription = "description" +) + +type groupProjection struct{} + +func (g *groupProjection) Name() string { + return GroupProjectionTable +} + +func newGroupProjection(ctx context.Context, config handler.Config) *handler.Handler { + return handler.NewHandler(ctx, &config, new(groupProjection)) +} + +func (g *groupProjection) Init() *old_handler.Check { + return handler.NewTableCheck( + handler.NewTable([]*handler.InitColumn{ + handler.NewColumn(GroupColumnID, handler.ColumnTypeText), + handler.NewColumn(GroupColumnName, handler.ColumnTypeText), + handler.NewColumn(GroupColumnResourceOwner, handler.ColumnTypeText), + handler.NewColumn(GroupColumnInstanceID, handler.ColumnTypeText), + handler.NewColumn(GroupColumnDescription, handler.ColumnTypeText), + handler.NewColumn(GroupColumnCreationDate, handler.ColumnTypeTimestamp), + handler.NewColumn(GroupColumnChangeDate, handler.ColumnTypeTimestamp), + handler.NewColumn(GroupColumnSequence, handler.ColumnTypeInt64), + handler.NewColumn(GroupColumnState, handler.ColumnTypeEnum), + }, + handler.NewPrimaryKey(GroupColumnInstanceID, GroupColumnID), + handler.WithIndex(handler.NewIndex("resource_owner", []string{GroupColumnResourceOwner})), + ), + ) +} + +func (g *groupProjection) Reducers() []handler.AggregateReducer { + return []handler.AggregateReducer{ + { + Aggregate: group.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: group.GroupAddedEventType, + Reduce: g.reduceGroupAdded, + }, + { + Event: group.GroupChangedEventType, + Reduce: g.reduceGroupChanged, + }, + { + Event: group.GroupRemovedEventType, + Reduce: g.reduceGroupRemoved, + }, + }, + }, + { + Aggregate: group.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: org.OrgRemovedEventType, + Reduce: g.reduceOwnerRemoved, + }, + }, + }, + { + Aggregate: instance.AggregateType, + EventReducers: []handler.EventReducer{ + { + Event: instance.InstanceRemovedEventType, + Reduce: reduceInstanceRemovedHelper(GroupColumnInstanceID), + }, + }, + }, + } +} + +func (g *groupProjection) reduceGroupAdded(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*group.GroupAddedEvent](event) + if err != nil { + return nil, err + } + return handler.NewCreateStatement( + e, + []handler.Column{ + handler.NewCol(GroupColumnID, e.Aggregate().ID), + handler.NewCol(GroupColumnName, e.Name), + handler.NewCol(GroupColumnResourceOwner, e.Aggregate().ResourceOwner), + handler.NewCol(GroupColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCol(GroupColumnDescription, e.Description), + handler.NewCol(GroupColumnCreationDate, e.CreationDate()), + handler.NewCol(GroupColumnChangeDate, e.CreationDate()), + handler.NewCol(GroupColumnSequence, e.Sequence()), + handler.NewCol(GroupColumnState, domain.GroupStateActive), + }, + ), nil +} + +func (g *groupProjection) reduceGroupChanged(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*group.GroupChangedEvent](event) + if err != nil { + return nil, err + } + + columns := make([]handler.Column, 0, 4) + + if e.Name != nil { + columns = append(columns, handler.NewCol(GroupColumnName, *e.Name)) + } + if e.Description != nil { + columns = append(columns, handler.NewCol(GroupColumnDescription, *e.Description)) + } + if len(columns) == 0 { + return handler.NewNoOpStatement(e), nil + } + + columns = append( + columns, + handler.NewCol(GroupColumnChangeDate, e.CreationDate()), + handler.NewCol(GroupColumnSequence, e.Sequence()), + ) + + return handler.NewUpdateStatement( + e, + columns, + []handler.Condition{ + handler.NewCond(GroupColumnID, e.Aggregate().ID), + handler.NewCond(GroupColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (g *groupProjection) reduceGroupRemoved(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*group.GroupRemovedEvent](event) + if err != nil { + return nil, err + } + return handler.NewDeleteStatement( + e, + []handler.Condition{ + handler.NewCond(GroupColumnID, e.Aggregate().ID), + handler.NewCond(GroupColumnInstanceID, e.Aggregate().InstanceID), + }, + ), nil +} + +func (g *groupProjection) reduceOwnerRemoved(event eventstore.Event) (*handler.Statement, error) { + e, err := assertEvent[*org.OrgRemovedEvent](event) + if err != nil { + return nil, err + } + return handler.NewDeleteStatement( + e, + []handler.Condition{ + handler.NewCond(GroupColumnInstanceID, e.Aggregate().InstanceID), + handler.NewCond(GroupColumnResourceOwner, e.Aggregate().ID), + }, + ), nil +} diff --git a/internal/query/projection/projection.go b/internal/query/projection/projection.go index de84bccda79..bea9bcaa26a 100644 --- a/internal/query/projection/projection.go +++ b/internal/query/projection/projection.go @@ -102,6 +102,8 @@ var ( InstanceDomainFields *handler.FieldHandler MembershipFields *handler.FieldHandler PermissionFields *handler.FieldHandler + + GroupProjection *handler.Handler ) type projection interface { @@ -207,6 +209,8 @@ func Create(ctx context.Context, sqlClient *database.DB, es handler.EventStore, PermissionFields = newFillPermissionFields(applyCustomConfig(projectionConfig, config.Customizations[fieldsPermission])) // Don't forget to add the new field handler to [ProjectInstanceFields] + GroupProjection = newGroupProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["groups"])) + InstanceRelationalProjection = newInstanceRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instances_relational"])) OrganizationRelationalProjection = newOrgRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["organizations_relational"])) InstanceDomainRelationalProjection = newInstanceDomainRelationalProjection(ctx, applyCustomConfig(projectionConfig, config.Customizations["instance_domains_relational"])) @@ -396,6 +400,7 @@ func newProjectionsList() { DebugEventsProjection, HostedLoginTranslationProjection, OrganizationSettingsProjection, + GroupProjection, InstanceRelationalProjection, OrganizationRelationalProjection,