perf(query): org permission function for resources (#9677)

# Which Problems Are Solved

Classic permission checks execute for every returned row on resource
based search APIs. Complete background and problem definition can be
found here: https://github.com/zitadel/zitadel/issues/9188

# How the Problems Are Solved

- PermissionClause function now support dynamic query building, so it
supports multiple cases.
- PermissionClause is applied to all list resources which support org
level permissions.
- Wrap permission logic into wrapper functions so we keep the business
logic clean.

# Additional Changes

- Handle org ID optimization in the query package, so it is reusable for
all resources, instead of extracting the filter in the API.
- Cleanup and test system user conversion in the authz package. (context
middleware)
- Fix: `core_integration_db_up` make recipe was missing the postgres
service.

# Additional Context

- Related to https://github.com/zitadel/zitadel/issues/9190
This commit is contained in:
Tim Möhlmann
2025-04-15 19:38:25 +03:00
committed by GitHub
parent 3b8a2ab811
commit a2f60f2e7a
23 changed files with 741 additions and 172 deletions

View File

@@ -112,7 +112,7 @@ core_unit_test:
.PHONY: core_integration_db_up
core_integration_db_up:
docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait cache
docker compose -f internal/integration/config/docker-compose.yaml up --pull always --wait cache postgres
.PHONY: core_integration_db_down
core_integration_db_down:

View File

@@ -7,8 +7,6 @@ import (
"slices"
"strings"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -26,14 +24,13 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID,
ctx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, orgDomain, verifier)
ctxData, err := VerifyTokenAndCreateCtxData(ctx, token, orgID, orgDomain, verifier, systemRolePermissionMapping)
if err != nil {
return nil, err
}
if requiredAuthOption.Permission == authenticated {
return func(parent context.Context) context.Context {
parent = addGetSystemUserRolesToCtx(parent, systemRolePermissionMapping, ctxData)
return context.WithValue(parent, dataKey, ctxData)
}, nil
}
@@ -54,7 +51,6 @@ func CheckUserAuthorization(ctx context.Context, req interface{}, token, orgID,
parent = context.WithValue(parent, dataKey, ctxData)
parent = context.WithValue(parent, allPermissionsKey, allPermissions)
parent = context.WithValue(parent, requestPermissionsKey, requestedPermissions)
parent = addGetSystemUserRolesToCtx(parent, systemRolePermissionMapping, ctxData)
return parent
}, nil
}
@@ -131,42 +127,32 @@ func GetAllPermissionCtxIDs(perms []string) []string {
return ctxIDs
}
type SystemUserPermissionsDBQuery struct {
MemberType string `json:"member_type"`
type SystemUserPermissions struct {
MemberType MemberType `json:"member_type"`
AggregateID string `json:"aggregate_id"`
ObjectID string `json:"object_id"`
Permissions []string `json:"permissions"`
}
func addGetSystemUserRolesToCtx(ctx context.Context, systemUserRoleMap []RoleMapping, ctxData CtxData) context.Context {
if len(ctxData.SystemMemberships) == 0 {
return ctx
// systemMembershipsToUserPermissions converts system memberships based on roles,
// to SystemUserPermissions, using the passed role mapping.
func systemMembershipsToUserPermissions(memberships Memberships, roleMap []RoleMapping) []SystemUserPermissions {
if memberships == nil {
return nil
}
systemUserPermissions := make([]SystemUserPermissionsDBQuery, len(ctxData.SystemMemberships))
for i, systemPerm := range ctxData.SystemMemberships {
systemUserPermissions := make([]SystemUserPermissions, len(memberships))
for i, systemPerm := range memberships {
permissions := make([]string, 0, len(systemPerm.Roles))
for _, role := range systemPerm.Roles {
permissions = append(permissions, getPermissionsFromRole(systemUserRoleMap, role)...)
permissions = append(permissions, getPermissionsFromRole(roleMap, role)...)
}
slices.Sort(permissions)
permissions = slices.Compact(permissions)
permissions = slices.Compact(permissions) // remove duplicates
systemUserPermissions[i].MemberType = systemPerm.MemberType.String()
systemUserPermissions[i].MemberType = systemPerm.MemberType
systemUserPermissions[i].AggregateID = systemPerm.AggregateID
systemUserPermissions[i].ObjectID = systemPerm.ObjectID
systemUserPermissions[i].Permissions = permissions
}
return context.WithValue(ctx, systemUserRolesKey, systemUserPermissions)
}
func GetSystemUserPermissions(ctx context.Context) []SystemUserPermissionsDBQuery {
getSystemUserRolesFuncValue := ctx.Value(systemUserRolesKey)
if getSystemUserRolesFuncValue == nil {
return nil
}
systemUserRoles, ok := getSystemUserRolesFuncValue.([]SystemUserPermissionsDBQuery)
if !ok {
logging.WithFields("Authz").Error("unable to cast []SystemUserPermissionsDBQuery")
return nil
}
return systemUserRoles
return systemUserPermissions
}

View File

@@ -3,6 +3,8 @@ package authz
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/zerrors"
)
@@ -276,3 +278,127 @@ func Test_GetPermissionCtxIDs(t *testing.T) {
})
}
}
func Test_systemMembershipsToUserPermissions(t *testing.T) {
roleMap := []RoleMapping{
{
Role: "FOO_BAR",
Permissions: []string{"foo.bar.read", "foo.bar.write"},
},
{
Role: "BAR_FOO",
Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read"},
},
}
type args struct {
memberships Memberships
roleMap []RoleMapping
}
tests := []struct {
name string
args args
want []SystemUserPermissions
}{
{
name: "nil memberships",
args: args{
memberships: nil,
roleMap: roleMap,
},
want: nil,
},
{
name: "empty memberships",
args: args{
memberships: Memberships{},
roleMap: roleMap,
},
want: []SystemUserPermissions{},
},
{
name: "single membership",
args: args{
memberships: Memberships{
{
MemberType: MemberTypeSystem,
AggregateID: "1",
ObjectID: "2",
Roles: []string{"FOO_BAR"},
},
},
roleMap: roleMap,
},
want: []SystemUserPermissions{
{
MemberType: MemberTypeSystem,
AggregateID: "1",
ObjectID: "2",
Permissions: []string{"foo.bar.read", "foo.bar.write"},
},
},
},
{
name: "multiple memberships",
args: args{
memberships: Memberships{
{
MemberType: MemberTypeSystem,
AggregateID: "1",
ObjectID: "2",
Roles: []string{"FOO_BAR"},
},
{
MemberType: MemberTypeIAM,
AggregateID: "1",
ObjectID: "2",
Roles: []string{"BAR_FOO"},
},
},
roleMap: roleMap,
},
want: []SystemUserPermissions{
{
MemberType: MemberTypeSystem,
AggregateID: "1",
ObjectID: "2",
Permissions: []string{"foo.bar.read", "foo.bar.write"},
},
{
MemberType: MemberTypeIAM,
AggregateID: "1",
ObjectID: "2",
Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read"},
},
},
},
{
name: "multiple roles",
args: args{
memberships: Memberships{
{
MemberType: MemberTypeSystem,
AggregateID: "1",
ObjectID: "2",
Roles: []string{"FOO_BAR", "BAR_FOO"},
},
},
roleMap: roleMap,
},
want: []SystemUserPermissions{
{
MemberType: MemberTypeSystem,
AggregateID: "1",
ObjectID: "2",
Permissions: []string{"bar.foo.read", "bar.foo.write", "foo.bar.read", "foo.bar.write"},
},
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := systemMembershipsToUserPermissions(tt.args.memberships, tt.args.roleMap)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -1,4 +1,4 @@
//go:generate enumer -type MemberType -trimprefix MemberType
//go:generate enumer -type MemberType -trimprefix MemberType -json
package authz
@@ -22,7 +22,6 @@ const (
dataKey key = 2
allPermissionsKey key = 3
instanceKey key = 4
systemUserRolesKey key = 5
)
type CtxData struct {
@@ -33,6 +32,7 @@ type CtxData struct {
PreferredLanguage string
ResourceOwner string
SystemMemberships Memberships
SystemUserPermissions []SystemUserPermissions
}
func (ctxData CtxData) IsZero() bool {
@@ -98,7 +98,7 @@ func (s SystemTokenVerifierFunc) VerifySystemToken(ctx context.Context, token st
return s(ctx, token, orgID)
}
func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain string, t APITokenVerifier) (_ CtxData, err error) {
func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain string, t APITokenVerifier, systemRoleMap []RoleMapping) (_ CtxData, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
tokenWOBearer, err := extractBearerToken(token)
@@ -140,6 +140,7 @@ func VerifyTokenAndCreateCtxData(ctx context.Context, token, orgID, orgDomain st
PreferredLanguage: prefLang,
ResourceOwner: resourceOwner,
SystemMemberships: sysMemberships,
SystemUserPermissions: systemMembershipsToUserPermissions(sysMemberships, systemRoleMap),
}, nil
}

View File

@@ -1,8 +1,9 @@
// Code generated by "enumer -type MemberType -trimprefix MemberType"; DO NOT EDIT.
// Code generated by "enumer -type MemberType -trimprefix MemberType -json"; DO NOT EDIT.
package authz
import (
"encoding/json"
"fmt"
"strings"
)
@@ -92,3 +93,20 @@ func (i MemberType) IsAMemberType() bool {
}
return false
}
// MarshalJSON implements the json.Marshaler interface for MemberType
func (i MemberType) MarshalJSON() ([]byte, error) {
return json.Marshal(i.String())
}
// UnmarshalJSON implements the json.Unmarshaler interface for MemberType
func (i *MemberType) UnmarshalJSON(data []byte) error {
var s string
if err := json.Unmarshal(data, &s); err != nil {
return fmt.Errorf("MemberType should be a string, got %s", data)
}
var err error
*i, err = MemberTypeString(s)
return err
}

View File

@@ -554,7 +554,7 @@ func (s *Server) getUsers(ctx context.Context, org string, withPasswords bool, w
if err != nil {
return nil, nil, nil, nil, err
}
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, org, nil)
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{orgSearch}}, nil)
if err != nil {
return nil, nil, nil, nil, err
}

View File

@@ -108,7 +108,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain str
if err != nil {
return nil, err
}
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
if err != nil {
return nil, err
}

View File

@@ -329,7 +329,7 @@ func (s *Server) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgDomain, or
}
queries = append(queries, owner)
}
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, orgID, nil)
users, err := s.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: queries}, nil)
if err != nil {
return nil, err
}

View File

@@ -69,7 +69,7 @@ func (s *Server) ListUsers(ctx context.Context, req *mgmt_pb.ListUsersRequest) (
if err != nil {
return nil, err
}
res, err := s.query.SearchUsers(ctx, queries, orgID, nil)
res, err := s.query.SearchUsers(ctx, queries, nil)
if err != nil {
return nil, err
}

View File

@@ -30,11 +30,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
}
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
queries, filterOrgId, err := listUsersRequestToModel(req)
queries, err := listUsersRequestToModel(req)
if err != nil {
return nil, err
}
res, err := s.query.SearchUsers(ctx, queries, filterOrgId, s.checkPermission)
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
if err != nil {
return nil, err
}
@@ -171,11 +171,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
}
}
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
offset, limit, asc := object.ListQueryToQuery(req.Query)
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
if err != nil {
return nil, "", err
return nil, err
}
return &query.UserSearchQueries{
SearchRequest: query.SearchRequest{
@@ -185,7 +185,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
},
Queries: queries,
}, filterOrgId, nil
}, nil
}
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
@@ -215,18 +215,15 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
}
}
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
q := make([]query.SearchQuery, len(queries))
for i, query := range queries {
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
filterOrgId = orgFilter.OrganizationId
}
q[i], err = userQueryToQuery(query, level)
if err != nil {
return nil, filterOrgId, err
return nil, err
}
}
return q, filterOrgId, nil
return q, nil
}
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
@@ -320,14 +317,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
return query.NewUserInUserIdsSearchQuery(q.UserIds)
}
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}
return query.NewUserOrSearchQuery(mappedQueries)
}
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}

View File

@@ -29,11 +29,11 @@ func (s *Server) GetUserByID(ctx context.Context, req *user.GetUserByIDRequest)
}
func (s *Server) ListUsers(ctx context.Context, req *user.ListUsersRequest) (*user.ListUsersResponse, error) {
queries, filterOrgIds, err := listUsersRequestToModel(req)
queries, err := listUsersRequestToModel(req)
if err != nil {
return nil, err
}
res, err := s.query.SearchUsers(ctx, queries, filterOrgIds, s.checkPermission)
res, err := s.query.SearchUsers(ctx, queries, s.checkPermission)
if err != nil {
return nil, err
}
@@ -165,11 +165,11 @@ func accessTokenTypeToPb(accessTokenType domain.OIDCTokenType) user.AccessTokenT
}
}
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, string, error) {
func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueries, error) {
offset, limit, asc := object.ListQueryToQuery(req.Query)
queries, filterOrgId, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
queries, err := userQueriesToQuery(req.Queries, 0 /*start from level 0*/)
if err != nil {
return nil, "", err
return nil, err
}
return &query.UserSearchQueries{
SearchRequest: query.SearchRequest{
@@ -179,7 +179,7 @@ func listUsersRequestToModel(req *user.ListUsersRequest) (*query.UserSearchQueri
SortingColumn: userFieldNameToSortingColumn(req.SortingColumn),
},
Queries: queries,
}, filterOrgId, nil
}, nil
}
func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
@@ -209,18 +209,15 @@ func userFieldNameToSortingColumn(field user.UserFieldName) query.Column {
}
}
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, filterOrgId string, err error) {
func userQueriesToQuery(queries []*user.SearchQuery, level uint8) (_ []query.SearchQuery, err error) {
q := make([]query.SearchQuery, len(queries))
for i, query := range queries {
if orgFilter := query.GetOrganizationIdQuery(); orgFilter != nil {
filterOrgId = orgFilter.OrganizationId
}
q[i], err = userQueryToQuery(query, level)
if err != nil {
return nil, filterOrgId, err
return nil, err
}
}
return q, filterOrgId, nil
return q, nil
}
func userQueryToQuery(query *user.SearchQuery, level uint8) (query.SearchQuery, error) {
@@ -314,14 +311,14 @@ func inUserIdsQueryToQuery(q *user.InUserIDQuery) (query.SearchQuery, error) {
return query.NewUserInUserIdsSearchQuery(q.UserIds)
}
func orQueryToQuery(q *user.OrQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}
return query.NewUserOrSearchQuery(mappedQueries)
}
func andQueryToQuery(q *user.AndQuery, level uint8) (query.SearchQuery, error) {
mappedQueries, _, err := userQueriesToQuery(q.Queries, level+1)
mappedQueries, err := userQueriesToQuery(q.Queries, level+1)
if err != nil {
return nil, err
}

View File

@@ -240,7 +240,7 @@ func (h *UsersHandler) List(ctx context.Context, request *ListRequest) (*ListRes
return NewListResponse(count, q.SearchRequest, make([]*ScimUser, 0)), nil
}
users, err := h.query.SearchUsers(ctx, q, authz.GetCtxData(ctx).OrgID, nil)
users, err := h.query.SearchUsers(ctx, q, nil)
if err != nil {
return nil, err
}

View File

@@ -182,7 +182,7 @@ func (l *Login) getClaimedUserIDsOfOrgDomain(ctx context.Context, orgName string
if err != nil {
return nil, err
}
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, "", nil)
users, err := l.query.SearchUsers(ctx, &query.UserSearchQueries{Queries: []query.SearchQuery{loginName}}, nil)
if err != nil {
return nil, err
}

View File

@@ -225,3 +225,37 @@ func (d *NullDuration) Scan(src any) error {
d.Duration, d.Valid = time.Duration(*duration), true
return nil
}
// JSONArray allows sending and receiving JSON arrays to and from the database.
// It implements the [database/sql.Scanner] and [database/sql/driver.Valuer] interfaces.
// Values are marshaled and unmarshaled using the [encoding/json] package.
type JSONArray[T any] []T
// NewJSONArray wraps an existing slice into a JSONArray.
func NewJSONArray[T any](a []T) JSONArray[T] {
return JSONArray[T](a)
}
// Scan implements the [database/sql.Scanner] interface.
func (a *JSONArray[T]) Scan(src any) error {
if src == nil {
*a = nil
return nil
}
bytes := src.([]byte)
if len(bytes) == 0 {
*a = nil
return nil
}
return json.Unmarshal(bytes, a)
}
// Value implements the [database/sql/driver.Valuer] interface.
func (a JSONArray[T]) Value() (driver.Value, error) {
if a == nil {
return nil, nil
}
return json.Marshal(a)
}

View File

@@ -5,6 +5,7 @@ import (
"database/sql/driver"
"testing"
"github.com/muhlemmer/gu"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
@@ -452,3 +453,89 @@ func TestDuration_Scan(t *testing.T) {
})
}
}
func TestJSONArray_Scan(t *testing.T) {
type args struct {
src any
}
tests := []struct {
name string
args args
want *JSONArray[string]
wantErr bool
}{
{
name: "nil",
args: args{src: nil},
want: new(JSONArray[string]),
wantErr: false,
},
{
name: "zero bytes",
args: args{src: []byte("")},
want: new(JSONArray[string]),
wantErr: false,
},
{
name: "empty",
args: args{src: []byte("[]")},
want: gu.Ptr(JSONArray[string]{}),
wantErr: false,
},
{
name: "ok",
args: args{src: []byte("[\"a\", \"b\"]")},
want: gu.Ptr(JSONArray[string]{"a", "b"}),
wantErr: false,
},
{
name: "json error",
args: args{src: []byte("{\"a\": \"b\"}")},
want: new(JSONArray[string]),
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := new(JSONArray[string])
err := got.Scan(tt.args.src)
if tt.wantErr {
assert.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestJSONArray_Value(t *testing.T) {
tests := []struct {
name string
a []string
want driver.Value
}{
{
name: "nil",
a: nil,
want: nil,
},
{
name: "empty",
a: []string{},
want: []byte("[]"),
},
{
name: "ok",
a: []string{"a", "b"},
want: []byte("[\"a\",\"b\"]"),
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := NewJSONArray(tt.a).Value()
require.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -106,12 +106,26 @@ func idpLinksCheckPermission(ctx context.Context, links *IDPUserLinks, permissio
)
}
func idpLinksPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool, queries *IDPUserLinksSearchQuery) sq.SelectBuilder {
if !enabled {
return query
}
return query.Where(PermissionClause(
ctx,
IDPUserLinkResourceOwnerCol,
domain.PermissionUserRead,
SingleOrgPermissionOption(queries.Queries),
OwnedRowsPermissionOption(IDPUserLinkUserIDCol),
))
}
func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, permissionCheck domain.PermissionCheck) (idps *IDPUserLinks, err error) {
links, err := q.idpUserLinks(ctx, queries, false)
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
links, err := q.idpUserLinks(ctx, queries, permissionCheckV2)
if err != nil {
return nil, err
}
if permissionCheck != nil && len(links.Links) > 0 {
if permissionCheck != nil && len(links.Links) > 0 && !permissionCheckV2 {
// when userID for query is provided, only one check has to be done
if queries.hasUserID() {
if err := userCheckPermission(ctx, links.Links[0].ResourceOwner, links.Links[0].UserID, permissionCheck); err != nil {
@@ -124,14 +138,15 @@ func (q *Queries) IDPUserLinks(ctx context.Context, queries *IDPUserLinksSearchQ
return links, nil
}
func (q *Queries) idpUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, withOwnerRemoved bool) (idps *IDPUserLinks, err error) {
func (q *Queries) idpUserLinks(ctx context.Context, queries *IDPUserLinksSearchQuery, permissionCheckV2 bool) (idps *IDPUserLinks, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareIDPUserLinksQuery()
eq := sq.Eq{IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID()}
if !withOwnerRemoved {
eq[IDPUserLinkOwnerRemovedCol.identifier()] = false
query = idpLinksPermissionCheckV2(ctx, query, permissionCheckV2, queries)
eq := sq.Eq{
IDPUserLinkInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
IDPUserLinkOwnerRemovedCol.identifier(): false,
}
stmt, args, err := queries.toQuery(query).Where(eq).ToSql()
if err != nil {

View File

@@ -93,6 +93,17 @@ func orgsCheckPermission(ctx context.Context, orgs *Orgs, permissionCheck domain
)
}
func orgsPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder {
if !enabled {
return query
}
return query.Where(PermissionClause(
ctx,
OrgColumnID,
domain_pkg.PermissionOrgRead,
))
}
type OrgSearchQueries struct {
SearchRequest
Queries []SearchQuery
@@ -283,21 +294,23 @@ func (q *Queries) ExistsOrg(ctx context.Context, id, domain string) (verifiedID
}
func (q *Queries) SearchOrgs(ctx context.Context, queries *OrgSearchQueries, permissionCheck domain_pkg.PermissionCheck) (*Orgs, error) {
orgs, err := q.searchOrgs(ctx, queries)
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
orgs, err := q.searchOrgs(ctx, queries, permissionCheckV2)
if err != nil {
return nil, err
}
if permissionCheck != nil {
if permissionCheck != nil && !permissionCheckV2 {
orgsCheckPermission(ctx, orgs, permissionCheck)
}
return orgs, nil
}
func (q *Queries) searchOrgs(ctx context.Context, queries *OrgSearchQueries) (orgs *Orgs, err error) {
func (q *Queries) searchOrgs(ctx context.Context, queries *OrgSearchQueries, permissionCheckV2 bool) (orgs *Orgs, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareOrgsQuery()
query = orgsPermissionCheckV2(ctx, query, permissionCheckV2)
stmt, args, err := queries.toQuery(query).
Where(sq.And{
sq.Eq{

View File

@@ -2,74 +2,109 @@ package query
import (
"context"
"encoding/json"
"fmt"
sq "github.com/Masterminds/squirrel"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/zerrors"
"github.com/zitadel/zitadel/internal/database"
domain_pkg "github.com/zitadel/zitadel/internal/domain"
)
const (
// eventstore.permitted_orgs(instanceid text, userid text, system_user_perms JSONB, perm text filter_orgs text)
wherePermittedOrgsClause = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))"
wherePermittedOrgsOrCurrentUserClause = "(" + wherePermittedOrgsClause + " OR %s = ?" + ")"
// eventstore.permitted_orgs(instanceid text, userid text, system_user_perms JSONB, perm text, filter_org text)
wherePermittedOrgsExpr = "%s = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))"
)
// wherePermittedOrgs sets a `WHERE` clause to the query that filters the orgs
// for which the authenticated user has the requested permission for.
// The user ID is taken from the context.
// The `orgIDColumn` specifies the table column to which this filter must be applied,
// and is typically the `resource_owner` column in ZITADEL.
// We use full identifiers in the query builder so this function should be
// called with something like `UserResourceOwnerCol.identifier()` for example.
// func wherePermittedOrgs(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, permission string) (sq.SelectBuilder, error) {
// userID := authz.GetCtxData(ctx).UserID
// logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
// systemUserPermissions := authz.GetSystemUserPermissions(ctx)
// var systemUserPermissionsJson []byte
// if systemUserPermissions != nil {
// var err error
// systemUserPermissionsJson, err = json.Marshal(systemUserPermissions)
// if err != nil {
// return query, err
// }
// }
// return query.Where(
// fmt.Sprintf(wherePermittedOrgsClause, orgIDColumn),
// authz.GetInstance(ctx).InstanceID(),
// userID,
// systemUserPermissionsJson,
// permission,
// filterOrgIds,
// ), nil
// }
func wherePermittedOrgsOrCurrentUser(ctx context.Context, query sq.SelectBuilder, filterOrgIds, orgIDColumn, userIdColum, permission string) (sq.SelectBuilder, error) {
userID := authz.GetCtxData(ctx).UserID
logging.WithFields("permission_check_v2_flag", authz.GetFeatures(ctx).PermissionCheckV2, "org_id_column", orgIDColumn, "user_id_colum", userIdColum, "permission", permission, "user_id", userID).Debug("permitted orgs check used")
systemUserPermissions := authz.GetSystemUserPermissions(ctx)
var systemUserPermissionsJson []byte
if systemUserPermissions != nil {
var err error
systemUserPermissionsJson, err = json.Marshal(systemUserPermissions)
if err != nil {
return query, zerrors.ThrowInternal(err, "AUTHZ-HS4us", "Errors.Internal")
}
}
return query.Where(
fmt.Sprintf(wherePermittedOrgsOrCurrentUserClause, orgIDColumn, userIdColum),
authz.GetInstance(ctx).InstanceID(),
userID,
systemUserPermissionsJson,
permission,
filterOrgIds,
userID,
), nil
type permissionClauseBuilder struct {
orgIDColumn Column
instanceID string
userID string
systemPermissions []authz.SystemUserPermissions
permission string
orgID string
connections []sq.Eq
}
func (b *permissionClauseBuilder) appendConnection(column string, value any) {
b.connections = append(b.connections, sq.Eq{column: value})
}
func (b *permissionClauseBuilder) clauses() sq.Or {
clauses := make(sq.Or, 1, len(b.connections)+1)
clauses[0] = sq.Expr(
fmt.Sprintf(wherePermittedOrgsExpr, b.orgIDColumn.identifier()),
b.instanceID,
b.userID,
database.NewJSONArray(b.systemPermissions),
b.permission,
b.orgID,
)
for _, include := range b.connections {
clauses = append(clauses, include)
}
return clauses
}
type PermissionOption func(b *permissionClauseBuilder)
// OwnedRowsPermissionOption allows rows to be returned of which the current user is the owner.
// Even if the user does not have an explicit permission for the organization.
// For example an authenticated user can always see his own user account.
func OwnedRowsPermissionOption(userIDColumn Column) PermissionOption {
return func(b *permissionClauseBuilder) {
b.appendConnection(userIDColumn.identifier(), b.userID)
}
}
// ConnectionPermissionOption allows returning of rows where the value is matched.
// Even if the user does not have an explicit permission for the organization.
func ConnectionPermissionOption(column Column, value any) PermissionOption {
return func(b *permissionClauseBuilder) {
b.appendConnection(column.identifier(), value)
}
}
// SingleOrgPermissionOption may be used to optimize the permitted orgs function by limiting the
// returned organizations, to the one used in the requested filters.
func SingleOrgPermissionOption(queries []SearchQuery) PermissionOption {
return func(b *permissionClauseBuilder) {
b.orgID = findTextEqualsQuery(b.orgIDColumn, queries)
}
}
// PermissionClause sets a `WHERE` clause to query,
// which filters returned rows the current authenticated user has the requested permission to.
//
// Experimental: Work in progress. Currently only organization permissions are supported
func PermissionClause(ctx context.Context, orgIDCol Column, permission string, options ...PermissionOption) sq.Or {
ctxData := authz.GetCtxData(ctx)
b := &permissionClauseBuilder{
orgIDColumn: orgIDCol,
instanceID: authz.GetInstance(ctx).InstanceID(),
userID: ctxData.UserID,
systemPermissions: ctxData.SystemUserPermissions,
permission: permission,
}
for _, opt := range options {
opt(b)
}
logging.WithFields(
"org_id_column", b.orgIDColumn,
"instance_id", b.instanceID,
"user_id", b.userID,
"system_user_permissions", b.systemPermissions,
"permission", b.permission,
"org_id", b.orgID,
"overrides", b.connections,
).Debug("permitted orgs check used")
return b.clauses()
}
// PermissionV2 checks are enabled when the feature flag is set and the permission check function is not nil.
// When the permission check function is nil, it indicates a v1 API and no resource based permission check is needed.
func PermissionV2(ctx context.Context, cf domain_pkg.PermissionCheck) bool {
return authz.GetFeatures(ctx).PermissionCheckV2 && cf != nil
}

View File

@@ -0,0 +1,208 @@
package query
import (
"context"
"testing"
sq "github.com/Masterminds/squirrel"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/database"
domain_pkg "github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/feature"
)
func TestPermissionClause(t *testing.T) {
var permissions = []authz.SystemUserPermissions{
{
MemberType: authz.MemberTypeOrganization,
AggregateID: "orgID",
Permissions: []string{"permission1", "permission2"},
},
{
MemberType: authz.MemberTypeIAM,
Permissions: []string{"permission2", "permission3"},
},
}
ctx := authz.WithInstanceID(context.Background(), "instanceID")
ctx = authz.SetCtxData(ctx, authz.CtxData{
UserID: "userID",
SystemUserPermissions: permissions,
})
type args struct {
ctx context.Context
orgIDCol Column
permission string
options []PermissionOption
}
tests := []struct {
name string
args args
wantClause sq.Or
}{
{
name: "no options",
args: args{
ctx: ctx,
orgIDCol: UserResourceOwnerCol,
permission: "permission1",
},
wantClause: sq.Or{
sq.Expr(
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
"instanceID",
"userID",
database.NewJSONArray(permissions),
"permission1",
"",
),
},
},
{
name: "owned rows option",
args: args{
ctx: ctx,
orgIDCol: UserResourceOwnerCol,
permission: "permission1",
options: []PermissionOption{
OwnedRowsPermissionOption(UserIDCol),
},
},
wantClause: sq.Or{
sq.Expr(
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
"instanceID",
"userID",
database.NewJSONArray(permissions),
"permission1",
"",
),
sq.Eq{"projections.users14.id": "userID"},
},
},
{
name: "connection rows option",
args: args{
ctx: ctx,
orgIDCol: UserResourceOwnerCol,
permission: "permission1",
options: []PermissionOption{
OwnedRowsPermissionOption(UserIDCol),
ConnectionPermissionOption(UserStateCol, "bar"),
},
},
wantClause: sq.Or{
sq.Expr(
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
"instanceID",
"userID",
database.NewJSONArray(permissions),
"permission1",
"",
),
sq.Eq{"projections.users14.id": "userID"},
sq.Eq{"projections.users14.state": "bar"},
},
},
{
name: "single org option",
args: args{
ctx: ctx,
orgIDCol: UserResourceOwnerCol,
permission: "permission1",
options: []PermissionOption{
SingleOrgPermissionOption([]SearchQuery{
mustSearchQuery(NewUserDisplayNameSearchQuery("zitadel", TextContains)),
mustSearchQuery(NewUserResourceOwnerSearchQuery("orgID", TextEquals)),
}),
},
},
wantClause: sq.Or{
sq.Expr(
"projections.users14.resource_owner = ANY(eventstore.permitted_orgs(?, ?, ?, ?, ?))",
"instanceID",
"userID",
database.NewJSONArray(permissions),
"permission1",
"orgID",
),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
gotClause := PermissionClause(tt.args.ctx, tt.args.orgIDCol, tt.args.permission, tt.args.options...)
assert.Equal(t, tt.wantClause, gotClause)
})
}
}
func mustSearchQuery(q SearchQuery, err error) SearchQuery {
if err != nil {
panic(err)
}
return q
}
func TestPermissionV2(t *testing.T) {
type args struct {
ctx context.Context
cf domain_pkg.PermissionCheck
}
tests := []struct {
name string
args args
want bool
}{
{
name: "feature disabled, no permission check",
args: args{
ctx: context.Background(),
cf: nil,
},
want: false,
},
{
name: "feature enabled, no permission check",
args: args{
ctx: authz.WithFeatures(context.Background(), feature.Features{
PermissionCheckV2: true,
}),
cf: nil,
},
want: false,
},
{
name: "feature enabled, with permission check",
args: args{
ctx: authz.WithFeatures(context.Background(), feature.Features{
PermissionCheckV2: true,
}),
cf: func(context.Context, string, string, string) error {
return nil
},
},
want: true,
},
{
name: "feature disabled, with permission check",
args: args{
ctx: authz.WithFeatures(context.Background(), feature.Features{
PermissionCheckV2: false,
}),
cf: func(context.Context, string, string, string) error {
return nil
},
},
want: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got := PermissionV2(tt.args.ctx, tt.args.cf)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -148,3 +148,16 @@ func triggerBatch(ctx context.Context, handlers ...*handler.Handler) {
wg.Wait()
}
func findTextEqualsQuery(column Column, queries []SearchQuery) string {
for _, query := range queries {
if query.Col() != column {
continue
}
tq, ok := query.(*textQuery)
if ok && tq.Compare == TextEquals {
return tq.Text
}
}
return ""
}

View File

@@ -113,6 +113,22 @@ func sessionCheckPermission(ctx context.Context, resourceOwner string, creator s
return permissionCheck(ctx, domain.PermissionSessionRead, resourceOwner, "")
}
func sessionsPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder {
if !enabled {
return query
}
return query.Where(PermissionClause(
ctx,
SessionColumnResourceOwner,
domain.PermissionSessionRead,
// Allow if user is creator
OwnedRowsPermissionOption(SessionColumnCreator),
// Allow if session belongs to the user
OwnedRowsPermissionOption(SessionColumnUserID),
ConnectionPermissionOption(SessionColumnUserAgentFingerprintID, authz.GetCtxData(ctx).AgentID),
))
}
func (q *SessionsSearchQueries) toQuery(query sq.SelectBuilder) sq.SelectBuilder {
query = q.SearchRequest.toQuery(query)
for _, q := range q.Queries {
@@ -282,21 +298,23 @@ func (q *Queries) sessionByID(ctx context.Context, shouldTriggerBulk bool, id st
}
func (q *Queries) SearchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheck domain.PermissionCheck) (*Sessions, error) {
sessions, err := q.searchSessions(ctx, queries)
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
sessions, err := q.searchSessions(ctx, queries, permissionCheckV2)
if err != nil {
return nil, err
}
if permissionCheck != nil {
if permissionCheck != nil && !permissionCheckV2 {
sessionsCheckPermission(ctx, sessions, permissionCheck)
}
return sessions, nil
}
func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries) (sessions *Sessions, err error) {
func (q *Queries) searchSessions(ctx context.Context, queries *SessionsSearchQueries, permissionCheckV2 bool) (sessions *Sessions, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareSessionsQuery()
query = sessionsPermissionCheckV2(ctx, query, permissionCheckV2)
stmt, args, err := queries.toQuery(query).
Where(sq.Eq{
SessionColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID(),

View File

@@ -132,6 +132,19 @@ func usersCheckPermission(ctx context.Context, users *Users, permissionCheck dom
)
}
func userPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool, queries *UserSearchQueries) sq.SelectBuilder {
if !enabled {
return query
}
return query.Where(PermissionClause(
ctx,
UserResourceOwnerCol,
domain.PermissionUserRead,
SingleOrgPermissionOption(queries.Queries),
OwnedRowsPermissionOption(UserIDCol),
))
}
type UserSearchQueries struct {
SearchRequest
Queries []SearchQuery
@@ -606,8 +619,9 @@ func (q *Queries) CountUsers(ctx context.Context, queries *UserSearchQueries) (c
return count, err
}
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheck domain.PermissionCheck) (*Users, error) {
users, err := q.searchUsers(ctx, queries, filterOrgIds, permissionCheck != nil && authz.GetFeatures(ctx).PermissionCheckV2)
func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheck domain.PermissionCheck) (*Users, error) {
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
users, err := q.searchUsers(ctx, queries, permissionCheckV2)
if err != nil {
return nil, err
}
@@ -617,22 +631,15 @@ func (q *Queries) SearchUsers(ctx context.Context, queries *UserSearchQueries, f
return users, nil
}
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, filterOrgIds string, permissionCheckV2 bool) (users *Users, err error) {
func (q *Queries) searchUsers(ctx context.Context, queries *UserSearchQueries, permissionCheckV2 bool) (users *Users, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareUsersQuery()
query = queries.toQuery(query).Where(sq.Eq{
query = userPermissionCheckV2(ctx, query, permissionCheckV2, queries)
stmt, args, err := queries.toQuery(query).Where(sq.Eq{
UserInstanceIDCol.identifier(): authz.GetInstance(ctx).InstanceID(),
})
if permissionCheckV2 {
query, err = wherePermittedOrgsOrCurrentUser(ctx, query, filterOrgIds, UserResourceOwnerCol.identifier(), UserIDCol.identifier(), domain.PermissionUserRead)
if err != nil {
return nil, zerrors.ThrowInternal(err, "AUTHZ-HS4us", "Errors.Internal")
}
}
stmt, args, err := query.ToSql()
}).ToSql()
if err != nil {
return nil, zerrors.ThrowInternal(err, "QUERY-Dgbg2", "Errors.Query.SQLStatment")
}

View File

@@ -104,6 +104,18 @@ func authMethodsCheckPermission(ctx context.Context, methods *AuthMethods, permi
)
}
func userAuthMethodPermissionCheckV2(ctx context.Context, query sq.SelectBuilder, enabled bool) sq.SelectBuilder {
if !enabled {
return query
}
return query.Where(PermissionClause(
ctx,
UserAuthMethodColumnResourceOwner,
domain.PermissionUserRead,
OwnedRowsPermissionOption(UserIDCol),
))
}
type AuthMethod struct {
UserID string
CreationDate time.Time
@@ -137,11 +149,12 @@ func (q *UserAuthMethodSearchQueries) hasUserID() bool {
}
func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheck domain.PermissionCheck) (userAuthMethods *AuthMethods, err error) {
methods, err := q.searchUserAuthMethods(ctx, queries)
permissionCheckV2 := PermissionV2(ctx, permissionCheck)
methods, err := q.searchUserAuthMethods(ctx, queries, permissionCheckV2)
if err != nil {
return nil, err
}
if permissionCheck != nil && len(methods.AuthMethods) > 0 {
if permissionCheck != nil && len(methods.AuthMethods) > 0 && !permissionCheckV2 {
// when userID for query is provided, only one check has to be done
if queries.hasUserID() {
if err := userCheckPermission(ctx, methods.AuthMethods[0].ResourceOwner, methods.AuthMethods[0].UserID, permissionCheck); err != nil {
@@ -154,11 +167,12 @@ func (q *Queries) SearchUserAuthMethods(ctx context.Context, queries *UserAuthMe
return methods, nil
}
func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries) (userAuthMethods *AuthMethods, err error) {
func (q *Queries) searchUserAuthMethods(ctx context.Context, queries *UserAuthMethodSearchQueries, permissionCheckV2 bool) (userAuthMethods *AuthMethods, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
query, scan := prepareUserAuthMethodsQuery()
query = userAuthMethodPermissionCheckV2(ctx, query, permissionCheckV2)
stmt, args, err := queries.toQuery(query).Where(sq.Eq{UserAuthMethodColumnInstanceID.identifier(): authz.GetInstance(ctx).InstanceID()}).ToSql()
if err != nil {
return nil, zerrors.ThrowInvalidArgument(err, "QUERY-j9NJd", "Errors.Query.InvalidRequest")