fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! fixup! feat(permissions): Addeding system user support for permission check v2

This commit is contained in:
Iraq Jaber
2025-03-07 08:43:05 +00:00
parent d82803907f
commit 3f49f5b699
4 changed files with 16 additions and 16 deletions

View File

@@ -4,7 +4,7 @@ CREATE OR REPLACE FUNCTION eventstore.permitted_orgs(
instanceId TEXT instanceId TEXT
, userId TEXT , userId TEXT
, perm TEXT , perm TEXT
, system_roles TEXT[] , system_user_roles TEXT[]
, filter_orgs TEXT , filter_orgs TEXT
, org_ids OUT TEXT[] , org_ids OUT TEXT[]
@@ -20,12 +20,12 @@ BEGIN
WHERE rp.instance_id = instanceId WHERE rp.instance_id = instanceId
AND rp.permission = perm; AND rp.permission = perm;
IF system_roles IS NOT NULL THEN IF system_user_roles IS NOT NULL THEN
DECLARE DECLARE
permission_found_in_system_roles bool; permission_found_in_system_roles bool;
BEGIN BEGIN
SELECT result.role_found INTO permission_found_in_system_roles SELECT result.role_found INTO permission_found_in_system_roles
FROM (SELECT matched_roles && system_roles AS role_found) AS result; FROM (SELECT matched_roles && system_user_roles AS role_found) AS result;
IF permission_found_in_system_roles THEN IF permission_found_in_system_roles THEN
SELECT array_agg(o.org_id) INTO org_ids SELECT array_agg(o.org_id) INTO org_ids

View File

@@ -31,7 +31,7 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID,
if requiredAuthOption.Permission == authenticated { if requiredAuthOption.Permission == authenticated {
return func(parent context.Context) context.Context { return func(parent context.Context) context.Context {
parent = addGetSystemRolesFuncToCtx(parent, ctxData) parent = addGetSystemUserRolesFuncToCtx(parent, ctxData)
return context.WithValue(parent, dataKey, ctxData) return context.WithValue(parent, dataKey, ctxData)
}, nil }, nil
} }
@@ -52,7 +52,7 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID,
parent = context.WithValue(parent, dataKey, ctxData) parent = context.WithValue(parent, dataKey, ctxData)
parent = context.WithValue(parent, allPermissionsKey, allPermissions) parent = context.WithValue(parent, allPermissionsKey, allPermissions)
parent = context.WithValue(parent, requestPermissionsKey, requestedPermissions) parent = context.WithValue(parent, requestPermissionsKey, requestedPermissions)
parent = addGetSystemRolesFuncToCtx(parent, ctxData) parent = addGetSystemUserRolesFuncToCtx(parent, ctxData)
return parent return parent
}, nil }, nil
} }
@@ -129,7 +129,7 @@ func GetAllPermissionCtxIDs(perms []string) []string {
return ctxIDs return ctxIDs
} }
func addGetSystemRolesFuncToCtx(ctx context.Context, ctxData CtxData) context.Context { func addGetSystemUserRolesFuncToCtx(ctx context.Context, ctxData CtxData) context.Context {
if len(ctxData.SystemMemberships) != 0 { if len(ctxData.SystemMemberships) != 0 {
ctx = context.WithValue(ctx, systemUserRolesFuncKey, func() func(ctx context.Context) ([]string, error) { ctx = context.WithValue(ctx, systemUserRolesFuncKey, func() func(ctx context.Context) ([]string, error) {
var roles []string var roles []string
@@ -138,7 +138,7 @@ func addGetSystemRolesFuncToCtx(ctx context.Context, ctxData CtxData) context.Co
return roles, nil return roles, nil
} }
var err error var err error
roles, err = getSystemRoles(ctx) roles, err = getSystemUserRoles(ctx)
return roles, err return roles, err
} }
}()) }())
@@ -146,7 +146,7 @@ func addGetSystemRolesFuncToCtx(ctx context.Context, ctxData CtxData) context.Co
return ctx return ctx
} }
func GetSystemRoles(ctx context.Context) ([]string, error) { func GetSystemUserRoles(ctx context.Context) ([]string, error) {
getSystemUserRolesFuncValue := ctx.Value(systemUserRolesFuncKey) getSystemUserRolesFuncValue := ctx.Value(systemUserRolesFuncKey)
if getSystemUserRolesFuncValue == nil { if getSystemUserRolesFuncValue == nil {
return nil, nil return nil, nil
@@ -158,7 +158,7 @@ func GetSystemRoles(ctx context.Context) ([]string, error) {
return getSystemUserRolesFunc(ctx) return getSystemUserRolesFunc(ctx)
} }
func getSystemRoles(ctx context.Context) ([]string, error) { func getSystemUserRoles(ctx context.Context) ([]string, error) {
ctxData, ok := ctx.Value(dataKey).(CtxData) ctxData, ok := ctx.Value(dataKey).(CtxData)
if !ok { if !ok {
return nil, errors.New("unable to obtain ctxData") return nil, errors.New("unable to obtain ctxData")

View File

@@ -11,7 +11,7 @@ import (
) )
const ( const (
// eventstore.permitted_orgs(instanceid text, userid text, perm text, system_roles text[], filter_orgs text) // eventstore.permitted_orgs(instanceid text, userid text, perm text, system_user_roles text[], filter_orgs text)
wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))" wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))"
wherePermittedOrgsOrCurrentUserClause = "(" + wherePermittedOrgsClause + " OR %s = ?" + ")" wherePermittedOrgsOrCurrentUserClause = "(" + wherePermittedOrgsClause + " OR %s = ?" + ")"
) )
@@ -24,7 +24,7 @@ const (
// and is typically the `resource_owner` column in ZITADEL. // and is typically the `resource_owner` column in ZITADEL.
// We use full identifiers in the query builder so this function should be // We use full identifiers in the query builder so this function should be
// called with something like `UserResourceOwnerCol.identifier()` for example. // called with something like `UserResourceOwnerCol.identifier()` for example.
func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, systemRoles []string, filterOrgIds, orgIDColumn, permission string) sq.SelectBuilder { func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, systemUserRoles []string, filterOrgIds, orgIDColumn, permission string) sq.SelectBuilder {
userID := authz.GetCtxData(ctx).UserID userID := authz.GetCtxData(ctx).UserID
logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used") logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
@@ -33,12 +33,12 @@ func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, systemRoles
authz.GetInstance(ctx).InstanceID(), authz.GetInstance(ctx).InstanceID(),
userID, userID,
permission, permission,
systemRoles, systemUserRoles,
filterOrgIds, filterOrgIds,
) )
} }
func wherePermittedOrgsOrCurrentUser(ctx context.Context, query sq.SelectBuilder, systemRoles []string, filterOrgIds, orgIDColumn, userIdColum, permission string) sq.SelectBuilder { func wherePermittedOrgsOrCurrentUser(ctx context.Context, query sq.SelectBuilder, systemUserRoles []string, filterOrgIds, orgIDColumn, userIdColum, permission string) sq.SelectBuilder {
userID := authz.GetCtxData(ctx).UserID userID := authz.GetCtxData(ctx).UserID
logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "user_id_colum", userIdColum, "permission", permission, "user_id", userID).Debug("permitted orgs check used") logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "user_id_colum", userIdColum, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
@@ -47,7 +47,7 @@ func wherePermittedOrgsOrCurrentUser(ctx context.Context, query sq.SelectBuilder
authz.GetInstance(ctx).InstanceID(), authz.GetInstance(ctx).InstanceID(),
userID, userID,
permission, permission,
systemRoles, systemUserRoles,
filterOrgIds, filterOrgIds,
userID, userID,
) )

View File

@@ -656,11 +656,11 @@ func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, f
}) })
if permissionCheckV2 { if permissionCheckV2 {
// extract system roles // extract system roles
systemRoles, err := authz.GetSystemRoles(ctx) systemUserRoles, err := authz.GetSystemUserRoles(ctx)
if err != nil { if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-GS9gs", "Errors.Internal") return nil, zerrors.ThrowInternal(err, "QUERY-GS9gs", "Errors.Internal")
} }
query = wherePermittedOrgsOrCurrentUser(ctx, query, systemRoles, filterOrgIds, UserResourceOwnerCol.identifier(), UserIDCol.identifier(), domain.PermissionUserRead) query = wherePermittedOrgsOrCurrentUser(ctx, query, systemUserRoles, filterOrgIds, UserResourceOwnerCol.identifier(), UserIDCol.identifier(), domain.PermissionUserRead)
} }
stmt, args, err := query.ToSql() stmt, args, err := query.ToSql()