feat: updating eventstore.permitted_orgs sql function (#9309)

# Which Problems Are Solved

Performance issue for GRPC call `zitadel.user.v2.UserService.ListUsers`
due to lack of org filtering on `ListUsers`

# Additional Context

Replace this example with links to related issues, discussions, discord
threads, or other sources with more context.
Use the Closing #issue syntax for issues that are resolved with this PR.
- Closes https://github.com/zitadel/zitadel/issues/9191

---------

Co-authored-by: Iraq Jaber <IraqJaber@gmail.com>
Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
Iraq
2025-02-17 09:55:28 +00:00
committed by GitHub
parent 7c96dcd9a2
commit 0cb0380826
14 changed files with 143 additions and 36 deletions

View File

@@ -554,7 +554,7 @@ func (s *Server) getUsers(ctx context.Context, org string, withPasswords bool, w
if err != nil {
return nil, nil, nil, nil, err
}
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, nil)
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, org, nil)
if err != nil {
return nil, nil, nil, nil, err
}

View File

@@ -108,7 +108,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain str
if err != nil {
return nil, err
}
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
if err != nil {
return nil, err
}

View File

@@ -330,7 +330,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain, or
}
queries = append(queries, owner)
}
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, nil)
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, orgID, nil)
if err != nil {
return nil, err
}

View File

@@ -64,11 +64,12 @@ func (s *Server) ListUsers(ctx context.Context, req *mgmt_pb.ListUsersRequest) (
return nil, err
}
err = queries.AppendMyResourceOwnerQuery(authz.GetCtxData(ctx).OrgID)
orgID := authz.GetCtxData(ctx).OrgID
err = queries.AppendMyResourceOwnerQuery(orgID)
if err != nil {
return nil, err
}
res, err := s.query.SearchUsers(ctx, queries, nil)
res, err := s.query.SearchUsers(ctx, queries, orgID, nil)
if err != nil {
return nil, err
}

View File

@@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
}
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
queries, err := listUsersRequestToModel(req)
queries, filterOrgId, err := listUsersRequestToModel(req)
if err != nil {
return nil, err
}
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
res, err := s.query.SearchUsers(ctx, queries, filterOrgId, s.checkPermission)
if err != nil {
return nil, err
}
@@ -169,11 +169,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
}
}
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
offset, limit, asc := object.ListQueryToQuery(req.Query)
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
if err != nil {
return nil, err
return nil, "", err
}
return &query.UserSearchQueries{
SearchRequest: query.SearchRequest{
@@ -183,7 +183,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
},
Queries: queries,
}, nil
}, filterOrgId, nil
}
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
@@ -213,15 +213,18 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
}
}
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
q := make([]query.SearchQuery, len(queries))
for i, query := range queries {
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
filterOrgId = orgFilter.OrganizationId
}
q[i], err = userQueryToQuery(query, level)
if err != nil {
return nil, err
return nil, filterOrgId, err
}
}
return q, nil
return q, filterOrgId, nil
}
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
@@ -315,14 +318,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
return query.NewUserInUserIdsSearchQuery(q.UserIds)
}
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}
return query.NewUserOrSearchQuery(mappedQueries)
}
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}

View File

@@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
}
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
queries, err := listUsersRequestToModel(req)
queries, filterOrgIds, err := listUsersRequestToModel(req)
if err != nil {
return nil, err
}
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
res, err := s.query.SearchUsers(ctx, queries, filterOrgIds, s.checkPermission)
if err != nil {
return nil, err
}
@@ -165,11 +165,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
}
}
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
offset, limit, asc := object.ListQueryToQuery(req.Query)
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
if err != nil {
return nil, err
return nil, "", err
}
return &query.UserSearchQueries{
SearchRequest: query.SearchRequest{
@@ -179,7 +179,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
},
Queries: queries,
}, nil
}, filterOrgId, nil
}
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
@@ -209,15 +209,18 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
}
}
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
q := make([]query.SearchQuery, len(queries))
for i, query := range queries {
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
filterOrgId = orgFilter.OrganizationId
}
q[i], err = userQueryToQuery(query, level)
if err != nil {
return nil, err
return nil, filterOrgId, err
}
}
return q, nil
return q, filterOrgId, nil
}
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
@@ -311,14 +314,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
return query.NewUserInUserIdsSearchQuery(q.UserIds)
}
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}
return query.NewUserOrSearchQuery(mappedQueries)
}
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}

View File

@@ -240,7 +240,7 @@ func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListRes
return NewListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil
}
users, err := h.query.SearchUsers(ctx, q, nil)
users, err := h.query.SearchUsers(ctx, q, authz.GetCtxData(ctx).OrgID, nil)
if err != nil {
return nil, err
}

View File

@@ -182,7 +182,7 @@ func (l *Login) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgName string
if err != nil {
return nil, err
}
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
if err != nil {
return nil, err
}

View File

@@ -11,8 +11,8 @@ import (
)
const (
// eventstore.permitted_orgs(instanceid text, userid text, perm text)
wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?))"
// eventstore.permitted_orgs(instanceid text, userid text, perm text, filter_orgs text)
wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?))"
)
// wherePermittedOrgs sets a `WHERE` clause to the query that filters the orgs
@@ -23,13 +23,15 @@ const (
// and is typically the `resource_owner` column in ZITADEL.
// We use full identifiers in the query builder so this function should be
// called with something like `UserResourceOwnerCol.identifier()` for example.
func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, orgIDColumn, permission string) sq.SelectBuilder {
func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, permission string) sq.SelectBuilder {
userID := authz.GetCtxData(ctx).UserID
logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
return query.Where(
fmt.Sprintf(wherePermittedOrgsClause, orgIDColumn),
authz.GetInstance(ctx).InstanceID(),
userID,
permission,
filterOrgIds,
)
}

View File

@@ -635,8 +635,8 @@ func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (c
return count, err
}
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) {
users, err := q.searchUsers(ctx, queries, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheck domain.PermissionCheck) (*Users, error) {
users, err := q.searchUsers(ctx, queries, filterOrgIds, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
if err != nil {
return nil, err
}
@@ -646,7 +646,7 @@ func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, p
return users, nil
}
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheckV2 bool) (users *Users, err error) {
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheckV2 bool) (users *Users, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
@@ -655,7 +655,7 @@ func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, p
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
})
if permissionCheckV2 {
query = wherePermittedOrgs(ctx, query, UserResourceOwnerCol.identifier(), domain.PermissionUserRead)
query = wherePermittedOrgs(ctx, query, filterOrgIds, UserResourceOwnerCol.identifier(), domain.PermissionUserRead)
}
stmt, args, err := query.ToSql()