From a2f60f2e7af4d5c3282d1a568fa04492e8b1a7ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tim=20M=C3=B6hlmann?= Date: Tue, 15 Apr 2025 19:38:25 +0300 Subject: [PATCH] perf(query): org permission function for resources (#9677) # Which Problems Are Solved Classic permission checks execute for every returned row on resource based search APIs. Complete background and problem definition can be found here: https://github.com/zitadel/zitadel/issues/9188 # How the Problems Are Solved - PermissionClause function now support dynamic query building, so it supports multiple cases. - PermissionClause is applied to all list resources which support org level permissions. - Wrap permission logic into wrapper functions so we keep the business logic clean. # Additional Changes - Handle org ID optimization in the query package, so it is reusable for all resources, instead of extracting the filter in the API. - Cleanup and test system user conversion in the authz package. (context middleware) - Fix: `core_integration_db_up` make recipe was missing the postgres service. # Additional Context - Related to https://github.com/zitadel/zitadel/issues/9190 --- Makefile | 2 +- internal/api/authz/authorization.go | 50 ++---- internal/api/authz/authorization_test.go | 126 ++++++++++++++ internal/api/authz/context.go | 35 ++-- internal/api/authz/membertype_enumer.go | 20 ++- internal/api/grpc/admin/export.go | 2 +- internal/api/grpc/admin/org.go | 2 +- internal/api/grpc/management/org.go | 2 +- internal/api/grpc/management/user.go | 2 +- internal/api/grpc/user/v2/query.go | 25 ++- internal/api/grpc/user/v2beta/query.go | 25 ++- internal/api/scim/resources/user.go | 2 +- internal/api/ui/login/login.go | 2 +- internal/database/type.go | 34 ++++ internal/database/type_test.go | 87 ++++++++++ internal/query/idp_user_link.go | 27 ++- internal/query/org.go | 19 ++- internal/query/permission.go | 153 ++++++++++------- internal/query/permission_test.go | 208 +++++++++++++++++++++++ internal/query/query.go | 13 ++ internal/query/session.go | 24 ++- internal/query/user.go | 33 ++-- internal/query/user_auth_method.go | 20 ++- 23 files changed, 741 insertions(+), 172 deletions(-) create mode 100644 internal/query/permission_test.go diff --git a/Makefile b/Makefile index b5145cef3d..3c50231bee 100644 --- a/Makefile +++ b/Makefile @@ -112,7 +112,7 @@ core_unit_test: .PHONY: core_integration_db_up core_integration_db_up: - docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait cache + docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait cache postgres .PHONY: core_integration_db_down core_integration_db_down: diff --git a/internal/api/authz/authorization.go b/internal/api/authz/authorization.go index ea20a2438f..25130584a0 100644 --- a/internal/api/authz/authorization.go +++ b/internal/api/authz/authorization.go @@ -7,8 +7,6 @@ import ( "slices" "strings" - "github.com/zitadel/logging" - "github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -26,14 +24,13 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID, ctx, span := tracing.NewServerInterceptorSpan(ctx) defer func() { span.EndWithError(err) }() - ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, orgDomain, verifier) + ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, orgDomain, verifier, systemRolePermissionMapping) if err != nil { return nil, err } if requiredAuthOption.Permission == authenticated { return func(parent context.Context) context.Context { - parent = addGetSystemUserRolesToCtx(parent, systemRolePermissionMapping, ctxData) return context.WithValue(parent, dataKey, ctxData) }, nil } @@ -54,7 +51,6 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID, parent = context.WithValue(parent, dataKey, ctxData) parent = context.WithValue(parent, allPermissionsKey, allPermissions) parent = context.WithValue(parent, requestPermissionsKey, requestedPermissions) - parent = addGetSystemUserRolesToCtx(parent, systemRolePermissionMapping, ctxData) return parent }, nil } @@ -131,42 +127,32 @@ func GetAllPermissionCtxIDs(perms []string) []string { return ctxIDs } -type SystemUserPermissionsDBQuery struct { - MemberType string `json:"member_type"` - AggregateID string `json:"aggregate_id"` - ObjectID string `json:"object_id"` - Permissions []string `json:"permissions"` +type SystemUserPermissions struct { + MemberType MemberType `json:"member_type"` + AggregateID string `json:"aggregate_id"` + ObjectID string `json:"object_id"` + Permissions []string `json:"permissions"` } -func addGetSystemUserRolesToCtx(ctx context.Context, systemUserRoleMap []RoleMapping, ctxData CtxData) context.Context { - if len(ctxData.SystemMemberships) == 0 { - return ctx +// systemMembershipsToUserPermissions converts system memberships based on roles, +// to SystemUserPermissions, using the passed role mapping. +func systemMembershipsToUserPermissions(memberships Memberships, roleMap []RoleMapping) []SystemUserPermissions { + if memberships == nil { + return nil } - systemUserPermissions := make([]SystemUserPermissionsDBQuery, len(ctxData.SystemMemberships)) - for i, systemPerm := range ctxData.SystemMemberships { + systemUserPermissions := make([]SystemUserPermissions, len(memberships)) + for i, systemPerm := range memberships { permissions := make([]string, 0, len(systemPerm.Roles)) for _, role := range systemPerm.Roles { - permissions = append(permissions, getPermissionsFromRole(systemUserRoleMap, role)...) + permissions = append(permissions, getPermissionsFromRole(roleMap, role)...) } slices.Sort(permissions) - permissions = slices.Compact(permissions) + permissions = slices.Compact(permissions) // remove duplicates - systemUserPermissions[i].MemberType = systemPerm.MemberType.String() + systemUserPermissions[i].MemberType = systemPerm.MemberType systemUserPermissions[i].AggregateID = systemPerm.AggregateID + systemUserPermissions[i].ObjectID = systemPerm.ObjectID systemUserPermissions[i].Permissions = permissions } - return context.WithValue(ctx, systemUserRolesKey, systemUserPermissions) -} - -func GetSystemUserPermissions(ctx context.Context) []SystemUserPermissionsDBQuery { - getSystemUserRolesFuncValue := ctx.Value(systemUserRolesKey) - if getSystemUserRolesFuncValue == nil { - return nil - } - systemUserRoles, ok := getSystemUserRolesFuncValue.([]SystemUserPermissionsDBQuery) - if !ok { - logging.WithFields("Authz").Error("unable to cast []SystemUserPermissionsDBQuery") - return nil - } - return systemUserRoles + return systemUserPermissions } diff --git a/internal/api/authz/authorization_test.go b/internal/api/authz/authorization_test.go index 4b81c73d81..af49dcc5c6 100644 --- a/internal/api/authz/authorization_test.go +++ b/internal/api/authz/authorization_test.go @@ -3,6 +3,8 @@ package authz import ( "testing" + "github.com/stretchr/testify/assert" + "github.com/zitadel/zitadel/internal/zerrors" ) @@ -276,3 +278,127 @@ func Test_GetPermissionCtxIDs(t *testing.T) { }) } } + +func Test_systemMembershipsToUserPermissions(t *testing.T) { + roleMap := []RoleMapping{ + { + Role: "FOO_BAR", + Permissions: []string{"foo.bar.read", "foo.bar.write"}, + }, + { + Role: "BAR_FOO", + Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read"}, + }, + } + + type args struct { + memberships Memberships + roleMap []RoleMapping + } + tests := []struct { + name string + args args + want []SystemUserPermissions + }{ + { + name: "nil memberships", + args: args{ + memberships: nil, + roleMap: roleMap, + }, + want: nil, + }, + { + name: "empty memberships", + args: args{ + memberships: Memberships{}, + roleMap: roleMap, + }, + want: []SystemUserPermissions{}, + }, + { + name: "single membership", + args: args{ + memberships: Memberships{ + { + MemberType: MemberTypeSystem, + AggregateID: "1", + ObjectID: "2", + Roles: []string{"FOO_BAR"}, + }, + }, + roleMap: roleMap, + }, + want: []SystemUserPermissions{ + { + MemberType: MemberTypeSystem, + AggregateID: "1", + ObjectID: "2", + Permissions: []string{"foo.bar.read", "foo.bar.write"}, + }, + }, + }, + { + name: "multiple memberships", + args: args{ + memberships: Memberships{ + { + MemberType: MemberTypeSystem, + AggregateID: "1", + ObjectID: "2", + Roles: []string{"FOO_BAR"}, + }, + { + MemberType: MemberTypeIAM, + AggregateID: "1", + ObjectID: "2", + Roles: []string{"BAR_FOO"}, + }, + }, + roleMap: roleMap, + }, + want: []SystemUserPermissions{ + { + MemberType: MemberTypeSystem, + AggregateID: "1", + ObjectID: "2", + Permissions: []string{"foo.bar.read", "foo.bar.write"}, + }, + { + MemberType: MemberTypeIAM, + AggregateID: "1", + ObjectID: "2", + Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read"}, + }, + }, + }, + { + name: "multiple roles", + args: args{ + memberships: Memberships{ + { + MemberType: MemberTypeSystem, + AggregateID: "1", + ObjectID: "2", + Roles: []string{"FOO_BAR", "BAR_FOO"}, + }, + }, + roleMap: roleMap, + }, + want: []SystemUserPermissions{ + { + MemberType: MemberTypeSystem, + AggregateID: "1", + ObjectID: "2", + Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read", "foo.bar.write"}, + }, + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := systemMembershipsToUserPermissions(tt.args.memberships, tt.args.roleMap) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/api/authz/context.go b/internal/api/authz/context.go index d6528cd017..d12a1def44 100644 --- a/internal/api/authz/context.go +++ b/internal/api/authz/context.go @@ -1,4 +1,4 @@ -//go:generate enumer -type MemberType -trimprefix MemberType +//go:generate enumer -type MemberType -trimprefix MemberType -json package authz @@ -22,17 +22,17 @@ const ( dataKey key = 2 allPermissionsKey key = 3 instanceKey key = 4 - systemUserRolesKey key = 5 ) type CtxData struct { - UserID string - OrgID string - ProjectID string - AgentID string - PreferredLanguage string - ResourceOwner string - SystemMemberships Memberships + UserID string + OrgID string + ProjectID string + AgentID string + PreferredLanguage string + ResourceOwner string + SystemMemberships Memberships + SystemUserPermissions []SystemUserPermissions } func (ctxData CtxData) IsZero() bool { @@ -98,7 +98,7 @@ func (s SystemTokenVerifierFunc) VerifySystemToken(ctx context.Context, token st return s(ctx, token, orgID) } -func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain string, t APITokenVerifier) (_ CtxData, err error) { +func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain string, t APITokenVerifier, systemRoleMap []RoleMapping) (_ CtxData, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() tokenWOBearer, err := extractBearerToken(token) @@ -133,13 +133,14 @@ func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain st } } return CtxData{ - UserID: userID, - OrgID: orgID, - ProjectID: projectID, - AgentID: agentID, - PreferredLanguage: prefLang, - ResourceOwner: resourceOwner, - SystemMemberships: sysMemberships, + UserID: userID, + OrgID: orgID, + ProjectID: projectID, + AgentID: agentID, + PreferredLanguage: prefLang, + ResourceOwner: resourceOwner, + SystemMemberships: sysMemberships, + SystemUserPermissions: systemMembershipsToUserPermissions(sysMemberships, systemRoleMap), }, nil } diff --git a/internal/api/authz/membertype_enumer.go b/internal/api/authz/membertype_enumer.go index 5de4c92292..a4275a2254 100644 --- a/internal/api/authz/membertype_enumer.go +++ b/internal/api/authz/membertype_enumer.go @@ -1,8 +1,9 @@ -// Code generated by "enumer -type MemberType -trimprefix MemberType"; DO NOT EDIT. +// Code generated by "enumer -type MemberType -trimprefix MemberType -json"; DO NOT EDIT. package authz import ( + "encoding/json" "fmt" "strings" ) @@ -92,3 +93,20 @@ func (i MemberType) IsAMemberType() bool { } return false } + +// MarshalJSON implements the json.Marshaler interface for MemberType +func (i MemberType) MarshalJSON() ([]byte, error) { + return json.Marshal(i.String()) +} + +// UnmarshalJSON implements the json.Unmarshaler interface for MemberType +func (i *MemberType) UnmarshalJSON(data []byte) error { + var s string + if err := json.Unmarshal(data, &s); err != nil { + return fmt.Errorf("MemberType should be a string, got %s", data) + } + + var err error + *i, err = MemberTypeString(s) + return err +} diff --git a/internal/api/grpc/admin/export.go b/internal/api/grpc/admin/export.go index da364909cb..68b6053c2c 100644 --- a/internal/api/grpc/admin/export.go +++ b/internal/api/grpc/admin/export.go @@ -554,7 +554,7 @@ func (s *Server) getUsers(ctx context.Context, org string, withPasswords bool, w if err != nil { return nil, nil, nil, nil, err } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, org, nil) + users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, nil) if err != nil { return nil, nil, nil, nil, err } diff --git a/internal/api/grpc/admin/org.go b/internal/api/grpc/admin/org.go index 93e6936d42..293e7c74d7 100644 --- a/internal/api/grpc/admin/org.go +++ b/internal/api/grpc/admin/org.go @@ -108,7 +108,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain str if err != nil { return nil, err } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil) + users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil) if err != nil { return nil, err } diff --git a/internal/api/grpc/management/org.go b/internal/api/grpc/management/org.go index a6a934160a..70f509a4d7 100644 --- a/internal/api/grpc/management/org.go +++ b/internal/api/grpc/management/org.go @@ -329,7 +329,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain, or } queries = append(queries, owner) } - users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, orgID, nil) + users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, nil) if err != nil { return nil, err } diff --git a/internal/api/grpc/management/user.go b/internal/api/grpc/management/user.go index b876999584..5b82eb5afe 100644 --- a/internal/api/grpc/management/user.go +++ b/internal/api/grpc/management/user.go @@ -69,7 +69,7 @@ func (s *Server) ListUsers(ctx context.Context, req *mgmt_pb.ListUsersRequest) ( if err != nil { return nil, err } - res, err := s.query.SearchUsers(ctx, queries, orgID, nil) + res, err := s.query.SearchUsers(ctx, queries, nil) if err != nil { return nil, err } diff --git a/internal/api/grpc/user/v2/query.go b/internal/api/grpc/user/v2/query.go index 23d4b4422c..136a4a0932 100644 --- a/internal/api/grpc/user/v2/query.go +++ b/internal/api/grpc/user/v2/query.go @@ -30,11 +30,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest) } func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) { - queries, filterOrgId, err := listUsersRequestToModel(req) + queries, err := listUsersRequestToModel(req) if err != nil { return nil, err } - res, err := s.query.SearchUsers(ctx, queries, filterOrgId, s.checkPermission) + res, err := s.query.SearchUsers(ctx, queries, s.checkPermission) if err != nil { return nil, err } @@ -171,11 +171,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT } } -func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) { +func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) { offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) + queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) if err != nil { - return nil, "", err + return nil, err } return &query.UserSearchQueries{ SearchRequest: query.SearchRequest{ @@ -185,7 +185,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri SortingColumn: userFieldNameToSortingColumn(req.SortingColumn), }, Queries: queries, - }, filterOrgId, nil + }, nil } func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { @@ -215,18 +215,15 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { } } -func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) { +func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) { q := make([]query.SearchQuery, len(queries)) for i, query := range queries { - if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil { - filterOrgId = orgFilter.OrganizationId - } q[i], err = userQueryToQuery(query, level) if err != nil { - return nil, filterOrgId, err + return nil, err } } - return q, filterOrgId, nil + return q, nil } func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) { @@ -320,14 +317,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) { return query.NewUserInUserIdsSearchQuery(q.UserIds) } func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } return query.NewUserOrSearchQuery(mappedQueries) } func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } diff --git a/internal/api/grpc/user/v2beta/query.go b/internal/api/grpc/user/v2beta/query.go index 7baa53e73e..e3602abc33 100644 --- a/internal/api/grpc/user/v2beta/query.go +++ b/internal/api/grpc/user/v2beta/query.go @@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest) } func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) { - queries, filterOrgIds, err := listUsersRequestToModel(req) + queries, err := listUsersRequestToModel(req) if err != nil { return nil, err } - res, err := s.query.SearchUsers(ctx, queries, filterOrgIds, s.checkPermission) + res, err := s.query.SearchUsers(ctx, queries, s.checkPermission) if err != nil { return nil, err } @@ -165,11 +165,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT } } -func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) { +func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) { offset, limit, asc := object.ListQueryToQuery(req.Query) - queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) + queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/) if err != nil { - return nil, "", err + return nil, err } return &query.UserSearchQueries{ SearchRequest: query.SearchRequest{ @@ -179,7 +179,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri SortingColumn: userFieldNameToSortingColumn(req.SortingColumn), }, Queries: queries, - }, filterOrgId, nil + }, nil } func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { @@ -209,18 +209,15 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column { } } -func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) { +func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) { q := make([]query.SearchQuery, len(queries)) for i, query := range queries { - if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil { - filterOrgId = orgFilter.OrganizationId - } q[i], err = userQueryToQuery(query, level) if err != nil { - return nil, filterOrgId, err + return nil, err } } - return q, filterOrgId, nil + return q, nil } func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) { @@ -314,14 +311,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) { return query.NewUserInUserIdsSearchQuery(q.UserIds) } func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } return query.NewUserOrSearchQuery(mappedQueries) } func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) { - mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1) + mappedQueries, err := userQueriesToQuery(q.Queries, level+1) if err != nil { return nil, err } diff --git a/internal/api/scim/resources/user.go b/internal/api/scim/resources/user.go index ffd39aa23f..bc8d864994 100644 --- a/internal/api/scim/resources/user.go +++ b/internal/api/scim/resources/user.go @@ -240,7 +240,7 @@ func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListRes return NewListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil } - users, err := h.query.SearchUsers(ctx, q, authz.GetCtxData(ctx).OrgID, nil) + users, err := h.query.SearchUsers(ctx, q, nil) if err != nil { return nil, err } diff --git a/internal/api/ui/login/login.go b/internal/api/ui/login/login.go index 4b028a347f..444c5aaa85 100644 --- a/internal/api/ui/login/login.go +++ b/internal/api/ui/login/login.go @@ -182,7 +182,7 @@ func (l *Login) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgName string if err != nil { return nil, err } - users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil) + users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil) if err != nil { return nil, err } diff --git a/internal/database/type.go b/internal/database/type.go index 6a781288a9..bd07e0bfde 100644 --- a/internal/database/type.go +++ b/internal/database/type.go @@ -225,3 +225,37 @@ func (d *NullDuration) Scan(src any) error { d.Duration, d.Valid = time.Duration(*duration), true return nil } + +// JSONArray allows sending and receiving JSON arrays to and from the database. +// It implements the [database/sql.Scanner] and [database/sql/driver.Valuer] interfaces. +// Values are marshaled and unmarshaled using the [encoding/json] package. +type JSONArray[T any] []T + +// NewJSONArray wraps an existing slice into a JSONArray. +func NewJSONArray[T any](a []T) JSONArray[T] { + return JSONArray[T](a) +} + +// Scan implements the [database/sql.Scanner] interface. +func (a *JSONArray[T]) Scan(src any) error { + if src == nil { + *a = nil + return nil + } + + bytes := src.([]byte) + if len(bytes) == 0 { + *a = nil + return nil + } + + return json.Unmarshal(bytes, a) +} + +// Value implements the [database/sql/driver.Valuer] interface. +func (a JSONArray[T]) Value() (driver.Value, error) { + if a == nil { + return nil, nil + } + return json.Marshal(a) +} diff --git a/internal/database/type_test.go b/internal/database/type_test.go index e56cdced76..7fab568a4e 100644 --- a/internal/database/type_test.go +++ b/internal/database/type_test.go @@ -5,6 +5,7 @@ import ( "database/sql/driver" "testing" + "github.com/muhlemmer/gu" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) @@ -452,3 +453,89 @@ func TestDuration_Scan(t *testing.T) { }) } } + +func TestJSONArray_Scan(t *testing.T) { + type args struct { + src any + } + tests := []struct { + name string + args args + want *JSONArray[string] + wantErr bool + }{ + { + name: "nil", + args: args{src: nil}, + want: new(JSONArray[string]), + wantErr: false, + }, + { + name: "zero bytes", + args: args{src: []byte("")}, + want: new(JSONArray[string]), + wantErr: false, + }, + { + name: "empty", + args: args{src: []byte("[]")}, + want: gu.Ptr(JSONArray[string]{}), + wantErr: false, + }, + { + name: "ok", + args: args{src: []byte("[\"a\", \"b\"]")}, + want: gu.Ptr(JSONArray[string]{"a", "b"}), + wantErr: false, + }, + { + name: "json error", + args: args{src: []byte("{\"a\": \"b\"}")}, + want: new(JSONArray[string]), + wantErr: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := new(JSONArray[string]) + err := got.Scan(tt.args.src) + if tt.wantErr { + assert.Error(t, err) + return + } + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} + +func TestJSONArray_Value(t *testing.T) { + tests := []struct { + name string + a []string + want driver.Value + }{ + { + name: "nil", + a: nil, + want: nil, + }, + { + name: "empty", + a: []string{}, + want: []byte("[]"), + }, + { + name: "ok", + a: []string{"a", "b"}, + want: []byte("[\"a\",\"b\"]"), + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := NewJSONArray(tt.a).Value() + require.NoError(t, err) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/query/idp_user_link.go b/internal/query/idp_user_link.go index 23305dfd6e..99bf3c403b 100644 --- a/internal/query/idp_user_link.go +++ b/internal/query/idp_user_link.go @@ -106,12 +106,26 @@ func idpLinksCheckPermission(ctx context.Context, links *IDPUserLinks, permissio ) } +func idpLinksPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool, queries *IDPUserLinksSearchQuery) sq.SelectBuilder { + if !enabled { + return query + } + return query.Where(PermissionClause( + ctx, + IDPUserLinkResourceOwnerCol, + domain.PermissionUserRead, + SingleOrgPermissionOption(queries.Queries), + OwnedRowsPermissionOption(IDPUserLinkUserIDCol), + )) +} + func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, permissionCheck domain.PermissionCheck) (idps *IDPUserLinks, err error) { - links, err := q.idpUserLinks(ctx, queries, false) + permissionCheckV2 := PermissionV2(ctx, permissionCheck) + links, err := q.idpUserLinks(ctx, queries, permissionCheckV2) if err != nil { return nil, err } - if permissionCheck != nil && len(links.Links) > 0 { + if permissionCheck != nil && len(links.Links) > 0 && !permissionCheckV2 { // when userID for query is provided, only one check has to be done if queries.hasUserID() { if err := userCheckPermission(ctx, links.Links[0].ResourceOwner, links.Links[0].UserID, permissionCheck); err != nil { @@ -124,14 +138,15 @@ func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQ return links, nil } -func (q *Queries) idpUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, withOwnerRemoved bool) (idps *IDPUserLinks, err error) { +func (q *Queries) idpUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, permissionCheckV2 bool) (idps *IDPUserLinks, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() query, scan := prepareIDPUserLinksQuery() - eq := sq.Eq{IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()} - if !withOwnerRemoved { - eq[IDPUserLinkOwnerRemovedCol.identifier()] = false + query = idpLinksPermissionCheckV2(ctx, query, permissionCheckV2, queries) + eq := sq.Eq{ + IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), + IDPUserLinkOwnerRemovedCol.identifier(): false, } stmt, args, err := queries.toQuery(query).Where(eq).ToSql() if err != nil { diff --git a/internal/query/org.go b/internal/query/org.go index b1f5eaea02..643aec291a 100644 --- a/internal/query/org.go +++ b/internal/query/org.go @@ -93,6 +93,17 @@ func orgsCheckPermission(ctx context.Context, orgs *Orgs, permissionCheck domain ) } +func orgsPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder { + if !enabled { + return query + } + return query.Where(PermissionClause( + ctx, + OrgColumnID, + domain_pkg.PermissionOrgRead, + )) +} + type OrgSearchQueries struct { SearchRequest Queries []SearchQuery @@ -283,21 +294,23 @@ func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID } func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries, permissionCheck domain_pkg.PermissionCheck) (*Orgs, error) { - orgs, err := q.searchOrgs(ctx, queries) + permissionCheckV2 := PermissionV2(ctx, permissionCheck) + orgs, err := q.searchOrgs(ctx, queries, permissionCheckV2) if err != nil { return nil, err } - if permissionCheck != nil { + if permissionCheck != nil && !permissionCheckV2 { orgsCheckPermission(ctx, orgs, permissionCheck) } return orgs, nil } -func (q *Queries) searchOrgs(ctx context.Context, queries *OrgSearchQueries) (orgs *Orgs, err error) { +func (q *Queries) searchOrgs(ctx context.Context, queries *OrgSearchQueries, permissionCheckV2 bool) (orgs *Orgs, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() query, scan := prepareOrgsQuery() + query = orgsPermissionCheckV2(ctx, query, permissionCheckV2) stmt, args, err := queries.toQuery(query). Where(sq.And{ sq.Eq{ diff --git a/internal/query/permission.go b/internal/query/permission.go index c52b491144..3157430264 100644 --- a/internal/query/permission.go +++ b/internal/query/permission.go @@ -2,74 +2,109 @@ package query import ( "context" - "encoding/json" "fmt" sq "github.com/Masterminds/squirrel" "github.com/zitadel/logging" "github.com/zitadel/zitadel/internal/api/authz" - "github.com/zitadel/zitadel/internal/zerrors" + "github.com/zitadel/zitadel/internal/database" + domain_pkg "github.com/zitadel/zitadel/internal/domain" ) const ( - // eventstore.permitted_orgs(instanceid text, userid text, system_user_perms JSONB, perm text filter_orgs text) - wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))" - wherePermittedOrgsOrCurrentUserClause = "(" + wherePermittedOrgsClause + " OR %s = ?" + ")" + // eventstore.permitted_orgs(instanceid text, userid text, system_user_perms JSONB, perm text, filter_org text) + wherePermittedOrgsExpr = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))" ) -// wherePermittedOrgs sets a `WHERE` clause to the query that filters the orgs -// for which the authenticated user has the requested permission for. -// The user ID is taken from the context. -// The `orgIDColumn` specifies the table column to which this filter must be applied, -// and is typically the `resource_owner` column in ZITADEL. -// We use full identifiers in the query builder so this function should be -// called with something like `UserResourceOwnerCol.identifier()` for example. -// func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, permission string) (sq.SelectBuilder, error) { -// userID := authz.GetCtxData(ctx).UserID -// logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used") - -// systemUserPermissions := authz.GetSystemUserPermissions(ctx) -// var systemUserPermissionsJson []byte -// if systemUserPermissions != nil { -// var err error -// systemUserPermissionsJson, err = json.Marshal(systemUserPermissions) -// if err != nil { -// return query, err -// } -// } - -// return query.Where( -// fmt.Sprintf(wherePermittedOrgsClause, orgIDColumn), -// authz.GetInstance(ctx).InstanceID(), -// userID, -// systemUserPermissionsJson, -// permission, -// filterOrgIds, -// ), nil -// } - -func wherePermittedOrgsOrCurrentUser(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, userIdColum, permission string) (sq.SelectBuilder, error) { - userID := authz.GetCtxData(ctx).UserID - logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "user_id_colum", userIdColum, "permission", permission, "user_id", userID).Debug("permitted orgs check used") - - systemUserPermissions := authz.GetSystemUserPermissions(ctx) - var systemUserPermissionsJson []byte - if systemUserPermissions != nil { - var err error - systemUserPermissionsJson, err = json.Marshal(systemUserPermissions) - if err != nil { - return query, zerrors.ThrowInternal(err, "AUTHZ-HS4us", "Errors.Internal") - } - } - - return query.Where( - fmt.Sprintf(wherePermittedOrgsOrCurrentUserClause, orgIDColumn, userIdColum), - authz.GetInstance(ctx).InstanceID(), - userID, - systemUserPermissionsJson, - permission, - filterOrgIds, - userID, - ), nil +type permissionClauseBuilder struct { + orgIDColumn Column + instanceID string + userID string + systemPermissions []authz.SystemUserPermissions + permission string + orgID string + connections []sq.Eq +} + +func (b *permissionClauseBuilder) appendConnection(column string, value any) { + b.connections = append(b.connections, sq.Eq{column: value}) +} + +func (b *permissionClauseBuilder) clauses() sq.Or { + clauses := make(sq.Or, 1, len(b.connections)+1) + clauses[0] = sq.Expr( + fmt.Sprintf(wherePermittedOrgsExpr, b.orgIDColumn.identifier()), + b.instanceID, + b.userID, + database.NewJSONArray(b.systemPermissions), + b.permission, + b.orgID, + ) + for _, include := range b.connections { + clauses = append(clauses, include) + } + return clauses +} + +type PermissionOption func(b *permissionClauseBuilder) + +// OwnedRowsPermissionOption allows rows to be returned of which the current user is the owner. +// Even if the user does not have an explicit permission for the organization. +// For example an authenticated user can always see his own user account. +func OwnedRowsPermissionOption(userIDColumn Column) PermissionOption { + return func(b *permissionClauseBuilder) { + b.appendConnection(userIDColumn.identifier(), b.userID) + } +} + +// ConnectionPermissionOption allows returning of rows where the value is matched. +// Even if the user does not have an explicit permission for the organization. +func ConnectionPermissionOption(column Column, value any) PermissionOption { + return func(b *permissionClauseBuilder) { + b.appendConnection(column.identifier(), value) + } +} + +// SingleOrgPermissionOption may be used to optimize the permitted orgs function by limiting the +// returned organizations, to the one used in the requested filters. +func SingleOrgPermissionOption(queries []SearchQuery) PermissionOption { + return func(b *permissionClauseBuilder) { + b.orgID = findTextEqualsQuery(b.orgIDColumn, queries) + } +} + +// PermissionClause sets a `WHERE` clause to query, +// which filters returned rows the current authenticated user has the requested permission to. +// +// Experimental: Work in progress. Currently only organization permissions are supported +func PermissionClause(ctx context.Context, orgIDCol Column, permission string, options ...PermissionOption) sq.Or { + ctxData := authz.GetCtxData(ctx) + b := &permissionClauseBuilder{ + orgIDColumn: orgIDCol, + instanceID: authz.GetInstance(ctx).InstanceID(), + userID: ctxData.UserID, + systemPermissions: ctxData.SystemUserPermissions, + permission: permission, + } + for _, opt := range options { + opt(b) + } + logging.WithFields( + "org_id_column", b.orgIDColumn, + "instance_id", b.instanceID, + "user_id", b.userID, + "system_user_permissions", b.systemPermissions, + "permission", b.permission, + "org_id", b.orgID, + "overrides", b.connections, + ).Debug("permitted orgs check used") + + return b.clauses() +} + +// PermissionV2 checks are enabled when the feature flag is set and the permission check function is not nil. +// When the permission check function is nil, it indicates a v1 API and no resource based permission check is needed. +func PermissionV2(ctx context.Context, cf domain_pkg.PermissionCheck) bool { + return authz.GetFeatures(ctx).PermissionCheckV2 && cf != nil } diff --git a/internal/query/permission_test.go b/internal/query/permission_test.go new file mode 100644 index 0000000000..f6ecd94b46 --- /dev/null +++ b/internal/query/permission_test.go @@ -0,0 +1,208 @@ +package query + +import ( + "context" + "testing" + + sq "github.com/Masterminds/squirrel" + "github.com/stretchr/testify/assert" + + "github.com/zitadel/zitadel/internal/api/authz" + "github.com/zitadel/zitadel/internal/database" + domain_pkg "github.com/zitadel/zitadel/internal/domain" + "github.com/zitadel/zitadel/internal/feature" +) + +func TestPermissionClause(t *testing.T) { + var permissions = []authz.SystemUserPermissions{ + { + MemberType: authz.MemberTypeOrganization, + AggregateID: "orgID", + Permissions: []string{"permission1", "permission2"}, + }, + { + MemberType: authz.MemberTypeIAM, + Permissions: []string{"permission2", "permission3"}, + }, + } + ctx := authz.WithInstanceID(context.Background(), "instanceID") + ctx = authz.SetCtxData(ctx, authz.CtxData{ + UserID: "userID", + SystemUserPermissions: permissions, + }) + + type args struct { + ctx context.Context + orgIDCol Column + permission string + options []PermissionOption + } + tests := []struct { + name string + args args + wantClause sq.Or + }{ + { + name: "no options", + args: args{ + ctx: ctx, + orgIDCol: UserResourceOwnerCol, + permission: "permission1", + }, + wantClause: sq.Or{ + sq.Expr( + "projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))", + "instanceID", + "userID", + database.NewJSONArray(permissions), + "permission1", + "", + ), + }, + }, + { + name: "owned rows option", + args: args{ + ctx: ctx, + orgIDCol: UserResourceOwnerCol, + permission: "permission1", + options: []PermissionOption{ + OwnedRowsPermissionOption(UserIDCol), + }, + }, + wantClause: sq.Or{ + sq.Expr( + "projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))", + "instanceID", + "userID", + database.NewJSONArray(permissions), + "permission1", + "", + ), + sq.Eq{"projections.users14.id": "userID"}, + }, + }, + { + name: "connection rows option", + args: args{ + ctx: ctx, + orgIDCol: UserResourceOwnerCol, + permission: "permission1", + options: []PermissionOption{ + OwnedRowsPermissionOption(UserIDCol), + ConnectionPermissionOption(UserStateCol, "bar"), + }, + }, + wantClause: sq.Or{ + sq.Expr( + "projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))", + "instanceID", + "userID", + database.NewJSONArray(permissions), + "permission1", + "", + ), + sq.Eq{"projections.users14.id": "userID"}, + sq.Eq{"projections.users14.state": "bar"}, + }, + }, + { + name: "single org option", + args: args{ + ctx: ctx, + orgIDCol: UserResourceOwnerCol, + permission: "permission1", + options: []PermissionOption{ + SingleOrgPermissionOption([]SearchQuery{ + mustSearchQuery(NewUserDisplayNameSearchQuery("zitadel", TextContains)), + mustSearchQuery(NewUserResourceOwnerSearchQuery("orgID", TextEquals)), + }), + }, + }, + wantClause: sq.Or{ + sq.Expr( + "projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))", + "instanceID", + "userID", + database.NewJSONArray(permissions), + "permission1", + "orgID", + ), + }, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotClause := PermissionClause(tt.args.ctx, tt.args.orgIDCol, tt.args.permission, tt.args.options...) + assert.Equal(t, tt.wantClause, gotClause) + }) + } +} + +func mustSearchQuery(q SearchQuery, err error) SearchQuery { + if err != nil { + panic(err) + } + return q +} + +func TestPermissionV2(t *testing.T) { + type args struct { + ctx context.Context + cf domain_pkg.PermissionCheck + } + tests := []struct { + name string + args args + want bool + }{ + { + name: "feature disabled, no permission check", + args: args{ + ctx: context.Background(), + cf: nil, + }, + want: false, + }, + { + name: "feature enabled, no permission check", + args: args{ + ctx: authz.WithFeatures(context.Background(), feature.Features{ + PermissionCheckV2: true, + }), + cf: nil, + }, + want: false, + }, + { + name: "feature enabled, with permission check", + args: args{ + ctx: authz.WithFeatures(context.Background(), feature.Features{ + PermissionCheckV2: true, + }), + cf: func(context.Context, string, string, string) error { + return nil + }, + }, + want: true, + }, + { + name: "feature disabled, with permission check", + args: args{ + ctx: authz.WithFeatures(context.Background(), feature.Features{ + PermissionCheckV2: false, + }), + cf: func(context.Context, string, string, string) error { + return nil + }, + }, + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := PermissionV2(tt.args.ctx, tt.args.cf) + assert.Equal(t, tt.want, got) + }) + } +} diff --git a/internal/query/query.go b/internal/query/query.go index c0c051f7b7..bd50d3c0be 100644 --- a/internal/query/query.go +++ b/internal/query/query.go @@ -148,3 +148,16 @@ func triggerBatch(ctx context.Context, handlers ...*handler.Handler) { wg.Wait() } + +func findTextEqualsQuery(column Column, queries []SearchQuery) string { + for _, query := range queries { + if query.Col() != column { + continue + } + tq, ok := query.(*textQuery) + if ok && tq.Compare == TextEquals { + return tq.Text + } + } + return "" +} diff --git a/internal/query/session.go b/internal/query/session.go index 111eb462a0..004f29fe81 100644 --- a/internal/query/session.go +++ b/internal/query/session.go @@ -113,6 +113,22 @@ func sessionCheckPermission(ctx context.Context, resourceOwner string, creator s return permissionCheck(ctx, domain.PermissionSessionRead, resourceOwner, "") } +func sessionsPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder { + if !enabled { + return query + } + return query.Where(PermissionClause( + ctx, + SessionColumnResourceOwner, + domain.PermissionSessionRead, + // Allow if user is creator + OwnedRowsPermissionOption(SessionColumnCreator), + // Allow if session belongs to the user + OwnedRowsPermissionOption(SessionColumnUserID), + ConnectionPermissionOption(SessionColumnUserAgentFingerprintID, authz.GetCtxData(ctx).AgentID), + )) +} + func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder { query = q.SearchRequest.toQuery(query) for _, q := range q.Queries { @@ -282,21 +298,23 @@ func (q *Queries) sessionByID(ctx context.Context, shouldTriggerBulk bool, id st } func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheck domain.PermissionCheck) (*Sessions, error) { - sessions, err := q.searchSessions(ctx, queries) + permissionCheckV2 := PermissionV2(ctx, permissionCheck) + sessions, err := q.searchSessions(ctx, queries, permissionCheckV2) if err != nil { return nil, err } - if permissionCheck != nil { + if permissionCheck != nil && !permissionCheckV2 { sessionsCheckPermission(ctx, sessions, permissionCheck) } return sessions, nil } -func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) { +func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheckV2 bool) (sessions *Sessions, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() query, scan := prepareSessionsQuery() + query = sessionsPermissionCheckV2(ctx, query, permissionCheckV2) stmt, args, err := queries.toQuery(query). Where(sq.Eq{ SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(), diff --git a/internal/query/user.go b/internal/query/user.go index c30eaaec74..47694736c4 100644 --- a/internal/query/user.go +++ b/internal/query/user.go @@ -132,6 +132,19 @@ func usersCheckPermission(ctx context.Context, users *Users, permissionCheck dom ) } +func userPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool, queries *UserSearchQueries) sq.SelectBuilder { + if !enabled { + return query + } + return query.Where(PermissionClause( + ctx, + UserResourceOwnerCol, + domain.PermissionUserRead, + SingleOrgPermissionOption(queries.Queries), + OwnedRowsPermissionOption(UserIDCol), + )) +} + type UserSearchQueries struct { SearchRequest Queries []SearchQuery @@ -606,8 +619,9 @@ func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (c return count, err } -func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheck domain.PermissionCheck) (*Users, error) { - users, err := q.searchUsers(ctx, queries, filterOrgIds, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2) +func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) { + permissionCheckV2 := PermissionV2(ctx, permissionCheck) + users, err := q.searchUsers(ctx, queries, permissionCheckV2) if err != nil { return nil, err } @@ -617,22 +631,15 @@ func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, f return users, nil } -func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheckV2 bool) (users *Users, err error) { +func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheckV2 bool) (users *Users, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() query, scan := prepareUsersQuery() - query = queries.toQuery(query).Where(sq.Eq{ + query = userPermissionCheckV2(ctx, query, permissionCheckV2, queries) + stmt, args, err := queries.toQuery(query).Where(sq.Eq{ UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(), - }) - if permissionCheckV2 { - query, err = wherePermittedOrgsOrCurrentUser(ctx, query, filterOrgIds, UserResourceOwnerCol.identifier(), UserIDCol.identifier(), domain.PermissionUserRead) - if err != nil { - return nil, zerrors.ThrowInternal(err, "AUTHZ-HS4us", "Errors.Internal") - } - } - - stmt, args, err := query.ToSql() + }).ToSql() if err != nil { return nil, zerrors.ThrowInternal(err, "QUERY-Dgbg2", "Errors.Query.SQLStatment") } diff --git a/internal/query/user_auth_method.go b/internal/query/user_auth_method.go index 8b26389f1a..acf61bf0e6 100644 --- a/internal/query/user_auth_method.go +++ b/internal/query/user_auth_method.go @@ -104,6 +104,18 @@ func authMethodsCheckPermission(ctx context.Context, methods *AuthMethods, permi ) } +func userAuthMethodPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder { + if !enabled { + return query + } + return query.Where(PermissionClause( + ctx, + UserAuthMethodColumnResourceOwner, + domain.PermissionUserRead, + OwnedRowsPermissionOption(UserIDCol), + )) +} + type AuthMethod struct { UserID string CreationDate time.Time @@ -137,11 +149,12 @@ func (q *UserAuthMethodSearchQueries) hasUserID() bool { } func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheck domain.PermissionCheck) (userAuthMethods *AuthMethods, err error) { - methods, err := q.searchUserAuthMethods(ctx, queries) + permissionCheckV2 := PermissionV2(ctx, permissionCheck) + methods, err := q.searchUserAuthMethods(ctx, queries, permissionCheckV2) if err != nil { return nil, err } - if permissionCheck != nil && len(methods.AuthMethods) > 0 { + if permissionCheck != nil && len(methods.AuthMethods) > 0 && !permissionCheckV2 { // when userID for query is provided, only one check has to be done if queries.hasUserID() { if err := userCheckPermission(ctx, methods.AuthMethods[0].ResourceOwner, methods.AuthMethods[0].UserID, permissionCheck); err != nil { @@ -154,11 +167,12 @@ func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMe return methods, nil } -func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries) (userAuthMethods *AuthMethods, err error) { +func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheckV2 bool) (userAuthMethods *AuthMethods, err error) { ctx, span := tracing.NewSpan(ctx) defer func() { span.EndWithError(err) }() query, scan := prepareUserAuthMethodsQuery() + query = userAuthMethodPermissionCheckV2(ctx, query, permissionCheckV2) stmt, args, err := queries.toQuery(query).Where(sq.Eq{UserAuthMethodColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}).ToSql() if err != nil { return nil, zerrors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest")