From 0cb03808269339bbc65d9fdb2e3d7dfa54cf5305 Mon Sep 17 00:00:00 2001 From: Iraq <66622793+kkrime@users.noreply.github.com> Date: Mon, 17 Feb 2025 09:55:28 +0000 Subject: [PATCH] feat: updating eventstore.permitted_orgs sql function (#9309) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # Which Problems Are Solved Performance issue for GRPC call `zitadel.user.v2.UserService.ListUsers` due to lack of org filtering on `ListUsers` # Additional Context Replace this example with links to related issues, discussions, discord threads, or other sources with more context. Use the Closing #issue syntax for issues that are resolved with this PR. - Closes https://github.com/zitadel/zitadel/issues/9191 --------- Co-authored-by: Iraq Jaber Co-authored-by: Tim Möhlmann --- cmd/setup/49.go | 39 ++++++++++++++ cmd/setup/49/01-permitted_orgs_function.sql | 56 +++++++++++++++++++++ cmd/setup/config.go | 1 + cmd/setup/setup.go | 2 + internal/api/grpc/admin/export.go | 2 +- internal/api/grpc/admin/org.go | 2 +- internal/api/grpc/management/org.go | 2 +- internal/api/grpc/management/user.go | 5 +- internal/api/grpc/user/v2/query.go | 25 +++++---- internal/api/grpc/user/v2beta/query.go | 25 +++++---- internal/api/scim/resources/user.go | 2 +- internal/api/ui/login/login.go | 2 +- internal/query/permission.go | 8 +-- internal/query/user.go | 8 +-- 14 files changed, 143 insertions(+), 36 deletions(-) create mode 100644 cmd/setup/49.go create mode 100644 cmd/setup/49/01-permitted_orgs_function.sql diff --git a/cmd/setup/49.go b/cmd/setup/49.go new file mode 100644 index 0000000000..28bf797110 --- /dev/null +++ b/cmd/setup/49.go @@ -0,0 +1,39 @@ +package setup + +import ( + "context" + "embed" + "fmt" + + "github.com/zitadel/logging" + + "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/eventstore" +) + +type InitPermittedOrgsFunction struct { + eventstoreClient *database.DB +} + +var ( + //go:embed 49/*.sql + permittedOrgsFunction embed.FS +) + +func (mig *InitPermittedOrgsFunction) Execute(ctx context.Context, _ eventstore.Event) error { + statements, err := readStatements(permittedOrgsFunction, "49", "") + if err != nil { + return err + } + for _, stmt := range statements { + logging.WithFields("file", stmt.file, "migration", mig.String()).Info("execute statement") + if _, err := mig.eventstoreClient.ExecContext(ctx, stmt.query); err != nil { + return fmt.Errorf("%s %s: %w", mig.String(), stmt.file, err) + } + } + return nil +} + +func (*InitPermittedOrgsFunction) String() string { + return "49_init_permitted_orgs_function" +} diff --git a/cmd/setup/49/01-permitted_orgs_function.sql b/cmd/setup/49/01-permitted_orgs_function.sql new file mode 100644 index 0000000000..9f291c016b --- /dev/null +++ b/cmd/setup/49/01-permitted_orgs_function.sql @@ -0,0 +1,56 @@ +DROP FUNCTION IF EXISTS eventstore.permitted_orgs; + +CREATE OR REPLACE FUNCTION eventstore.permitted_orgs( + instanceId TEXT + , userId TEXT + , perm TEXT + , filter_orgs TEXT + + , org_ids OUT TEXT[] +) + LANGUAGE 'plpgsql' + STABLE +AS $$ +DECLARE + matched_roles TEXT[]; -- roles containing permission +BEGIN + SELECT array_agg(rp.role) INTO matched_roles + FROM eventstore.role_permissions rp + WHERE rp.instance_id = instanceId + AND rp.permission = perm; + + -- First try if the permission was granted thru an instance-level role + DECLARE + has_instance_permission bool; + BEGIN + SELECT true INTO has_instance_permission + FROM eventstore.instance_members im + WHERE im.role = ANY(matched_roles) + AND im.instance_id = instanceId + AND im.user_id = userId + LIMIT 1; + + IF has_instance_permission THEN + -- Return all organizations or only those in filter_orgs + SELECT array_agg(o.org_id) INTO org_ids + FROM eventstore.instance_orgs o + WHERE o.instance_id = instanceId + AND CASE WHEN filter_orgs != '' + THEN o.org_id IN (filter_orgs) + ELSE TRUE END; + RETURN; + END IF; + END; + + -- Return the organizations where permission were granted thru org-level roles + SELECT array_agg(org_id) INTO org_ids + FROM ( + SELECT DISTINCT om.org_id + FROM eventstore.org_members om + WHERE om.role = ANY(matched_roles) + AND om.instance_id = instanceID + AND om.user_id = userId + ); + RETURN; +END; +$$; diff --git a/cmd/setup/config.go b/cmd/setup/config.go index d782a32dd6..0153f7227f 100644 --- a/cmd/setup/config.go +++ b/cmd/setup/config.go @@ -137,6 +137,7 @@ type Steps struct { s46InitPermissionFunctions *InitPermissionFunctions s47FillMembershipFields *FillMembershipFields s48Apps7SAMLConfigsLoginVersion *Apps7SAMLConfigsLoginVersion + s49InitPermittedOrgsFunction *InitPermittedOrgsFunction } func MustNewSteps(v *viper.Viper) *Steps { diff --git a/cmd/setup/setup.go b/cmd/setup/setup.go index bfa289ab36..74b16355f3 100644 --- a/cmd/setup/setup.go +++ b/cmd/setup/setup.go @@ -174,6 +174,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s46InitPermissionFunctions = &InitPermissionFunctions{eventstoreClient: dbClient} steps.s47FillMembershipFields = &FillMembershipFields{eventstore: eventstoreClient} steps.s48Apps7SAMLConfigsLoginVersion = &Apps7SAMLConfigsLoginVersion{dbClient: dbClient} + steps.s49InitPermittedOrgsFunction = &InitPermittedOrgsFunction{eventstoreClient: dbClient} err = projection.Create(ctx, dbClient, eventstoreClient, config.Projections, nil, nil, nil) logging.OnError(err).Fatal("unable to start projections") @@ -238,6 +239,7 @@ func Setup(ctx context.Context, config *Config, steps *Steps, masterKey string) steps.s45CorrectProjectOwners, steps.s46InitPermissionFunctions, steps.s47FillMembershipFields, + steps.s49InitPermittedOrgsFunction, } { mustExecuteMigration(ctx, eventstoreClient, step, "migration failed") } diff --git a/internal/api/grpc/admin/export.go b/internal/api/grpc/admin/export.go index 68b6053c2c..da364909cb 100644 --- a/internal/api/grpc/admin/export.go +++ b/internal/api/grpc/admin/export.go @@ -554,7 +554,7 @@ func (s *Server) getUsers(ctx context.Context, org string, withPasswords bool, w if err != nil { return nil, nil, nil, nil, err } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, nil) + users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, org, nil) if err != nil { return nil, nil, nil, nil, err } diff --git a/internal/api/grpc/admin/org.go b/internal/api/grpc/admin/org.go index 934de1b570..f788bb5f5a 100644 --- a/internal/api/grpc/admin/org.go +++ b/internal/api/grpc/admin/org.go @@ -108,7 +108,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain str if err != nil { return nil, err } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil) + users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil) if err != nil { return nil, err } diff --git a/internal/api/grpc/management/org.go b/internal/api/grpc/management/org.go index d25d46d852..abc179a763 100644 --- a/internal/api/grpc/management/org.go +++ b/internal/api/grpc/management/org.go @@ -330,7 +330,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain, or } queries = append(queries, owner) } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, nil) + users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, orgID, nil) if err != nil { return nil, err } diff --git a/internal/api/grpc/management/user.go b/internal/api/grpc/management/user.go index dac651af81..17bca58993 100644 --- a/internal/api/grpc/management/user.go +++ b/internal/api/grpc/management/user.go @@ -64,11 +64,12 @@ func (s *Server) ListUsers(ctx context.Context, req *mgmt_pb.ListUsersRequest) ( return nil, err } - err = queries.AppendMyResourceOwnerQuery(authz.GetCtxData(ctx).OrgID) + orgID := authz.GetCtxData(ctx).OrgID + err = queries.AppendMyResourceOwnerQuery(orgID) if err != nil { return nil, err } - res, err := s.query.SearchUsers(ctx, queries, nil) + res, err := s.query.SearchUsers(ctx, queries, orgID, nil) if err != nil { return nil, err } diff --git a/internal/api/grpc/user/v2/query.go b/internal/api/grpc/user/v2/query.go index aeb17d5dcf..aec5367ded 100644 --- a/internal/api/grpc/user/v2/query.go +++ b/internal/api/grpc/user/v2/query.go @@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest) } func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) { - queries, err := listUsersRequestToModel(req) + queries, filterOrgId, err := listUsersRequestToModel(req) if err != nil { return nil, err } - res, err := s.query.SearchUsers(ctx, queries, s.checkPermission) + res, err := s.query.SearchUsers(ctx, queries, filterOrgId, s.checkPermission) if err != nil { return nil, err } @@ -169,11 +169,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT } } -func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) { +func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) { offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) + queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) if err != nil { - return nil, err + return nil, "", err } return &query.UserSearchQueries{ SearchRequest: query.SearchRequest{ @@ -183,7 +183,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri SortingColumn: userFieldNameToSortingColumn(req.SortingColumn), }, Queries: queries, - }, nil + }, filterOrgId, nil } func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { @@ -213,15 +213,18 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { } } -func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) { +func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) { q := make([]query.SearchQuery, len(queries)) for i, query := range queries { + if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil { + filterOrgId = orgFilter.OrganizationId + } q[i], err = userQueryToQuery(query, level) if err != nil { - return nil, err + return nil, filterOrgId, err } } - return q, nil + return q, filterOrgId, nil } func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) { @@ -315,14 +318,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) { return query.NewUserInUserIdsSearchQuery(q.UserIds) } func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } return query.NewUserOrSearchQuery(mappedQueries) } func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } diff --git a/internal/api/grpc/user/v2beta/query.go b/internal/api/grpc/user/v2beta/query.go index e3602abc33..7baa53e73e 100644 --- a/internal/api/grpc/user/v2beta/query.go +++ b/internal/api/grpc/user/v2beta/query.go @@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest) } func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) { - queries, err := listUsersRequestToModel(req) + queries, filterOrgIds, err := listUsersRequestToModel(req) if err != nil { return nil, err } - res, err := s.query.SearchUsers(ctx, queries, s.checkPermission) + res, err := s.query.SearchUsers(ctx, queries, filterOrgIds, s.checkPermission) if err != nil { return nil, err } @@ -165,11 +165,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT } } -func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) { +func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) { offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) + queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) if err != nil { - return nil, err + return nil, "", err } return &query.UserSearchQueries{ SearchRequest: query.SearchRequest{ @@ -179,7 +179,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri SortingColumn: userFieldNameToSortingColumn(req.SortingColumn), }, Queries: queries, - }, nil + }, filterOrgId, nil } func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { @@ -209,15 +209,18 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { } } -func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) { +func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) { q := make([]query.SearchQuery, len(queries)) for i, query := range queries { + if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil { + filterOrgId = orgFilter.OrganizationId + } q[i], err = userQueryToQuery(query, level) if err != nil { - return nil, err + return nil, filterOrgId, err } } - return q, nil + return q, filterOrgId, nil } func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) { @@ -311,14 +314,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) { return query.NewUserInUserIdsSearchQuery(q.UserIds) } func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } return query.NewUserOrSearchQuery(mappedQueries) } func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } diff --git a/internal/api/scim/resources/user.go b/internal/api/scim/resources/user.go index bc8d864994..ffd39aa23f 100644 --- a/internal/api/scim/resources/user.go +++ b/internal/api/scim/resources/user.go @@ -240,7 +240,7 @@ func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListRes return NewListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil } - users, err := h.query.SearchUsers(ctx, q, nil) + users, err := h.query.SearchUsers(ctx, q, authz.GetCtxData(ctx).OrgID, nil) if err != nil { return nil, err } diff --git a/internal/api/ui/login/login.go b/internal/api/ui/login/login.go index 444c5aaa85..4b028a347f 100644 --- a/internal/api/ui/login/login.go +++ b/internal/api/ui/login/login.go @@ -182,7 +182,7 @@ func (l *Login) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgName string if err != nil { return nil, err } - users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil) + users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil) if err != nil { return nil, err } diff --git a/internal/query/permission.go b/internal/query/permission.go index 96d7db6c6a..591493375e 100644 --- a/internal/query/permission.go +++ b/internal/query/permission.go @@ -11,8 +11,8 @@ import ( ) const ( - // eventstore.permitted_orgs(instanceid text, userid text, perm text) - wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?))" + // eventstore.permitted_orgs(instanceid text, userid text, perm text, filter_orgs text) + wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?))" ) // wherePermittedOrgs sets a `WHERE` clause to the query that filters the orgs @@ -23,13 +23,15 @@ const ( // and is typically the `resource_owner` column in ZITADEL. // We use full identifiers in the query builder so this function should be // called with something like `UserResourceOwnerCol.identifier()` for example. -func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, orgIDColumn, permission string) sq.SelectBuilder { +func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, permission string) sq.SelectBuilder { userID := authz.GetCtxData(ctx).UserID logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used") + return query.Where( fmt.Sprintf(wherePermittedOrgsClause, orgIDColumn), authz.GetInstance(ctx).InstanceID(), userID, permission, + filterOrgIds, ) } diff --git a/internal/query/user.go b/internal/query/user.go index bb76e51f66..0b00b45e03 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -635,8 +635,8 @@ func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (c return count, err } -func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) { - users, err := q.searchUsers(ctx, queries, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2) +func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheck domain.PermissionCheck) (*Users, error) { + users, err := q.searchUsers(ctx, queries, filterOrgIds, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2) if err != nil { return nil, err } @@ -646,7 +646,7 @@ func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, p return users, nil } -func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheckV2 bool) (users *Users, err error) { +func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheckV2 bool) (users *Users, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() @@ -655,7 +655,7 @@ func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, p UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), }) if permissionCheckV2 { - query = wherePermittedOrgs(ctx, query, UserResourceOwnerCol.identifier(), domain.PermissionUserRead) + query = wherePermittedOrgs(ctx, query, filterOrgIds, UserResourceOwnerCol.identifier(), domain.PermissionUserRead) } stmt, args, err := query.ToSql()