mirror of
https://github.com/zitadel/zitadel.git
synced 2025-02-28 20:07:23 +00:00
feat: updating eventstore.permitted_orgs sql function (#9309)
# Which Problems Are Solved Performance issue for GRPC call `zitadel.user.v2.UserService.ListUsers` due to lack of org filtering on `ListUsers` # Additional Context Replace this example with links to related issues, discussions, discord threads, or other sources with more context. Use the Closing #issue syntax for issues that are resolved with this PR. - Closes https://github.com/zitadel/zitadel/issues/9191 --------- Co-authored-by: Iraq Jaber <IraqJaber@gmail.com> Co-authored-by: Tim Möhlmann <tim+github@zitadel.com>
This commit is contained in:
parent
7c96dcd9a2
commit
0cb0380826
39
cmd/setup/49.go
Normal file
39
cmd/setup/49.go
Normal file
@ -0,0 +1,39 @@
|
||||
package setup
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/database"
|
||||
"github.com/zitadel/zitadel/internal/eventstore"
|
||||
)
|
||||
|
||||
type InitPermittedOrgsFunction struct {
|
||||
eventstoreClient *database.DB
|
||||
}
|
||||
|
||||
var (
|
||||
//go:embed 49/*.sql
|
||||
permittedOrgsFunction embed.FS
|
||||
)
|
||||
|
||||
func (mig *InitPermittedOrgsFunction) Execute(ctx context.Context, _ eventstore.Event) error {
|
||||
statements, err := readStatements(permittedOrgsFunction, "49", "")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, stmt := range statements {
|
||||
logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement")
|
||||
if _, err := mig.eventstoreClient.ExecContext(ctx, stmt.query); err != nil {
|
||||
return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*InitPermittedOrgsFunction) String() string {
|
||||
return "49_init_permitted_orgs_function"
|
||||
}
|
56
cmd/setup/49/01-permitted_orgs_function.sql
Normal file
56
cmd/setup/49/01-permitted_orgs_function.sql
Normal file
@ -0,0 +1,56 @@
|
||||
DROP FUNCTION IF EXISTS eventstore.permitted_orgs;
|
||||
|
||||
CREATE OR REPLACE FUNCTION eventstore.permitted_orgs(
|
||||
instanceId TEXT
|
||||
, userId TEXT
|
||||
, perm TEXT
|
||||
, filter_orgs TEXT
|
||||
|
||||
, org_ids OUT TEXT[]
|
||||
)
|
||||
LANGUAGE 'plpgsql'
|
||||
STABLE
|
||||
AS $$
|
||||
DECLARE
|
||||
matched_roles TEXT[]; -- roles containing permission
|
||||
BEGIN
|
||||
SELECT array_agg(rp.role) INTO matched_roles
|
||||
FROM eventstore.role_permissions rp
|
||||
WHERE rp.instance_id = instanceId
|
||||
AND rp.permission = perm;
|
||||
|
||||
-- First try if the permission was granted thru an instance-level role
|
||||
DECLARE
|
||||
has_instance_permission bool;
|
||||
BEGIN
|
||||
SELECT true INTO has_instance_permission
|
||||
FROM eventstore.instance_members im
|
||||
WHERE im.role = ANY(matched_roles)
|
||||
AND im.instance_id = instanceId
|
||||
AND im.user_id = userId
|
||||
LIMIT 1;
|
||||
|
||||
IF has_instance_permission THEN
|
||||
-- Return all organizations or only those in filter_orgs
|
||||
SELECT array_agg(o.org_id) INTO org_ids
|
||||
FROM eventstore.instance_orgs o
|
||||
WHERE o.instance_id = instanceId
|
||||
AND CASE WHEN filter_orgs != ''
|
||||
THEN o.org_id IN (filter_orgs)
|
||||
ELSE TRUE END;
|
||||
RETURN;
|
||||
END IF;
|
||||
END;
|
||||
|
||||
-- Return the organizations where permission were granted thru org-level roles
|
||||
SELECT array_agg(org_id) INTO org_ids
|
||||
FROM (
|
||||
SELECT DISTINCT om.org_id
|
||||
FROM eventstore.org_members om
|
||||
WHERE om.role = ANY(matched_roles)
|
||||
AND om.instance_id = instanceID
|
||||
AND om.user_id = userId
|
||||
);
|
||||
RETURN;
|
||||
END;
|
||||
$$;
|
@ -137,6 +137,7 @@ type Steps struct {
|
||||
s46InitPermissionFunctions *InitPermissionFunctions
|
||||
s47FillMembershipFields *FillMembershipFields
|
||||
s48Apps7SAMLConfigsLoginVersion *Apps7SAMLConfigsLoginVersion
|
||||
s49InitPermittedOrgsFunction *InitPermittedOrgsFunction
|
||||
}
|
||||
|
||||
func MustNewSteps(v *viper.Viper) *Steps {
|
||||
|
@ -174,6 +174,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
steps.s46InitPermissionFunctions = &InitPermissionFunctions{eventstoreClient: dbClient}
|
||||
steps.s47FillMembershipFields = &FillMembershipFields{eventstore: eventstoreClient}
|
||||
steps.s48Apps7SAMLConfigsLoginVersion = &Apps7SAMLConfigsLoginVersion{dbClient: dbClient}
|
||||
steps.s49InitPermittedOrgsFunction = &InitPermittedOrgsFunction{eventstoreClient: dbClient}
|
||||
|
||||
err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil)
|
||||
logging.OnError(err).Fatal("unable to start projections")
|
||||
@ -238,6 +239,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string)
|
||||
steps.s45CorrectProjectOwners,
|
||||
steps.s46InitPermissionFunctions,
|
||||
steps.s47FillMembershipFields,
|
||||
steps.s49InitPermittedOrgsFunction,
|
||||
} {
|
||||
mustExecuteMigration(ctx, eventstoreClient, step, "migration failed")
|
||||
}
|
||||
|
@ -554,7 +554,7 @@ func (s *Server) getUsers(ctx context.Context, org string, withPasswords bool, w
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, nil)
|
||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, org, nil)
|
||||
if err != nil {
|
||||
return nil, nil, nil, nil, err
|
||||
}
|
||||
|
@ -108,7 +108,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain str
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
|
||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -330,7 +330,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain, or
|
||||
}
|
||||
queries = append(queries, owner)
|
||||
}
|
||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, nil)
|
||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, orgID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -64,11 +64,12 @@ func (s *Server) ListUsers(ctx context.Context, req *mgmt_pb.ListUsersRequest) (
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = queries.AppendMyResourceOwnerQuery(authz.GetCtxData(ctx).OrgID)
|
||||
orgID := authz.GetCtxData(ctx).OrgID
|
||||
err = queries.AppendMyResourceOwnerQuery(orgID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := s.query.SearchUsers(ctx, queries, nil)
|
||||
res, err := s.query.SearchUsers(ctx, queries, orgID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
|
||||
}
|
||||
|
||||
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
|
||||
queries, err := listUsersRequestToModel(req)
|
||||
queries, filterOrgId, err := listUsersRequestToModel(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
|
||||
res, err := s.query.SearchUsers(ctx, queries, filterOrgId, s.checkPermission)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -169,11 +169,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
|
||||
}
|
||||
}
|
||||
|
||||
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
|
||||
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
|
||||
offset, limit, asc := object.ListQueryToQuery(req.Query)
|
||||
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
||||
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
return &query.UserSearchQueries{
|
||||
SearchRequest: query.SearchRequest{
|
||||
@ -183,7 +183,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
|
||||
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
|
||||
},
|
||||
Queries: queries,
|
||||
}, nil
|
||||
}, filterOrgId, nil
|
||||
}
|
||||
|
||||
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
||||
@ -213,15 +213,18 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
||||
}
|
||||
}
|
||||
|
||||
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
|
||||
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
|
||||
q := make([]query.SearchQuery, len(queries))
|
||||
for i, query := range queries {
|
||||
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
|
||||
filterOrgId = orgFilter.OrganizationId
|
||||
}
|
||||
q[i], err = userQueryToQuery(query, level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, filterOrgId, err
|
||||
}
|
||||
}
|
||||
return q, nil
|
||||
return q, filterOrgId, nil
|
||||
}
|
||||
|
||||
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
|
||||
@ -315,14 +318,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
|
||||
return query.NewUserInUserIdsSearchQuery(q.UserIds)
|
||||
}
|
||||
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
|
||||
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return query.NewUserOrSearchQuery(mappedQueries)
|
||||
}
|
||||
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
|
||||
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
|
||||
}
|
||||
|
||||
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
|
||||
queries, err := listUsersRequestToModel(req)
|
||||
queries, filterOrgIds, err := listUsersRequestToModel(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
|
||||
res, err := s.query.SearchUsers(ctx, queries, filterOrgIds, s.checkPermission)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -165,11 +165,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
|
||||
}
|
||||
}
|
||||
|
||||
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
|
||||
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
|
||||
offset, limit, asc := object.ListQueryToQuery(req.Query)
|
||||
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
||||
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, "", err
|
||||
}
|
||||
return &query.UserSearchQueries{
|
||||
SearchRequest: query.SearchRequest{
|
||||
@ -179,7 +179,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
|
||||
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
|
||||
},
|
||||
Queries: queries,
|
||||
}, nil
|
||||
}, filterOrgId, nil
|
||||
}
|
||||
|
||||
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
||||
@ -209,15 +209,18 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
||||
}
|
||||
}
|
||||
|
||||
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
|
||||
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
|
||||
q := make([]query.SearchQuery, len(queries))
|
||||
for i, query := range queries {
|
||||
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
|
||||
filterOrgId = orgFilter.OrganizationId
|
||||
}
|
||||
q[i], err = userQueryToQuery(query, level)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return nil, filterOrgId, err
|
||||
}
|
||||
}
|
||||
return q, nil
|
||||
return q, filterOrgId, nil
|
||||
}
|
||||
|
||||
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
|
||||
@ -311,14 +314,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
|
||||
return query.NewUserInUserIdsSearchQuery(q.UserIds)
|
||||
}
|
||||
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
|
||||
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return query.NewUserOrSearchQuery(mappedQueries)
|
||||
}
|
||||
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
|
||||
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -240,7 +240,7 @@ func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListRes
|
||||
return NewListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil
|
||||
}
|
||||
|
||||
users, err := h.query.SearchUsers(ctx, q, nil)
|
||||
users, err := h.query.SearchUsers(ctx, q, authz.GetCtxData(ctx).OrgID, nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -182,7 +182,7 @@ func (l *Login) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgName string
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
|
||||
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -11,8 +11,8 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
// eventstore.permitted_orgs(instanceid text, userid text, perm text)
|
||||
wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?))"
|
||||
// eventstore.permitted_orgs(instanceid text, userid text, perm text, filter_orgs text)
|
||||
wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?))"
|
||||
)
|
||||
|
||||
// wherePermittedOrgs sets a `WHERE` clause to the query that filters the orgs
|
||||
@ -23,13 +23,15 @@ const (
|
||||
// and is typically the `resource_owner` column in ZITADEL.
|
||||
// We use full identifiers in the query builder so this function should be
|
||||
// called with something like `UserResourceOwnerCol.identifier()` for example.
|
||||
func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, orgIDColumn, permission string) sq.SelectBuilder {
|
||||
func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, permission string) sq.SelectBuilder {
|
||||
userID := authz.GetCtxData(ctx).UserID
|
||||
logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
|
||||
|
||||
return query.Where(
|
||||
fmt.Sprintf(wherePermittedOrgsClause, orgIDColumn),
|
||||
authz.GetInstance(ctx).InstanceID(),
|
||||
userID,
|
||||
permission,
|
||||
filterOrgIds,
|
||||
)
|
||||
}
|
||||
|
@ -635,8 +635,8 @@ func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (c
|
||||
return count, err
|
||||
}
|
||||
|
||||
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) {
|
||||
users, err := q.searchUsers(ctx, queries, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
|
||||
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheck domain.PermissionCheck) (*Users, error) {
|
||||
users, err := q.searchUsers(ctx, queries, filterOrgIds, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -646,7 +646,7 @@ func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, p
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheckV2 bool) (users *Users, err error) {
|
||||
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheckV2 bool) (users *Users, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
@ -655,7 +655,7 @@ func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, p
|
||||
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||
})
|
||||
if permissionCheckV2 {
|
||||
query = wherePermittedOrgs(ctx, query, UserResourceOwnerCol.identifier(), domain.PermissionUserRead)
|
||||
query = wherePermittedOrgs(ctx, query, filterOrgIds, UserResourceOwnerCol.identifier(), domain.PermissionUserRead)
|
||||
}
|
||||
|
||||
stmt, args, err := query.ToSql()
|
||||
|
Loading…
x
Reference in New Issue
Block a user