mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 17:27:31 +00:00
perf(query): org permission function for resources (#9677)
# Which Problems Are Solved Classic permission checks execute for every returned row on resource based search APIs. Complete background and problem definition can be found here: https://github.com/zitadel/zitadel/issues/9188 # How the Problems Are Solved - PermissionClause function now support dynamic query building, so it supports multiple cases. - PermissionClause is applied to all list resources which support org level permissions. - Wrap permission logic into wrapper functions so we keep the business logic clean. # Additional Changes - Handle org ID optimization in the query package, so it is reusable for all resources, instead of extracting the filter in the API. - Cleanup and test system user conversion in the authz package. (context middleware) - Fix: `core_integration_db_up` make recipe was missing the postgres service. # Additional Context - Related to https://github.com/zitadel/zitadel/issues/9190
This commit is contained in:
2
Makefile
2
Makefile
@@ -112,7 +112,7 @@ core_unit_test:
|
|||||||
|
|
||||||
.PHONY: core_integration_db_up
|
.PHONY: core_integration_db_up
|
||||||
core_integration_db_up:
|
core_integration_db_up:
|
||||||
docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait cache
|
docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait cache postgres
|
||||||
|
|
||||||
.PHONY: core_integration_db_down
|
.PHONY: core_integration_db_down
|
||||||
core_integration_db_down:
|
core_integration_db_down:
|
||||||
|
@@ -7,8 +7,6 @@ import (
|
|||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/zitadel/logging"
|
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
"github.com/zitadel/zitadel/internal/zerrors"
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
)
|
)
|
||||||
@@ -26,14 +24,13 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID,
|
|||||||
ctx, span := tracing.NewServerInterceptorSpan(ctx)
|
ctx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, orgDomain, verifier)
|
ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, orgDomain, verifier, systemRolePermissionMapping)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if requiredAuthOption.Permission == authenticated {
|
if requiredAuthOption.Permission == authenticated {
|
||||||
return func(parent context.Context) context.Context {
|
return func(parent context.Context) context.Context {
|
||||||
parent = addGetSystemUserRolesToCtx(parent, systemRolePermissionMapping, ctxData)
|
|
||||||
return context.WithValue(parent, dataKey, ctxData)
|
return context.WithValue(parent, dataKey, ctxData)
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -54,7 +51,6 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID,
|
|||||||
parent = context.WithValue(parent, dataKey, ctxData)
|
parent = context.WithValue(parent, dataKey, ctxData)
|
||||||
parent = context.WithValue(parent, allPermissionsKey, allPermissions)
|
parent = context.WithValue(parent, allPermissionsKey, allPermissions)
|
||||||
parent = context.WithValue(parent, requestPermissionsKey, requestedPermissions)
|
parent = context.WithValue(parent, requestPermissionsKey, requestedPermissions)
|
||||||
parent = addGetSystemUserRolesToCtx(parent, systemRolePermissionMapping, ctxData)
|
|
||||||
return parent
|
return parent
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
@@ -131,42 +127,32 @@ func GetAllPermissionCtxIDs(perms []string) []string {
|
|||||||
return ctxIDs
|
return ctxIDs
|
||||||
}
|
}
|
||||||
|
|
||||||
type SystemUserPermissionsDBQuery struct {
|
type SystemUserPermissions struct {
|
||||||
MemberType string `json:"member_type"`
|
MemberType MemberType `json:"member_type"`
|
||||||
AggregateID string `json:"aggregate_id"`
|
AggregateID string `json:"aggregate_id"`
|
||||||
ObjectID string `json:"object_id"`
|
ObjectID string `json:"object_id"`
|
||||||
Permissions []string `json:"permissions"`
|
Permissions []string `json:"permissions"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func addGetSystemUserRolesToCtx(ctx context.Context, systemUserRoleMap []RoleMapping, ctxData CtxData) context.Context {
|
// systemMembershipsToUserPermissions converts system memberships based on roles,
|
||||||
if len(ctxData.SystemMemberships) == 0 {
|
// to SystemUserPermissions, using the passed role mapping.
|
||||||
return ctx
|
func systemMembershipsToUserPermissions(memberships Memberships, roleMap []RoleMapping) []SystemUserPermissions {
|
||||||
|
if memberships == nil {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
systemUserPermissions := make([]SystemUserPermissionsDBQuery, len(ctxData.SystemMemberships))
|
systemUserPermissions := make([]SystemUserPermissions, len(memberships))
|
||||||
for i, systemPerm := range ctxData.SystemMemberships {
|
for i, systemPerm := range memberships {
|
||||||
permissions := make([]string, 0, len(systemPerm.Roles))
|
permissions := make([]string, 0, len(systemPerm.Roles))
|
||||||
for _, role := range systemPerm.Roles {
|
for _, role := range systemPerm.Roles {
|
||||||
permissions = append(permissions, getPermissionsFromRole(systemUserRoleMap, role)...)
|
permissions = append(permissions, getPermissionsFromRole(roleMap, role)...)
|
||||||
}
|
}
|
||||||
slices.Sort(permissions)
|
slices.Sort(permissions)
|
||||||
permissions = slices.Compact(permissions)
|
permissions = slices.Compact(permissions) // remove duplicates
|
||||||
|
|
||||||
systemUserPermissions[i].MemberType = systemPerm.MemberType.String()
|
systemUserPermissions[i].MemberType = systemPerm.MemberType
|
||||||
systemUserPermissions[i].AggregateID = systemPerm.AggregateID
|
systemUserPermissions[i].AggregateID = systemPerm.AggregateID
|
||||||
|
systemUserPermissions[i].ObjectID = systemPerm.ObjectID
|
||||||
systemUserPermissions[i].Permissions = permissions
|
systemUserPermissions[i].Permissions = permissions
|
||||||
}
|
}
|
||||||
return context.WithValue(ctx, systemUserRolesKey, systemUserPermissions)
|
return systemUserPermissions
|
||||||
}
|
|
||||||
|
|
||||||
func GetSystemUserPermissions(ctx context.Context) []SystemUserPermissionsDBQuery {
|
|
||||||
getSystemUserRolesFuncValue := ctx.Value(systemUserRolesKey)
|
|
||||||
if getSystemUserRolesFuncValue == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
systemUserRoles, ok := getSystemUserRolesFuncValue.([]SystemUserPermissionsDBQuery)
|
|
||||||
if !ok {
|
|
||||||
logging.WithFields("Authz").Error("unable to cast []SystemUserPermissionsDBQuery")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return systemUserRoles
|
|
||||||
}
|
}
|
||||||
|
@@ -3,6 +3,8 @@ package authz
|
|||||||
import (
|
import (
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/zerrors"
|
"github.com/zitadel/zitadel/internal/zerrors"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -276,3 +278,127 @@ func Test_GetPermissionCtxIDs(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_systemMembershipsToUserPermissions(t *testing.T) {
|
||||||
|
roleMap := []RoleMapping{
|
||||||
|
{
|
||||||
|
Role: "FOO_BAR",
|
||||||
|
Permissions: []string{"foo.bar.read", "foo.bar.write"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Role: "BAR_FOO",
|
||||||
|
Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
memberships Memberships
|
||||||
|
roleMap []RoleMapping
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want []SystemUserPermissions
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil memberships",
|
||||||
|
args: args{
|
||||||
|
memberships: nil,
|
||||||
|
roleMap: roleMap,
|
||||||
|
},
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty memberships",
|
||||||
|
args: args{
|
||||||
|
memberships: Memberships{},
|
||||||
|
roleMap: roleMap,
|
||||||
|
},
|
||||||
|
want: []SystemUserPermissions{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single membership",
|
||||||
|
args: args{
|
||||||
|
memberships: Memberships{
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeSystem,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Roles: []string{"FOO_BAR"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
roleMap: roleMap,
|
||||||
|
},
|
||||||
|
want: []SystemUserPermissions{
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeSystem,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Permissions: []string{"foo.bar.read", "foo.bar.write"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple memberships",
|
||||||
|
args: args{
|
||||||
|
memberships: Memberships{
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeSystem,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Roles: []string{"FOO_BAR"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeIAM,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Roles: []string{"BAR_FOO"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
roleMap: roleMap,
|
||||||
|
},
|
||||||
|
want: []SystemUserPermissions{
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeSystem,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Permissions: []string{"foo.bar.read", "foo.bar.write"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeIAM,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multiple roles",
|
||||||
|
args: args{
|
||||||
|
memberships: Memberships{
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeSystem,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Roles: []string{"FOO_BAR", "BAR_FOO"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
roleMap: roleMap,
|
||||||
|
},
|
||||||
|
want: []SystemUserPermissions{
|
||||||
|
{
|
||||||
|
MemberType: MemberTypeSystem,
|
||||||
|
AggregateID: "1",
|
||||||
|
ObjectID: "2",
|
||||||
|
Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read", "foo.bar.write"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := systemMembershipsToUserPermissions(tt.args.memberships, tt.args.roleMap)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -1,4 +1,4 @@
|
|||||||
//go:generate enumer -type MemberType -trimprefix MemberType
|
//go:generate enumer -type MemberType -trimprefix MemberType -json
|
||||||
|
|
||||||
package authz
|
package authz
|
||||||
|
|
||||||
@@ -22,7 +22,6 @@ const (
|
|||||||
dataKey key = 2
|
dataKey key = 2
|
||||||
allPermissionsKey key = 3
|
allPermissionsKey key = 3
|
||||||
instanceKey key = 4
|
instanceKey key = 4
|
||||||
systemUserRolesKey key = 5
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type CtxData struct {
|
type CtxData struct {
|
||||||
@@ -33,6 +32,7 @@ type CtxData struct {
|
|||||||
PreferredLanguage string
|
PreferredLanguage string
|
||||||
ResourceOwner string
|
ResourceOwner string
|
||||||
SystemMemberships Memberships
|
SystemMemberships Memberships
|
||||||
|
SystemUserPermissions []SystemUserPermissions
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ctxData CtxData) IsZero() bool {
|
func (ctxData CtxData) IsZero() bool {
|
||||||
@@ -98,7 +98,7 @@ func (s SystemTokenVerifierFunc) VerifySystemToken(ctx context.Context, token st
|
|||||||
return s(ctx, token, orgID)
|
return s(ctx, token, orgID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain string, t APITokenVerifier) (_ CtxData, err error) {
|
func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain string, t APITokenVerifier, systemRoleMap []RoleMapping) (_ CtxData, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
tokenWOBearer, err := extractBearerToken(token)
|
tokenWOBearer, err := extractBearerToken(token)
|
||||||
@@ -140,6 +140,7 @@ func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain st
|
|||||||
PreferredLanguage: prefLang,
|
PreferredLanguage: prefLang,
|
||||||
ResourceOwner: resourceOwner,
|
ResourceOwner: resourceOwner,
|
||||||
SystemMemberships: sysMemberships,
|
SystemMemberships: sysMemberships,
|
||||||
|
SystemUserPermissions: systemMembershipsToUserPermissions(sysMemberships, systemRoleMap),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@@ -1,8 +1,9 @@
|
|||||||
// Code generated by "enumer -type MemberType -trimprefix MemberType"; DO NOT EDIT.
|
// Code generated by "enumer -type MemberType -trimprefix MemberType -json"; DO NOT EDIT.
|
||||||
|
|
||||||
package authz
|
package authz
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
@@ -92,3 +93,20 @@ func (i MemberType) IsAMemberType() bool {
|
|||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MarshalJSON implements the json.Marshaler interface for MemberType
|
||||||
|
func (i MemberType) MarshalJSON() ([]byte, error) {
|
||||||
|
return json.Marshal(i.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnmarshalJSON implements the json.Unmarshaler interface for MemberType
|
||||||
|
func (i *MemberType) UnmarshalJSON(data []byte) error {
|
||||||
|
var s string
|
||||||
|
if err := json.Unmarshal(data, &s); err != nil {
|
||||||
|
return fmt.Errorf("MemberType should be a string, got %s", data)
|
||||||
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
*i, err = MemberTypeString(s)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
@@ -554,7 +554,7 @@ func (s *Server) getUsers(ctx context.Context, org string, withPasswords bool, w
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, org, nil)
|
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, nil, nil, err
|
return nil, nil, nil, nil, err
|
||||||
}
|
}
|
||||||
|
@@ -108,7 +108,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain str
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
|
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -329,7 +329,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain, or
|
|||||||
}
|
}
|
||||||
queries = append(queries, owner)
|
queries = append(queries, owner)
|
||||||
}
|
}
|
||||||
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, orgID, nil)
|
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -69,7 +69,7 @@ func (s *Server) ListUsers(ctx context.Context, req *mgmt_pb.ListUsersRequest) (
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
res, err := s.query.SearchUsers(ctx, queries, orgID, nil)
|
res, err := s.query.SearchUsers(ctx, queries, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -30,11 +30,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
|
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
|
||||||
queries, filterOrgId, err := listUsersRequestToModel(req)
|
queries, err := listUsersRequestToModel(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
res, err := s.query.SearchUsers(ctx, queries, filterOrgId, s.checkPermission)
|
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -171,11 +171,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
|
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
|
||||||
offset, limit, asc := object.ListQueryToQuery(req.Query)
|
offset, limit, asc := object.ListQueryToQuery(req.Query)
|
||||||
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &query.UserSearchQueries{
|
return &query.UserSearchQueries{
|
||||||
SearchRequest: query.SearchRequest{
|
SearchRequest: query.SearchRequest{
|
||||||
@@ -185,7 +185,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
|
|||||||
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
|
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
|
||||||
},
|
},
|
||||||
Queries: queries,
|
Queries: queries,
|
||||||
}, filterOrgId, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
||||||
@@ -215,18 +215,15 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
|
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
|
||||||
q := make([]query.SearchQuery, len(queries))
|
q := make([]query.SearchQuery, len(queries))
|
||||||
for i, query := range queries {
|
for i, query := range queries {
|
||||||
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
|
|
||||||
filterOrgId = orgFilter.OrganizationId
|
|
||||||
}
|
|
||||||
q[i], err = userQueryToQuery(query, level)
|
q[i], err = userQueryToQuery(query, level)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, filterOrgId, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return q, filterOrgId, nil
|
return q, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
|
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
|
||||||
@@ -320,14 +317,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
|
|||||||
return query.NewUserInUserIdsSearchQuery(q.UserIds)
|
return query.NewUserInUserIdsSearchQuery(q.UserIds)
|
||||||
}
|
}
|
||||||
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
|
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
|
||||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return query.NewUserOrSearchQuery(mappedQueries)
|
return query.NewUserOrSearchQuery(mappedQueries)
|
||||||
}
|
}
|
||||||
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
|
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
|
||||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
|
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
|
||||||
queries, filterOrgIds, err := listUsersRequestToModel(req)
|
queries, err := listUsersRequestToModel(req)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
res, err := s.query.SearchUsers(ctx, queries, filterOrgIds, s.checkPermission)
|
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -165,11 +165,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
|
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
|
||||||
offset, limit, asc := object.ListQueryToQuery(req.Query)
|
offset, limit, asc := object.ListQueryToQuery(req.Query)
|
||||||
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, err
|
||||||
}
|
}
|
||||||
return &query.UserSearchQueries{
|
return &query.UserSearchQueries{
|
||||||
SearchRequest: query.SearchRequest{
|
SearchRequest: query.SearchRequest{
|
||||||
@@ -179,7 +179,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
|
|||||||
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
|
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
|
||||||
},
|
},
|
||||||
Queries: queries,
|
Queries: queries,
|
||||||
}, filterOrgId, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
||||||
@@ -209,18 +209,15 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
|
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
|
||||||
q := make([]query.SearchQuery, len(queries))
|
q := make([]query.SearchQuery, len(queries))
|
||||||
for i, query := range queries {
|
for i, query := range queries {
|
||||||
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
|
|
||||||
filterOrgId = orgFilter.OrganizationId
|
|
||||||
}
|
|
||||||
q[i], err = userQueryToQuery(query, level)
|
q[i], err = userQueryToQuery(query, level)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, filterOrgId, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return q, filterOrgId, nil
|
return q, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
|
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
|
||||||
@@ -314,14 +311,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
|
|||||||
return query.NewUserInUserIdsSearchQuery(q.UserIds)
|
return query.NewUserInUserIdsSearchQuery(q.UserIds)
|
||||||
}
|
}
|
||||||
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
|
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
|
||||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
return query.NewUserOrSearchQuery(mappedQueries)
|
return query.NewUserOrSearchQuery(mappedQueries)
|
||||||
}
|
}
|
||||||
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
|
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
|
||||||
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
|
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -240,7 +240,7 @@ func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListRes
|
|||||||
return NewListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil
|
return NewListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
users, err := h.query.SearchUsers(ctx, q, authz.GetCtxData(ctx).OrgID, nil)
|
users, err := h.query.SearchUsers(ctx, q, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -182,7 +182,7 @@ func (l *Login) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgName string
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
|
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -225,3 +225,37 @@ func (d *NullDuration) Scan(src any) error {
|
|||||||
d.Duration, d.Valid = time.Duration(*duration), true
|
d.Duration, d.Valid = time.Duration(*duration), true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JSONArray allows sending and receiving JSON arrays to and from the database.
|
||||||
|
// It implements the [database/sql.Scanner] and [database/sql/driver.Valuer] interfaces.
|
||||||
|
// Values are marshaled and unmarshaled using the [encoding/json] package.
|
||||||
|
type JSONArray[T any] []T
|
||||||
|
|
||||||
|
// NewJSONArray wraps an existing slice into a JSONArray.
|
||||||
|
func NewJSONArray[T any](a []T) JSONArray[T] {
|
||||||
|
return JSONArray[T](a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Scan implements the [database/sql.Scanner] interface.
|
||||||
|
func (a *JSONArray[T]) Scan(src any) error {
|
||||||
|
if src == nil {
|
||||||
|
*a = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
bytes := src.([]byte)
|
||||||
|
if len(bytes) == 0 {
|
||||||
|
*a = nil
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return json.Unmarshal(bytes, a)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Value implements the [database/sql/driver.Valuer] interface.
|
||||||
|
func (a JSONArray[T]) Value() (driver.Value, error) {
|
||||||
|
if a == nil {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
return json.Marshal(a)
|
||||||
|
}
|
||||||
|
@@ -5,6 +5,7 @@ import (
|
|||||||
"database/sql/driver"
|
"database/sql/driver"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/muhlemmer/gu"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
@@ -452,3 +453,89 @@ func TestDuration_Scan(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestJSONArray_Scan(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
src any
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want *JSONArray[string]
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
args: args{src: nil},
|
||||||
|
want: new(JSONArray[string]),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "zero bytes",
|
||||||
|
args: args{src: []byte("")},
|
||||||
|
want: new(JSONArray[string]),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
args: args{src: []byte("[]")},
|
||||||
|
want: gu.Ptr(JSONArray[string]{}),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
args: args{src: []byte("[\"a\", \"b\"]")},
|
||||||
|
want: gu.Ptr(JSONArray[string]{"a", "b"}),
|
||||||
|
wantErr: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "json error",
|
||||||
|
args: args{src: []byte("{\"a\": \"b\"}")},
|
||||||
|
want: new(JSONArray[string]),
|
||||||
|
wantErr: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := new(JSONArray[string])
|
||||||
|
err := got.Scan(tt.args.src)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJSONArray_Value(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
a []string
|
||||||
|
want driver.Value
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil",
|
||||||
|
a: nil,
|
||||||
|
want: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
a: []string{},
|
||||||
|
want: []byte("[]"),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ok",
|
||||||
|
a: []string{"a", "b"},
|
||||||
|
want: []byte("[\"a\",\"b\"]"),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := NewJSONArray(tt.a).Value()
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
@@ -106,12 +106,26 @@ func idpLinksCheckPermission(ctx context.Context, links *IDPUserLinks, permissio
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func idpLinksPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool, queries *IDPUserLinksSearchQuery) sq.SelectBuilder {
|
||||||
|
if !enabled {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return query.Where(PermissionClause(
|
||||||
|
ctx,
|
||||||
|
IDPUserLinkResourceOwnerCol,
|
||||||
|
domain.PermissionUserRead,
|
||||||
|
SingleOrgPermissionOption(queries.Queries),
|
||||||
|
OwnedRowsPermissionOption(IDPUserLinkUserIDCol),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, permissionCheck domain.PermissionCheck) (idps *IDPUserLinks, err error) {
|
func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, permissionCheck domain.PermissionCheck) (idps *IDPUserLinks, err error) {
|
||||||
links, err := q.idpUserLinks(ctx, queries, false)
|
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
|
||||||
|
links, err := q.idpUserLinks(ctx, queries, permissionCheckV2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if permissionCheck != nil && len(links.Links) > 0 {
|
if permissionCheck != nil && len(links.Links) > 0 && !permissionCheckV2 {
|
||||||
// when userID for query is provided, only one check has to be done
|
// when userID for query is provided, only one check has to be done
|
||||||
if queries.hasUserID() {
|
if queries.hasUserID() {
|
||||||
if err := userCheckPermission(ctx, links.Links[0].ResourceOwner, links.Links[0].UserID, permissionCheck); err != nil {
|
if err := userCheckPermission(ctx, links.Links[0].ResourceOwner, links.Links[0].UserID, permissionCheck); err != nil {
|
||||||
@@ -124,14 +138,15 @@ func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQ
|
|||||||
return links, nil
|
return links, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) idpUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, withOwnerRemoved bool) (idps *IDPUserLinks, err error) {
|
func (q *Queries) idpUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, permissionCheckV2 bool) (idps *IDPUserLinks, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
query, scan := prepareIDPUserLinksQuery()
|
query, scan := prepareIDPUserLinksQuery()
|
||||||
eq := sq.Eq{IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()}
|
query = idpLinksPermissionCheckV2(ctx, query, permissionCheckV2, queries)
|
||||||
if !withOwnerRemoved {
|
eq := sq.Eq{
|
||||||
eq[IDPUserLinkOwnerRemovedCol.identifier()] = false
|
IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||||
|
IDPUserLinkOwnerRemovedCol.identifier(): false,
|
||||||
}
|
}
|
||||||
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
|
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@@ -93,6 +93,17 @@ func orgsCheckPermission(ctx context.Context, orgs *Orgs, permissionCheck domain
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func orgsPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder {
|
||||||
|
if !enabled {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return query.Where(PermissionClause(
|
||||||
|
ctx,
|
||||||
|
OrgColumnID,
|
||||||
|
domain_pkg.PermissionOrgRead,
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
type OrgSearchQueries struct {
|
type OrgSearchQueries struct {
|
||||||
SearchRequest
|
SearchRequest
|
||||||
Queries []SearchQuery
|
Queries []SearchQuery
|
||||||
@@ -283,21 +294,23 @@ func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries, permissionCheck domain_pkg.PermissionCheck) (*Orgs, error) {
|
func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries, permissionCheck domain_pkg.PermissionCheck) (*Orgs, error) {
|
||||||
orgs, err := q.searchOrgs(ctx, queries)
|
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
|
||||||
|
orgs, err := q.searchOrgs(ctx, queries, permissionCheckV2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if permissionCheck != nil {
|
if permissionCheck != nil && !permissionCheckV2 {
|
||||||
orgsCheckPermission(ctx, orgs, permissionCheck)
|
orgsCheckPermission(ctx, orgs, permissionCheck)
|
||||||
}
|
}
|
||||||
return orgs, nil
|
return orgs, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) searchOrgs(ctx context.Context, queries *OrgSearchQueries) (orgs *Orgs, err error) {
|
func (q *Queries) searchOrgs(ctx context.Context, queries *OrgSearchQueries, permissionCheckV2 bool) (orgs *Orgs, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
query, scan := prepareOrgsQuery()
|
query, scan := prepareOrgsQuery()
|
||||||
|
query = orgsPermissionCheckV2(ctx, query, permissionCheckV2)
|
||||||
stmt, args, err := queries.toQuery(query).
|
stmt, args, err := queries.toQuery(query).
|
||||||
Where(sq.And{
|
Where(sq.And{
|
||||||
sq.Eq{
|
sq.Eq{
|
||||||
|
@@ -2,74 +2,109 @@ package query
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
sq "github.com/Masterminds/squirrel"
|
sq "github.com/Masterminds/squirrel"
|
||||||
"github.com/zitadel/logging"
|
"github.com/zitadel/logging"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/api/authz"
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
"github.com/zitadel/zitadel/internal/zerrors"
|
"github.com/zitadel/zitadel/internal/database"
|
||||||
|
domain_pkg "github.com/zitadel/zitadel/internal/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// eventstore.permitted_orgs(instanceid text, userid text, system_user_perms JSONB, perm text filter_orgs text)
|
// eventstore.permitted_orgs(instanceid text, userid text, system_user_perms JSONB, perm text, filter_org text)
|
||||||
wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))"
|
wherePermittedOrgsExpr = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))"
|
||||||
wherePermittedOrgsOrCurrentUserClause = "(" + wherePermittedOrgsClause + " OR %s = ?" + ")"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// wherePermittedOrgs sets a `WHERE` clause to the query that filters the orgs
|
type permissionClauseBuilder struct {
|
||||||
// for which the authenticated user has the requested permission for.
|
orgIDColumn Column
|
||||||
// The user ID is taken from the context.
|
instanceID string
|
||||||
// The `orgIDColumn` specifies the table column to which this filter must be applied,
|
userID string
|
||||||
// and is typically the `resource_owner` column in ZITADEL.
|
systemPermissions []authz.SystemUserPermissions
|
||||||
// We use full identifiers in the query builder so this function should be
|
permission string
|
||||||
// called with something like `UserResourceOwnerCol.identifier()` for example.
|
orgID string
|
||||||
// func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, permission string) (sq.SelectBuilder, error) {
|
connections []sq.Eq
|
||||||
// userID := authz.GetCtxData(ctx).UserID
|
}
|
||||||
// logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
|
|
||||||
|
func (b *permissionClauseBuilder) appendConnection(column string, value any) {
|
||||||
// systemUserPermissions := authz.GetSystemUserPermissions(ctx)
|
b.connections = append(b.connections, sq.Eq{column: value})
|
||||||
// var systemUserPermissionsJson []byte
|
}
|
||||||
// if systemUserPermissions != nil {
|
|
||||||
// var err error
|
func (b *permissionClauseBuilder) clauses() sq.Or {
|
||||||
// systemUserPermissionsJson, err = json.Marshal(systemUserPermissions)
|
clauses := make(sq.Or, 1, len(b.connections)+1)
|
||||||
// if err != nil {
|
clauses[0] = sq.Expr(
|
||||||
// return query, err
|
fmt.Sprintf(wherePermittedOrgsExpr, b.orgIDColumn.identifier()),
|
||||||
// }
|
b.instanceID,
|
||||||
// }
|
b.userID,
|
||||||
|
database.NewJSONArray(b.systemPermissions),
|
||||||
// return query.Where(
|
b.permission,
|
||||||
// fmt.Sprintf(wherePermittedOrgsClause, orgIDColumn),
|
b.orgID,
|
||||||
// authz.GetInstance(ctx).InstanceID(),
|
)
|
||||||
// userID,
|
for _, include := range b.connections {
|
||||||
// systemUserPermissionsJson,
|
clauses = append(clauses, include)
|
||||||
// permission,
|
}
|
||||||
// filterOrgIds,
|
return clauses
|
||||||
// ), nil
|
}
|
||||||
// }
|
|
||||||
|
type PermissionOption func(b *permissionClauseBuilder)
|
||||||
func wherePermittedOrgsOrCurrentUser(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, userIdColum, permission string) (sq.SelectBuilder, error) {
|
|
||||||
userID := authz.GetCtxData(ctx).UserID
|
// OwnedRowsPermissionOption allows rows to be returned of which the current user is the owner.
|
||||||
logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "user_id_colum", userIdColum, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
|
// Even if the user does not have an explicit permission for the organization.
|
||||||
|
// For example an authenticated user can always see his own user account.
|
||||||
systemUserPermissions := authz.GetSystemUserPermissions(ctx)
|
func OwnedRowsPermissionOption(userIDColumn Column) PermissionOption {
|
||||||
var systemUserPermissionsJson []byte
|
return func(b *permissionClauseBuilder) {
|
||||||
if systemUserPermissions != nil {
|
b.appendConnection(userIDColumn.identifier(), b.userID)
|
||||||
var err error
|
}
|
||||||
systemUserPermissionsJson, err = json.Marshal(systemUserPermissions)
|
}
|
||||||
if err != nil {
|
|
||||||
return query, zerrors.ThrowInternal(err, "AUTHZ-HS4us", "Errors.Internal")
|
// ConnectionPermissionOption allows returning of rows where the value is matched.
|
||||||
}
|
// Even if the user does not have an explicit permission for the organization.
|
||||||
}
|
func ConnectionPermissionOption(column Column, value any) PermissionOption {
|
||||||
|
return func(b *permissionClauseBuilder) {
|
||||||
return query.Where(
|
b.appendConnection(column.identifier(), value)
|
||||||
fmt.Sprintf(wherePermittedOrgsOrCurrentUserClause, orgIDColumn, userIdColum),
|
}
|
||||||
authz.GetInstance(ctx).InstanceID(),
|
}
|
||||||
userID,
|
|
||||||
systemUserPermissionsJson,
|
// SingleOrgPermissionOption may be used to optimize the permitted orgs function by limiting the
|
||||||
permission,
|
// returned organizations, to the one used in the requested filters.
|
||||||
filterOrgIds,
|
func SingleOrgPermissionOption(queries []SearchQuery) PermissionOption {
|
||||||
userID,
|
return func(b *permissionClauseBuilder) {
|
||||||
), nil
|
b.orgID = findTextEqualsQuery(b.orgIDColumn, queries)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// PermissionClause sets a `WHERE` clause to query,
|
||||||
|
// which filters returned rows the current authenticated user has the requested permission to.
|
||||||
|
//
|
||||||
|
// Experimental: Work in progress. Currently only organization permissions are supported
|
||||||
|
func PermissionClause(ctx context.Context, orgIDCol Column, permission string, options ...PermissionOption) sq.Or {
|
||||||
|
ctxData := authz.GetCtxData(ctx)
|
||||||
|
b := &permissionClauseBuilder{
|
||||||
|
orgIDColumn: orgIDCol,
|
||||||
|
instanceID: authz.GetInstance(ctx).InstanceID(),
|
||||||
|
userID: ctxData.UserID,
|
||||||
|
systemPermissions: ctxData.SystemUserPermissions,
|
||||||
|
permission: permission,
|
||||||
|
}
|
||||||
|
for _, opt := range options {
|
||||||
|
opt(b)
|
||||||
|
}
|
||||||
|
logging.WithFields(
|
||||||
|
"org_id_column", b.orgIDColumn,
|
||||||
|
"instance_id", b.instanceID,
|
||||||
|
"user_id", b.userID,
|
||||||
|
"system_user_permissions", b.systemPermissions,
|
||||||
|
"permission", b.permission,
|
||||||
|
"org_id", b.orgID,
|
||||||
|
"overrides", b.connections,
|
||||||
|
).Debug("permitted orgs check used")
|
||||||
|
|
||||||
|
return b.clauses()
|
||||||
|
}
|
||||||
|
|
||||||
|
// PermissionV2 checks are enabled when the feature flag is set and the permission check function is not nil.
|
||||||
|
// When the permission check function is nil, it indicates a v1 API and no resource based permission check is needed.
|
||||||
|
func PermissionV2(ctx context.Context, cf domain_pkg.PermissionCheck) bool {
|
||||||
|
return authz.GetFeatures(ctx).PermissionCheckV2 && cf != nil
|
||||||
}
|
}
|
||||||
|
208
internal/query/permission_test.go
Normal file
208
internal/query/permission_test.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package query
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
sq "github.com/Masterminds/squirrel"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
|
||||||
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
"github.com/zitadel/zitadel/internal/database"
|
||||||
|
domain_pkg "github.com/zitadel/zitadel/internal/domain"
|
||||||
|
"github.com/zitadel/zitadel/internal/feature"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPermissionClause(t *testing.T) {
|
||||||
|
var permissions = []authz.SystemUserPermissions{
|
||||||
|
{
|
||||||
|
MemberType: authz.MemberTypeOrganization,
|
||||||
|
AggregateID: "orgID",
|
||||||
|
Permissions: []string{"permission1", "permission2"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MemberType: authz.MemberTypeIAM,
|
||||||
|
Permissions: []string{"permission2", "permission3"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
ctx := authz.WithInstanceID(context.Background(), "instanceID")
|
||||||
|
ctx = authz.SetCtxData(ctx, authz.CtxData{
|
||||||
|
UserID: "userID",
|
||||||
|
SystemUserPermissions: permissions,
|
||||||
|
})
|
||||||
|
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
orgIDCol Column
|
||||||
|
permission string
|
||||||
|
options []PermissionOption
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
wantClause sq.Or
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "no options",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
orgIDCol: UserResourceOwnerCol,
|
||||||
|
permission: "permission1",
|
||||||
|
},
|
||||||
|
wantClause: sq.Or{
|
||||||
|
sq.Expr(
|
||||||
|
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
|
||||||
|
"instanceID",
|
||||||
|
"userID",
|
||||||
|
database.NewJSONArray(permissions),
|
||||||
|
"permission1",
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "owned rows option",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
orgIDCol: UserResourceOwnerCol,
|
||||||
|
permission: "permission1",
|
||||||
|
options: []PermissionOption{
|
||||||
|
OwnedRowsPermissionOption(UserIDCol),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantClause: sq.Or{
|
||||||
|
sq.Expr(
|
||||||
|
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
|
||||||
|
"instanceID",
|
||||||
|
"userID",
|
||||||
|
database.NewJSONArray(permissions),
|
||||||
|
"permission1",
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
sq.Eq{"projections.users14.id": "userID"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "connection rows option",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
orgIDCol: UserResourceOwnerCol,
|
||||||
|
permission: "permission1",
|
||||||
|
options: []PermissionOption{
|
||||||
|
OwnedRowsPermissionOption(UserIDCol),
|
||||||
|
ConnectionPermissionOption(UserStateCol, "bar"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantClause: sq.Or{
|
||||||
|
sq.Expr(
|
||||||
|
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
|
||||||
|
"instanceID",
|
||||||
|
"userID",
|
||||||
|
database.NewJSONArray(permissions),
|
||||||
|
"permission1",
|
||||||
|
"",
|
||||||
|
),
|
||||||
|
sq.Eq{"projections.users14.id": "userID"},
|
||||||
|
sq.Eq{"projections.users14.state": "bar"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "single org option",
|
||||||
|
args: args{
|
||||||
|
ctx: ctx,
|
||||||
|
orgIDCol: UserResourceOwnerCol,
|
||||||
|
permission: "permission1",
|
||||||
|
options: []PermissionOption{
|
||||||
|
SingleOrgPermissionOption([]SearchQuery{
|
||||||
|
mustSearchQuery(NewUserDisplayNameSearchQuery("zitadel", TextContains)),
|
||||||
|
mustSearchQuery(NewUserResourceOwnerSearchQuery("orgID", TextEquals)),
|
||||||
|
}),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
wantClause: sq.Or{
|
||||||
|
sq.Expr(
|
||||||
|
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
|
||||||
|
"instanceID",
|
||||||
|
"userID",
|
||||||
|
database.NewJSONArray(permissions),
|
||||||
|
"permission1",
|
||||||
|
"orgID",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
gotClause := PermissionClause(tt.args.ctx, tt.args.orgIDCol, tt.args.permission, tt.args.options...)
|
||||||
|
assert.Equal(t, tt.wantClause, gotClause)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustSearchQuery(q SearchQuery, err error) SearchQuery {
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return q
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPermissionV2(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
ctx context.Context
|
||||||
|
cf domain_pkg.PermissionCheck
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "feature disabled, no permission check",
|
||||||
|
args: args{
|
||||||
|
ctx: context.Background(),
|
||||||
|
cf: nil,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "feature enabled, no permission check",
|
||||||
|
args: args{
|
||||||
|
ctx: authz.WithFeatures(context.Background(), feature.Features{
|
||||||
|
PermissionCheckV2: true,
|
||||||
|
}),
|
||||||
|
cf: nil,
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "feature enabled, with permission check",
|
||||||
|
args: args{
|
||||||
|
ctx: authz.WithFeatures(context.Background(), feature.Features{
|
||||||
|
PermissionCheckV2: true,
|
||||||
|
}),
|
||||||
|
cf: func(context.Context, string, string, string) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "feature disabled, with permission check",
|
||||||
|
args: args{
|
||||||
|
ctx: authz.WithFeatures(context.Background(), feature.Features{
|
||||||
|
PermissionCheckV2: false,
|
||||||
|
}),
|
||||||
|
cf: func(context.Context, string, string, string) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got := PermissionV2(tt.args.ctx, tt.args.cf)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
@@ -148,3 +148,16 @@ func triggerBatch(ctx context.Context, handlers ...*handler.Handler) {
|
|||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func findTextEqualsQuery(column Column, queries []SearchQuery) string {
|
||||||
|
for _, query := range queries {
|
||||||
|
if query.Col() != column {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
tq, ok := query.(*textQuery)
|
||||||
|
if ok && tq.Compare == TextEquals {
|
||||||
|
return tq.Text
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
@@ -113,6 +113,22 @@ func sessionCheckPermission(ctx context.Context, resourceOwner string, creator s
|
|||||||
return permissionCheck(ctx, domain.PermissionSessionRead, resourceOwner, "")
|
return permissionCheck(ctx, domain.PermissionSessionRead, resourceOwner, "")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sessionsPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder {
|
||||||
|
if !enabled {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return query.Where(PermissionClause(
|
||||||
|
ctx,
|
||||||
|
SessionColumnResourceOwner,
|
||||||
|
domain.PermissionSessionRead,
|
||||||
|
// Allow if user is creator
|
||||||
|
OwnedRowsPermissionOption(SessionColumnCreator),
|
||||||
|
// Allow if session belongs to the user
|
||||||
|
OwnedRowsPermissionOption(SessionColumnUserID),
|
||||||
|
ConnectionPermissionOption(SessionColumnUserAgentFingerprintID, authz.GetCtxData(ctx).AgentID),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
|
func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
|
||||||
query = q.SearchRequest.toQuery(query)
|
query = q.SearchRequest.toQuery(query)
|
||||||
for _, q := range q.Queries {
|
for _, q := range q.Queries {
|
||||||
@@ -282,21 +298,23 @@ func (q *Queries) sessionByID(ctx context.Context, shouldTriggerBulk bool, id st
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheck domain.PermissionCheck) (*Sessions, error) {
|
func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheck domain.PermissionCheck) (*Sessions, error) {
|
||||||
sessions, err := q.searchSessions(ctx, queries)
|
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
|
||||||
|
sessions, err := q.searchSessions(ctx, queries, permissionCheckV2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if permissionCheck != nil {
|
if permissionCheck != nil && !permissionCheckV2 {
|
||||||
sessionsCheckPermission(ctx, sessions, permissionCheck)
|
sessionsCheckPermission(ctx, sessions, permissionCheck)
|
||||||
}
|
}
|
||||||
return sessions, nil
|
return sessions, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) {
|
func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheckV2 bool) (sessions *Sessions, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
query, scan := prepareSessionsQuery()
|
query, scan := prepareSessionsQuery()
|
||||||
|
query = sessionsPermissionCheckV2(ctx, query, permissionCheckV2)
|
||||||
stmt, args, err := queries.toQuery(query).
|
stmt, args, err := queries.toQuery(query).
|
||||||
Where(sq.Eq{
|
Where(sq.Eq{
|
||||||
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||||
|
@@ -132,6 +132,19 @@ func usersCheckPermission(ctx context.Context, users *Users, permissionCheck dom
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func userPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool, queries *UserSearchQueries) sq.SelectBuilder {
|
||||||
|
if !enabled {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return query.Where(PermissionClause(
|
||||||
|
ctx,
|
||||||
|
UserResourceOwnerCol,
|
||||||
|
domain.PermissionUserRead,
|
||||||
|
SingleOrgPermissionOption(queries.Queries),
|
||||||
|
OwnedRowsPermissionOption(UserIDCol),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
type UserSearchQueries struct {
|
type UserSearchQueries struct {
|
||||||
SearchRequest
|
SearchRequest
|
||||||
Queries []SearchQuery
|
Queries []SearchQuery
|
||||||
@@ -606,8 +619,9 @@ func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (c
|
|||||||
return count, err
|
return count, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheck domain.PermissionCheck) (*Users, error) {
|
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) {
|
||||||
users, err := q.searchUsers(ctx, queries, filterOrgIds, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
|
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
|
||||||
|
users, err := q.searchUsers(ctx, queries, permissionCheckV2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -617,22 +631,15 @@ func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, f
|
|||||||
return users, nil
|
return users, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheckV2 bool) (users *Users, err error) {
|
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheckV2 bool) (users *Users, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
query, scan := prepareUsersQuery()
|
query, scan := prepareUsersQuery()
|
||||||
query = queries.toQuery(query).Where(sq.Eq{
|
query = userPermissionCheckV2(ctx, query, permissionCheckV2, queries)
|
||||||
|
stmt, args, err := queries.toQuery(query).Where(sq.Eq{
|
||||||
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
|
||||||
})
|
}).ToSql()
|
||||||
if permissionCheckV2 {
|
|
||||||
query, err = wherePermittedOrgsOrCurrentUser(ctx, query, filterOrgIds, UserResourceOwnerCol.identifier(), UserIDCol.identifier(), domain.PermissionUserRead)
|
|
||||||
if err != nil {
|
|
||||||
return nil, zerrors.ThrowInternal(err, "AUTHZ-HS4us", "Errors.Internal")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
stmt, args, err := query.ToSql()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, zerrors.ThrowInternal(err, "QUERY-Dgbg2", "Errors.Query.SQLStatment")
|
return nil, zerrors.ThrowInternal(err, "QUERY-Dgbg2", "Errors.Query.SQLStatment")
|
||||||
}
|
}
|
||||||
|
@@ -104,6 +104,18 @@ func authMethodsCheckPermission(ctx context.Context, methods *AuthMethods, permi
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func userAuthMethodPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder {
|
||||||
|
if !enabled {
|
||||||
|
return query
|
||||||
|
}
|
||||||
|
return query.Where(PermissionClause(
|
||||||
|
ctx,
|
||||||
|
UserAuthMethodColumnResourceOwner,
|
||||||
|
domain.PermissionUserRead,
|
||||||
|
OwnedRowsPermissionOption(UserIDCol),
|
||||||
|
))
|
||||||
|
}
|
||||||
|
|
||||||
type AuthMethod struct {
|
type AuthMethod struct {
|
||||||
UserID string
|
UserID string
|
||||||
CreationDate time.Time
|
CreationDate time.Time
|
||||||
@@ -137,11 +149,12 @@ func (q *UserAuthMethodSearchQueries) hasUserID() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheck domain.PermissionCheck) (userAuthMethods *AuthMethods, err error) {
|
func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheck domain.PermissionCheck) (userAuthMethods *AuthMethods, err error) {
|
||||||
methods, err := q.searchUserAuthMethods(ctx, queries)
|
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
|
||||||
|
methods, err := q.searchUserAuthMethods(ctx, queries, permissionCheckV2)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
if permissionCheck != nil && len(methods.AuthMethods) > 0 {
|
if permissionCheck != nil && len(methods.AuthMethods) > 0 && !permissionCheckV2 {
|
||||||
// when userID for query is provided, only one check has to be done
|
// when userID for query is provided, only one check has to be done
|
||||||
if queries.hasUserID() {
|
if queries.hasUserID() {
|
||||||
if err := userCheckPermission(ctx, methods.AuthMethods[0].ResourceOwner, methods.AuthMethods[0].UserID, permissionCheck); err != nil {
|
if err := userCheckPermission(ctx, methods.AuthMethods[0].ResourceOwner, methods.AuthMethods[0].UserID, permissionCheck); err != nil {
|
||||||
@@ -154,11 +167,12 @@ func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMe
|
|||||||
return methods, nil
|
return methods, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries) (userAuthMethods *AuthMethods, err error) {
|
func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheckV2 bool) (userAuthMethods *AuthMethods, err error) {
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
query, scan := prepareUserAuthMethodsQuery()
|
query, scan := prepareUserAuthMethodsQuery()
|
||||||
|
query = userAuthMethodPermissionCheckV2(ctx, query, permissionCheckV2)
|
||||||
stmt, args, err := queries.toQuery(query).Where(sq.Eq{UserAuthMethodColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}).ToSql()
|
stmt, args, err := queries.toQuery(query).Where(sq.Eq{UserAuthMethodColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}).ToSql()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest")
|
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest")
|
||||||
|
Reference in New Issue
Block a user