diff --git a/internal/query/saml_request.go b/internal/query/saml_request.go index a0f6fdc6cd..a81a1b2c34 100644 --- a/internal/query/saml_request.go +++ b/internal/query/saml_request.go @@ -12,6 +12,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/call" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/eventstore/handler/v2" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -28,9 +29,9 @@ type SamlRequest struct { Binding string } -func (a *SamlRequest) checkLoginClient(ctx context.Context) error { +func (a *SamlRequest) checkLoginClient(ctx context.Context, permissionCheck domain.PermissionCheck) error { if uid := authz.GetCtxData(ctx).UserID; uid != a.LoginClient { - return zerrors.ThrowPermissionDenied(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient") + return permissionCheck(ctx, domain.PermissionSessionRead, authz.GetInstance(ctx).InstanceID(), "") } return nil } @@ -72,7 +73,7 @@ func (q *Queries) SamlRequestByID(ctx context.Context, shouldTriggerBulk bool, i } if checkLoginClient { - if err = dst.checkLoginClient(ctx); err != nil { + if err = dst.checkLoginClient(ctx, q.checkPermission); err != nil { return nil, err } } diff --git a/internal/query/saml_request_test.go b/internal/query/saml_request_test.go index 5cf58369cb..6c6c2b6ebe 100644 --- a/internal/query/saml_request_test.go +++ b/internal/query/saml_request_test.go @@ -1,6 +1,7 @@ package query import ( + "context" "database/sql" "database/sql/driver" _ "embed" @@ -13,6 +14,7 @@ import ( "github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/database" + "github.com/zitadel/zitadel/internal/domain" "github.com/zitadel/zitadel/internal/query/projection" "github.com/zitadel/zitadel/internal/zerrors" ) @@ -38,11 +40,12 @@ func TestQueries_SamlRequestByID(t *testing.T) { checkLoginClient bool } tests := []struct { - name string - args args - expect sqlExpectation - want *SamlRequest - wantErr error + name string + args args + expect sqlExpectation + permissionCheck domain.PermissionCheck + want *SamlRequest + wantErr error }{ { name: "success, all values", @@ -89,7 +92,7 @@ func TestQueries_SamlRequestByID(t *testing.T) { wantErr: zerrors.ThrowInternal(sql.ErrConnDone, "QUERY-Ou8ue", "Errors.Internal"), }, { - name: "wrong login client", + name: "wrong login client/ not permitted", args: args{ shouldTriggerBulk: false, id: "123", @@ -104,13 +107,46 @@ func TestQueries_SamlRequestByID(t *testing.T) { "relayState", "binding", }, "123", "instanceID"), - wantErr: zerrors.ThrowPermissionDeniedf(nil, "OIDCv2-aL0ag", "Errors.SamlRequest.WrongLoginClient"), + permissionCheck: func(ctx context.Context, permission, orgID, resourceID string) (err error) { + return zerrors.ThrowPermissionDenied(nil, "id", "not permitted") + }, + wantErr: zerrors.ThrowPermissionDenied(nil, "id", "not permitted"), + }, + { + name: "wrong login client / permitted", + args: args{ + shouldTriggerBulk: false, + id: "123", + checkLoginClient: true, + }, + expect: mockQuery(expQuery, cols, []driver.Value{ + "id", + testNow, + "otherLoginClient", + "issuer", + "acs", + "relayState", + "binding", + }, "123", "instanceID"), + permissionCheck: func(ctx context.Context, permission, orgID, resourceID string) (err error) { + return nil + }, + want: &SamlRequest{ + ID: "id", + CreationDate: testNow, + LoginClient: "otherLoginClient", + Issuer: "issuer", + ACS: "acs", + RelayState: "relayState", + Binding: "binding", + }, }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { execMock(t, tt.expect, func(db *sql.DB) { q := &Queries{ + checkPermission: tt.permissionCheck, client: &database.DB{ DB: db, Database: &prepareDB{},