mirror of
https://github.com/zitadel/zitadel.git
synced 2025-08-11 21:17:32 +00:00
feat: exchange gRPC server implementation to connectRPC (#10145)
# Which Problems Are Solved The current maintained gRPC server in combination with a REST (grpc) gateway is getting harder and harder to maintain. Additionally, there have been and still are issues with supporting / displaying `oneOf`s correctly. We therefore decided to exchange the server implementation to connectRPC, which apart from supporting connect as protocol, also also "standard" gRCP clients as well as HTTP/1.1 / rest like clients, e.g. curl directly call the server without any additional gateway. # How the Problems Are Solved - All v2 services are moved to connectRPC implementation. (v1 services are still served as pure grpc servers) - All gRPC server interceptors were migrated / copied to a corresponding connectRPC interceptor. - API.ListGrpcServices and API. ListGrpcMethods were changed to include the connect services and endpoints. - gRPC server reflection was changed to a `StaticReflector` using the `ListGrpcServices` list. - The `grpc.Server` interfaces was split into different combinations to be able to handle the different cases (grpc server and prefixed gateway, connect server with grpc gateway, connect server only, ...) - Docs of services serving connectRPC only with no additional gateway (instance, webkey, project, app, org v2 beta) are changed to expose that - since the plugin is not yet available on buf, we download it using `postinstall` hook of the docs # Additional Changes - WebKey service is added as v2 service (in addition to the current v2beta) # Additional Context closes #9483 --------- Co-authored-by: Elio Bischof <elio@zitadel.com>
This commit is contained in:
@@ -0,0 +1,57 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
|
||||
if !svc.Enabled() {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
resp, handlerErr := handler(ctx, req)
|
||||
|
||||
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var respStatus uint32
|
||||
if code := connect.CodeOf(handlerErr); code != connect.CodeUnknown {
|
||||
respStatus = uint32(code)
|
||||
}
|
||||
|
||||
respHeader := http.Header{}
|
||||
if resp != nil {
|
||||
respHeader = resp.Header()
|
||||
}
|
||||
instance := authz.GetInstance(ctx)
|
||||
domainCtx := http_util.DomainContext(ctx)
|
||||
|
||||
r := &record.AccessLog{
|
||||
LogDate: time.Now(),
|
||||
Protocol: record.GRPC,
|
||||
RequestURL: req.Spec().Procedure,
|
||||
ResponseStatus: respStatus,
|
||||
RequestHeaders: req.Header(),
|
||||
ResponseHeaders: respHeader,
|
||||
InstanceID: instance.InstanceID(),
|
||||
ProjectID: instance.ProjectID(),
|
||||
RequestedDomain: domainCtx.RequestedDomain(),
|
||||
RequestedHost: domainCtx.RequestedHost(),
|
||||
}
|
||||
|
||||
svc.Handle(interceptorCtx, r)
|
||||
return resp, handlerErr
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,52 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/activity"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/gerrors"
|
||||
ainfo "github.com/zitadel/zitadel/internal/api/info"
|
||||
)
|
||||
|
||||
func ActivityInterceptor() connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
ctx = activityInfoFromGateway(ctx, req.Header()).SetMethod(req.Spec().Procedure).IntoContext(ctx)
|
||||
resp, err := handler(ctx, req)
|
||||
if isResourceAPI(req.Spec().Procedure) {
|
||||
code, _, _, _ := gerrors.ExtractZITADELError(err)
|
||||
ctx = ainfo.ActivityInfoFromContext(ctx).SetGRPCStatus(code).IntoContext(ctx)
|
||||
activity.TriggerGRPCWithContext(ctx, activity.ResourceAPI)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
var resourcePrefixes = []string{
|
||||
"/zitadel.management.v1.ManagementService/",
|
||||
"/zitadel.admin.v1.AdminService/",
|
||||
"/zitadel.user.v2.UserService/",
|
||||
"/zitadel.settings.v2.SettingsService/",
|
||||
"/zitadel.user.v2beta.UserService/",
|
||||
"/zitadel.settings.v2beta.SettingsService/",
|
||||
"/zitadel.auth.v1.AuthService/",
|
||||
}
|
||||
|
||||
func isResourceAPI(method string) bool {
|
||||
return slices.ContainsFunc(resourcePrefixes, func(prefix string) bool {
|
||||
return strings.HasPrefix(method, prefix)
|
||||
})
|
||||
}
|
||||
|
||||
func activityInfoFromGateway(ctx context.Context, headers http.Header) *ainfo.ActivityInfo {
|
||||
info := ainfo.ActivityInfoFromContext(ctx)
|
||||
path := headers.Get(activity.PathKey)
|
||||
requestMethod := headers.Get(activity.RequestMethodKey)
|
||||
return info.SetPath(path).SetRequestMethod(requestMethod)
|
||||
}
|
@@ -0,0 +1,65 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func AuthorizationInterceptor(verifier authz.APITokenVerifier, systemUserPermissions authz.Config, authConfig authz.Config) connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
return authorize(ctx, req, handler, verifier, systemUserPermissions, authConfig)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func authorize(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.APITokenVerifier, systemUserPermissions authz.Config, authConfig authz.Config) (_ connect.AnyResponse, err error) {
|
||||
authOpt, needsToken := verifier.CheckAuthMethod(req.Spec().Procedure)
|
||||
if !needsToken {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
authCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
authToken := req.Header().Get(http.Authorization)
|
||||
if authToken == "" {
|
||||
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("auth header missing"))
|
||||
}
|
||||
|
||||
orgID, orgDomain := orgIDAndDomainFromRequest(req)
|
||||
ctxSetter, err := authz.CheckUserAuthorization(authCtx, req, authToken, orgID, orgDomain, verifier, systemUserPermissions.RolePermissionMappings, authConfig.RolePermissionMappings, authOpt, req.Spec().Procedure)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
span.End()
|
||||
return handler(ctxSetter(ctx), req)
|
||||
}
|
||||
|
||||
func orgIDAndDomainFromRequest(req connect.AnyRequest) (id, domain string) {
|
||||
orgID := req.Header().Get(http.ZitadelOrgID)
|
||||
oz, ok := req.Any().(OrganizationFromRequest)
|
||||
if ok {
|
||||
id = oz.OrganizationFromRequestConnect().ID
|
||||
domain = oz.OrganizationFromRequestConnect().Domain
|
||||
if id != "" || domain != "" {
|
||||
return id, domain
|
||||
}
|
||||
}
|
||||
return orgID, domain
|
||||
}
|
||||
|
||||
type Organization struct {
|
||||
ID string
|
||||
Domain string
|
||||
}
|
||||
|
||||
type OrganizationFromRequest interface {
|
||||
OrganizationFromRequestConnect() *Organization
|
||||
}
|
@@ -0,0 +1,318 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
const anAPIRole = "AN_API_ROLE"
|
||||
|
||||
type authzRepoMock struct{}
|
||||
|
||||
func (v *authzRepoMock) VerifyAccessToken(ctx context.Context, token, clientID, projectID string) (string, string, string, string, string, error) {
|
||||
return "", "", "", "", "", nil
|
||||
}
|
||||
|
||||
func (v *authzRepoMock) SearchMyMemberships(ctx context.Context, orgID string, _ bool) ([]*authz.Membership, error) {
|
||||
return authz.Memberships{{
|
||||
MemberType: authz.MemberTypeOrganization,
|
||||
AggregateID: orgID,
|
||||
Roles: []string{anAPIRole},
|
||||
}}, nil
|
||||
}
|
||||
|
||||
func (v *authzRepoMock) ProjectIDAndOriginsByClientID(ctx context.Context, clientID string) (string, []string, error) {
|
||||
return "", nil, nil
|
||||
}
|
||||
|
||||
func (v *authzRepoMock) ExistsOrg(ctx context.Context, orgID, domain string) (string, error) {
|
||||
return orgID, nil
|
||||
}
|
||||
|
||||
func (v *authzRepoMock) VerifierClientID(ctx context.Context, appName string) (string, string, error) {
|
||||
return "", "", nil
|
||||
}
|
||||
|
||||
var (
|
||||
accessTokenOK = authz.AccessTokenVerifierFunc(func(ctx context.Context, token string) (userID string, clientID string, agentID string, prefLan string, resourceOwner string, err error) {
|
||||
return "user1", "", "", "", "org1", nil
|
||||
})
|
||||
accessTokenNOK = authz.AccessTokenVerifierFunc(func(ctx context.Context, token string) (userID string, clientID string, agentID string, prefLan string, resourceOwner string, err error) {
|
||||
return "", "", "", "", "", zerrors.ThrowUnauthenticated(nil, "TEST-fQHDI", "unauthenticaded")
|
||||
})
|
||||
systemTokenNOK = authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) {
|
||||
return nil, "", errors.New("system token error")
|
||||
})
|
||||
)
|
||||
|
||||
type mockOrgFromRequest struct {
|
||||
id string
|
||||
}
|
||||
|
||||
func (m *mockOrgFromRequest) OrganizationFromRequestConnect() *Organization {
|
||||
return &Organization{
|
||||
ID: m.id,
|
||||
Domain: "",
|
||||
}
|
||||
}
|
||||
|
||||
func Test_authorize(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req connect.AnyRequest
|
||||
handler func(t *testing.T) connect.UnaryFunc
|
||||
verifier func() authz.APITokenVerifier
|
||||
authConfig authz.Config
|
||||
}
|
||||
type res struct {
|
||||
want interface{}
|
||||
wantErr bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"no token needed ok",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/no/token/needed"},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{})
|
||||
return verifier
|
||||
},
|
||||
},
|
||||
res{
|
||||
&connect.Response[struct{}]{},
|
||||
false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"auth header missing error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/need/authentication"},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{},
|
||||
},
|
||||
res{
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"unauthorized error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"wrong"}}},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{},
|
||||
},
|
||||
res{
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"authorized ok",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
|
||||
UserID: "user1",
|
||||
OrgID: "org1",
|
||||
ResourceOwner: "org1",
|
||||
}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{},
|
||||
},
|
||||
res{
|
||||
&connect.Response[struct{}]{},
|
||||
false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"authorized ok, org by request",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[mockOrgFromRequest]{
|
||||
Request: connect.Request[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}},
|
||||
procedure: "/need/authentication",
|
||||
header: http.Header{"Authorization": []string{"Bearer token"}},
|
||||
},
|
||||
handler: emptyMockHandler(&connect.Response[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}}, authz.CtxData{
|
||||
UserID: "user1",
|
||||
OrgID: "id",
|
||||
ResourceOwner: "org1",
|
||||
}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{},
|
||||
},
|
||||
res{
|
||||
&connect.Response[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}},
|
||||
false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"permission denied error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
|
||||
UserID: "user1",
|
||||
OrgID: "org1",
|
||||
ResourceOwner: "org1",
|
||||
}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{
|
||||
RolePermissionMappings: []authz.RoleMapping{{
|
||||
Role: anAPIRole,
|
||||
Permissions: []string{"to.do.something.else"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
res{
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"permission ok",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
|
||||
UserID: "user1",
|
||||
OrgID: "org1",
|
||||
ResourceOwner: "org1",
|
||||
}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{
|
||||
RolePermissionMappings: []authz.RoleMapping{{
|
||||
Role: anAPIRole,
|
||||
Permissions: []string{"to.do.something"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
res{
|
||||
&connect.Response[struct{}]{},
|
||||
false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"system token permission denied error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenNOK, authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) {
|
||||
return authz.Memberships{{
|
||||
MemberType: authz.MemberTypeSystem,
|
||||
Roles: []string{"A_SYSTEM_ROLE"},
|
||||
}}, "systemuser", nil
|
||||
}))
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{
|
||||
RolePermissionMappings: []authz.RoleMapping{{
|
||||
Role: "A_SYSTEM_ROLE",
|
||||
Permissions: []string{"to.do.something.else"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
res{
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"system token permission denied error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
|
||||
UserID: "systemuser",
|
||||
SystemMemberships: authz.Memberships{{
|
||||
MemberType: authz.MemberTypeSystem,
|
||||
Roles: []string{"A_SYSTEM_ROLE"},
|
||||
}},
|
||||
SystemUserPermissions: []authz.SystemUserPermissions{{
|
||||
MemberType: authz.MemberTypeSystem,
|
||||
Permissions: []string{"to.do.something"},
|
||||
}},
|
||||
}),
|
||||
verifier: func() authz.APITokenVerifier {
|
||||
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenNOK, authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) {
|
||||
return authz.Memberships{{
|
||||
MemberType: authz.MemberTypeSystem,
|
||||
Roles: []string{"A_SYSTEM_ROLE"},
|
||||
}}, "systemuser", nil
|
||||
}))
|
||||
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
|
||||
return verifier
|
||||
},
|
||||
authConfig: authz.Config{
|
||||
RolePermissionMappings: []authz.RoleMapping{{
|
||||
Role: "A_SYSTEM_ROLE",
|
||||
Permissions: []string{"to.do.something"},
|
||||
}},
|
||||
},
|
||||
},
|
||||
res{
|
||||
&connect.Response[struct{}]{},
|
||||
false,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := authorize(tt.args.ctx, tt.args.req, tt.args.handler(t), tt.args.verifier(), tt.args.authConfig, tt.args.authConfig)
|
||||
if (err != nil) != tt.res.wantErr {
|
||||
t.Errorf("authorize() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.res.want) {
|
||||
t.Errorf("authorize() got = %v, want %v", got, tt.res.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -0,0 +1,31 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
_ "github.com/zitadel/zitadel/internal/statik"
|
||||
)
|
||||
|
||||
func NoCacheInterceptor() connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
headers := map[string]string{
|
||||
"cache-control": "no-store",
|
||||
"expires": time.Now().UTC().Format(http.TimeFormat),
|
||||
"pragma": "no-cache",
|
||||
}
|
||||
resp, err := handler(ctx, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for key, value := range headers {
|
||||
resp.Header().Set(key, value)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,18 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/call"
|
||||
)
|
||||
|
||||
func CallDurationHandler() connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
ctx = call.WithTimestamp(ctx)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,23 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/gerrors"
|
||||
_ "github.com/zitadel/zitadel/internal/statik"
|
||||
)
|
||||
|
||||
func ErrorHandler() connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
return toConnectError(ctx, req, handler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func toConnectError(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc) (connect.AnyResponse, error) {
|
||||
resp, err := handler(ctx, req)
|
||||
return resp, gerrors.ZITADELToConnectError(err) // TODO !
|
||||
}
|
@@ -0,0 +1,65 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"reflect"
|
||||
"testing"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
)
|
||||
|
||||
func Test_toGRPCError(t *testing.T) {
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
req connect.AnyRequest
|
||||
handler func(t *testing.T) connect.UnaryFunc
|
||||
}
|
||||
type res struct {
|
||||
want interface{}
|
||||
wantErr bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"no error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{},
|
||||
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
|
||||
},
|
||||
res{
|
||||
&connect.Response[struct{}]{},
|
||||
false,
|
||||
},
|
||||
},
|
||||
{
|
||||
"error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
req: &mockReq[struct{}]{},
|
||||
handler: errorMockHandler(),
|
||||
},
|
||||
res{
|
||||
nil,
|
||||
true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := toConnectError(tt.args.ctx, tt.args.req, tt.args.handler(t))
|
||||
if (err != nil) != tt.res.wantErr {
|
||||
t.Errorf("toGRPCError() error = %v, wantErr %v", err, tt.res.wantErr)
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(got, tt.res.want) {
|
||||
t.Errorf("toGRPCError() got = %v, want %v", got, tt.res.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -0,0 +1,160 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/execution"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func ExecutionHandler(queries *query.Queries) connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
|
||||
requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, req.Spec().Procedure)
|
||||
|
||||
// call targets otherwise return req
|
||||
handledReq, err := executeTargetsForRequest(ctx, requestTargets, req.Spec().Procedure, req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
response, err := handler(ctx, handledReq)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return executeTargetsForResponse(ctx, responseTargets, req.Spec().Procedure, handledReq, response)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func executeTargetsForRequest(ctx context.Context, targets []execution.Target, fullMethod string, req connect.AnyRequest) (_ connect.AnyRequest, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
// if no targets are found, return without any calls
|
||||
if len(targets) == 0 {
|
||||
return req, nil
|
||||
}
|
||||
|
||||
ctxData := authz.GetCtxData(ctx)
|
||||
info := &ContextInfoRequest{
|
||||
FullMethod: fullMethod,
|
||||
InstanceID: authz.GetInstance(ctx).InstanceID(),
|
||||
ProjectID: ctxData.ProjectID,
|
||||
OrgID: ctxData.OrgID,
|
||||
UserID: ctxData.UserID,
|
||||
Request: Message{req.Any().(proto.Message)},
|
||||
}
|
||||
|
||||
_, err = execution.CallTargets(ctx, targets, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func executeTargetsForResponse(ctx context.Context, targets []execution.Target, fullMethod string, req connect.AnyRequest, resp connect.AnyResponse) (_ connect.AnyResponse, err error) {
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
// if no targets are found, return without any calls
|
||||
if len(targets) == 0 {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
ctxData := authz.GetCtxData(ctx)
|
||||
info := &ContextInfoResponse{
|
||||
FullMethod: fullMethod,
|
||||
InstanceID: authz.GetInstance(ctx).InstanceID(),
|
||||
ProjectID: ctxData.ProjectID,
|
||||
OrgID: ctxData.OrgID,
|
||||
UserID: ctxData.UserID,
|
||||
Request: Message{req.Any().(proto.Message)},
|
||||
Response: Message{resp.Any().(proto.Message)},
|
||||
}
|
||||
|
||||
_, err = execution.CallTargets(ctx, targets, info)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
var _ execution.ContextInfo = &ContextInfoRequest{}
|
||||
|
||||
type ContextInfoRequest struct {
|
||||
FullMethod string `json:"fullMethod,omitempty"`
|
||||
InstanceID string `json:"instanceID,omitempty"`
|
||||
OrgID string `json:"orgID,omitempty"`
|
||||
ProjectID string `json:"projectID,omitempty"`
|
||||
UserID string `json:"userID,omitempty"`
|
||||
Request Message `json:"request,omitempty"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
proto.Message
|
||||
}
|
||||
|
||||
func (r *Message) MarshalJSON() ([]byte, error) {
|
||||
data, err := protojson.Marshal(r.Message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (r *Message) UnmarshalJSON(data []byte) error {
|
||||
return protojson.Unmarshal(data, r.Message)
|
||||
}
|
||||
|
||||
func (c *ContextInfoRequest) GetHTTPRequestBody() []byte {
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (c *ContextInfoRequest) SetHTTPResponseBody(resp []byte) error {
|
||||
return json.Unmarshal(resp, &c.Request)
|
||||
}
|
||||
|
||||
func (c *ContextInfoRequest) GetContent() interface{} {
|
||||
return c.Request.Message
|
||||
}
|
||||
|
||||
var _ execution.ContextInfo = &ContextInfoResponse{}
|
||||
|
||||
type ContextInfoResponse struct {
|
||||
FullMethod string `json:"fullMethod,omitempty"`
|
||||
InstanceID string `json:"instanceID,omitempty"`
|
||||
OrgID string `json:"orgID,omitempty"`
|
||||
ProjectID string `json:"projectID,omitempty"`
|
||||
UserID string `json:"userID,omitempty"`
|
||||
Request Message `json:"request,omitempty"`
|
||||
Response Message `json:"response,omitempty"`
|
||||
}
|
||||
|
||||
func (c *ContextInfoResponse) GetHTTPRequestBody() []byte {
|
||||
data, err := json.Marshal(c)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (c *ContextInfoResponse) SetHTTPResponseBody(resp []byte) error {
|
||||
return json.Unmarshal(resp, &c.Response)
|
||||
}
|
||||
|
||||
func (c *ContextInfoResponse) GetContent() interface{} {
|
||||
return c.Response.Message
|
||||
}
|
@@ -0,0 +1,815 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"google.golang.org/protobuf/encoding/protojson"
|
||||
"google.golang.org/protobuf/proto"
|
||||
"google.golang.org/protobuf/types/known/structpb"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/domain"
|
||||
"github.com/zitadel/zitadel/internal/execution"
|
||||
)
|
||||
|
||||
var _ execution.Target = &mockExecutionTarget{}
|
||||
|
||||
type mockExecutionTarget struct {
|
||||
InstanceID string
|
||||
ExecutionID string
|
||||
TargetID string
|
||||
TargetType domain.TargetType
|
||||
Endpoint string
|
||||
Timeout time.Duration
|
||||
InterruptOnError bool
|
||||
SigningKey string
|
||||
}
|
||||
|
||||
func (e *mockExecutionTarget) SetEndpoint(endpoint string) {
|
||||
e.Endpoint = endpoint
|
||||
}
|
||||
func (e *mockExecutionTarget) IsInterruptOnError() bool {
|
||||
return e.InterruptOnError
|
||||
}
|
||||
func (e *mockExecutionTarget) GetEndpoint() string {
|
||||
return e.Endpoint
|
||||
}
|
||||
func (e *mockExecutionTarget) GetTargetType() domain.TargetType {
|
||||
return e.TargetType
|
||||
}
|
||||
func (e *mockExecutionTarget) GetTimeout() time.Duration {
|
||||
return e.Timeout
|
||||
}
|
||||
func (e *mockExecutionTarget) GetTargetID() string {
|
||||
return e.TargetID
|
||||
}
|
||||
func (e *mockExecutionTarget) GetExecutionID() string {
|
||||
return e.ExecutionID
|
||||
}
|
||||
func (e *mockExecutionTarget) GetSigningKey() string {
|
||||
return e.SigningKey
|
||||
}
|
||||
|
||||
func newMockContentRequest(content string) *connect.Request[structpb.Struct] {
|
||||
return connect.NewRequest(&structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"content": {
|
||||
Kind: &structpb.Value_StringValue{StringValue: content},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newMockContentResponse(content string) *connect.Response[structpb.Struct] {
|
||||
return connect.NewResponse(&structpb.Struct{
|
||||
Fields: map[string]*structpb.Value{
|
||||
"content": {
|
||||
Kind: &structpb.Value_StringValue{StringValue: content},
|
||||
},
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
func newMockContextInfoRequest(fullMethod, request string) *ContextInfoRequest {
|
||||
return &ContextInfoRequest{
|
||||
FullMethod: fullMethod,
|
||||
Request: Message{Message: newMockContentRequest(request).Msg},
|
||||
}
|
||||
}
|
||||
|
||||
func newMockContextInfoResponse(fullMethod, request, response string) *ContextInfoResponse {
|
||||
return &ContextInfoResponse{
|
||||
FullMethod: fullMethod,
|
||||
Request: Message{Message: newMockContentRequest(request).Msg},
|
||||
Response: Message{Message: newMockContentResponse(response).Msg},
|
||||
}
|
||||
}
|
||||
|
||||
func Test_executeTargetsForGRPCFullMethod_request(t *testing.T) {
|
||||
type target struct {
|
||||
reqBody execution.ContextInfo
|
||||
sleep time.Duration
|
||||
statusCode int
|
||||
respBody connect.AnyResponse
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
|
||||
executionTargets []execution.Target
|
||||
targets []target
|
||||
fullMethod string
|
||||
req connect.AnyRequest
|
||||
}
|
||||
type res struct {
|
||||
want interface{}
|
||||
wantErr bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"target, executionTargets nil",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: nil,
|
||||
req: newMockContentRequest("request"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentRequest("request"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, executionTargets empty",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{},
|
||||
req: newMockContentRequest("request"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentRequest("request"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, not reachable",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
},
|
||||
},
|
||||
targets: []target{},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, error without interrupt",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentRequest("content"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, interruptOnError",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, timeout",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Second,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 5 * time.Second,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, wrong request",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Second,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{reqBody: newMockContextInfoRequest("/service/method", "wrong")},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, ok",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentRequest("content1"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"target async, timeout",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeAsync,
|
||||
Timeout: time.Second,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 5 * time.Second,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentRequest("content"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"target async, ok",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeAsync,
|
||||
Timeout: time.Minute,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentRequest("content"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"webhook, error",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeWebhook,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusInternalServerError,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"webhook, timeout",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeWebhook,
|
||||
Timeout: time.Second,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 5 * time.Second,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"webhook, ok",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeWebhook,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentRequest("content"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"with includes, interruptOnError",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target1",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target2",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target3",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content1"),
|
||||
respBody: newMockContentResponse("content2"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content2"),
|
||||
respBody: newMockContentResponse("content3"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"with includes, timeout",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target1",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target2",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Second,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target3",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Second,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse("content1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content1"),
|
||||
respBody: newMockContentResponse("content2"),
|
||||
sleep: 5 * time.Second,
|
||||
statusCode: http.StatusBadRequest,
|
||||
},
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content2"),
|
||||
respBody: newMockContentResponse("content3"),
|
||||
sleep: 5 * time.Second,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("content"),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
closeFuncs := make([]func(), len(tt.args.targets))
|
||||
for i, target := range tt.args.targets {
|
||||
url, closeF := testServerCall(
|
||||
target.reqBody,
|
||||
target.sleep,
|
||||
target.statusCode,
|
||||
target.respBody,
|
||||
)
|
||||
|
||||
et := tt.args.executionTargets[i].(*mockExecutionTarget)
|
||||
et.SetEndpoint(url)
|
||||
closeFuncs[i] = closeF
|
||||
}
|
||||
|
||||
resp, err := executeTargetsForRequest(
|
||||
tt.args.ctx,
|
||||
tt.args.executionTargets,
|
||||
tt.args.fullMethod,
|
||||
tt.args.req,
|
||||
)
|
||||
|
||||
if tt.res.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.EqualExportedValues(t, tt.res.want, resp)
|
||||
|
||||
for _, closeF := range closeFuncs {
|
||||
closeF()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func testServerCall(
|
||||
reqBody interface{},
|
||||
sleep time.Duration,
|
||||
statusCode int,
|
||||
respBody connect.AnyResponse,
|
||||
) (string, func()) {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
data, err := json.Marshal(reqBody)
|
||||
if err != nil {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
sentBody, err := io.ReadAll(r.Body)
|
||||
if err != nil {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(data, sentBody) {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
|
||||
if statusCode != http.StatusOK {
|
||||
http.Error(w, "error", statusCode)
|
||||
return
|
||||
}
|
||||
|
||||
time.Sleep(sleep)
|
||||
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
resp, err := protojson.Marshal(respBody.Any().(proto.Message))
|
||||
if err != nil {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
if _, err := w.Write(resp); err != nil {
|
||||
http.Error(w, "error", http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(handler))
|
||||
|
||||
return server.URL, server.Close
|
||||
}
|
||||
|
||||
func Test_executeTargetsForGRPCFullMethod_response(t *testing.T) {
|
||||
type target struct {
|
||||
reqBody execution.ContextInfo
|
||||
sleep time.Duration
|
||||
statusCode int
|
||||
respBody connect.AnyResponse
|
||||
}
|
||||
type args struct {
|
||||
ctx context.Context
|
||||
|
||||
executionTargets []execution.Target
|
||||
targets []target
|
||||
fullMethod string
|
||||
req connect.AnyRequest
|
||||
resp connect.AnyResponse
|
||||
}
|
||||
type res struct {
|
||||
want interface{}
|
||||
wantErr bool
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
res res
|
||||
}{
|
||||
{
|
||||
"target, executionTargets nil",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: nil,
|
||||
req: newMockContentRequest("request"),
|
||||
resp: newMockContentResponse("response"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentResponse("response"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, executionTargets empty",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{},
|
||||
req: newMockContentRequest("request"),
|
||||
resp: newMockContentResponse("response"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentResponse("response"),
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, empty response",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoRequest("/service/method", "content"),
|
||||
respBody: newMockContentResponse(""),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest(""),
|
||||
resp: newMockContentResponse(""),
|
||||
},
|
||||
res{
|
||||
wantErr: true,
|
||||
},
|
||||
},
|
||||
{
|
||||
"target, ok",
|
||||
args{
|
||||
ctx: context.Background(),
|
||||
fullMethod: "/service/method",
|
||||
executionTargets: []execution.Target{
|
||||
&mockExecutionTarget{
|
||||
InstanceID: "instance",
|
||||
ExecutionID: "response./zitadel.session.v2.SessionService/SetSession",
|
||||
TargetID: "target",
|
||||
TargetType: domain.TargetTypeCall,
|
||||
Timeout: time.Minute,
|
||||
InterruptOnError: true,
|
||||
SigningKey: "signingkey",
|
||||
},
|
||||
},
|
||||
targets: []target{
|
||||
{
|
||||
reqBody: newMockContextInfoResponse("/service/method", "request", "response"),
|
||||
respBody: newMockContentResponse("response1"),
|
||||
sleep: 0,
|
||||
statusCode: http.StatusOK,
|
||||
},
|
||||
},
|
||||
req: newMockContentRequest("request"),
|
||||
resp: newMockContentResponse("response"),
|
||||
},
|
||||
res{
|
||||
want: newMockContentResponse("response1"),
|
||||
},
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
closeFuncs := make([]func(), len(tt.args.targets))
|
||||
for i, target := range tt.args.targets {
|
||||
url, closeF := testServerCall(
|
||||
target.reqBody,
|
||||
target.sleep,
|
||||
target.statusCode,
|
||||
target.respBody,
|
||||
)
|
||||
|
||||
et := tt.args.executionTargets[i].(*mockExecutionTarget)
|
||||
et.SetEndpoint(url)
|
||||
closeFuncs[i] = closeF
|
||||
}
|
||||
|
||||
resp, err := executeTargetsForResponse(
|
||||
tt.args.ctx,
|
||||
tt.args.executionTargets,
|
||||
tt.args.fullMethod,
|
||||
tt.args.req,
|
||||
tt.args.resp,
|
||||
)
|
||||
|
||||
if tt.res.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
assert.EqualExportedValues(t, tt.res.want, resp)
|
||||
|
||||
for _, closeF := range closeFuncs {
|
||||
closeF()
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
@@ -0,0 +1,107 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/zitadel/logging"
|
||||
"golang.org/x/text/language"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/i18n"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
object_v3 "github.com/zitadel/zitadel/pkg/grpc/object/v3alpha"
|
||||
)
|
||||
|
||||
func InstanceInterceptor(verifier authz.InstanceVerifier, externalDomain string, explicitInstanceIdServices ...string) connect.UnaryInterceptorFunc {
|
||||
translator, err := i18n.NewZitadelTranslator(language.English)
|
||||
logging.OnError(err).Panic("unable to get translator")
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
return setInstance(ctx, req, handler, verifier, externalDomain, translator, explicitInstanceIdServices...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func setInstance(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, externalDomain string, translator *i18n.Translator, idFromRequestsServices ...string) (_ connect.AnyResponse, err error) {
|
||||
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
for _, service := range idFromRequestsServices {
|
||||
if !strings.HasPrefix(service, "/") {
|
||||
service = "/" + service
|
||||
}
|
||||
if strings.HasPrefix(req.Spec().Procedure, service) {
|
||||
withInstanceIDProperty, ok := req.Any().(interface {
|
||||
GetInstanceId() string
|
||||
})
|
||||
if !ok {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
return addInstanceByID(interceptorCtx, req, handler, verifier, translator, withInstanceIDProperty.GetInstanceId())
|
||||
}
|
||||
}
|
||||
explicitInstanceRequest, ok := req.Any().(interface {
|
||||
GetInstance() *object_v3.Instance
|
||||
})
|
||||
if ok {
|
||||
instance := explicitInstanceRequest.GetInstance()
|
||||
if id := instance.GetId(); id != "" {
|
||||
return addInstanceByID(interceptorCtx, req, handler, verifier, translator, id)
|
||||
}
|
||||
if domain := instance.GetDomain(); domain != "" {
|
||||
return addInstanceByDomain(interceptorCtx, req, handler, verifier, translator, domain)
|
||||
}
|
||||
}
|
||||
return addInstanceByRequestedHost(interceptorCtx, req, handler, verifier, translator, externalDomain)
|
||||
}
|
||||
|
||||
func addInstanceByID(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, id string) (connect.AnyResponse, error) {
|
||||
instance, err := verifier.InstanceByID(ctx, id)
|
||||
if err != nil {
|
||||
notFoundErr := new(zerrors.ZitadelError)
|
||||
if errors.As(err, ¬FoundErr) {
|
||||
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using id %s: %w", id, notFoundErr))
|
||||
}
|
||||
return handler(authz.WithInstance(ctx, instance), req)
|
||||
}
|
||||
|
||||
func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, domain string) (connect.AnyResponse, error) {
|
||||
instance, err := verifier.InstanceByHost(ctx, domain, "")
|
||||
if err != nil {
|
||||
notFoundErr := new(zerrors.NotFoundError)
|
||||
if errors.As(err, ¬FoundErr) {
|
||||
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using domain %s: %w", domain, notFoundErr))
|
||||
}
|
||||
return handler(authz.WithInstance(ctx, instance), req)
|
||||
}
|
||||
|
||||
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
|
||||
requestContext := zitadel_http.DomainContext(ctx)
|
||||
if requestContext.InstanceHost == "" {
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
|
||||
if err != nil {
|
||||
origin := zitadel_http.DomainContext(ctx)
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
||||
zErr := new(zerrors.ZitadelError)
|
||||
if errors.As(err, &zErr) {
|
||||
zErr.SetMessage(translator.LocalizeFromCtx(ctx, zErr.GetMessage(), nil))
|
||||
zErr.Parent = err
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using origin %s (ExternalDomain is %s): %s", origin, externalDomain, zErr.Error()))
|
||||
}
|
||||
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using origin %s (ExternalDomain is %s)", origin, externalDomain))
|
||||
}
|
||||
return handler(authz.WithInstance(ctx, instance), req)
|
||||
}
|
@@ -0,0 +1,34 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func LimitsInterceptor(ignoreService ...string) connect.UnaryInterceptorFunc {
|
||||
for idx, service := range ignoreService {
|
||||
if !strings.HasPrefix(service, "/") {
|
||||
ignoreService[idx] = "/" + service
|
||||
}
|
||||
}
|
||||
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
|
||||
for _, service := range ignoreService {
|
||||
if strings.HasPrefix(req.Spec().Procedure, service) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
instance := authz.GetInstance(ctx)
|
||||
if block := instance.Block(); block != nil && *block {
|
||||
return nil, zerrors.ThrowResourceExhausted(nil, "LIMITS-molsj", "Errors.Limits.Instance.Blocked")
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,96 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/grpc-ecosystem/grpc-gateway/runtime"
|
||||
"github.com/zitadel/logging"
|
||||
"go.opentelemetry.io/otel/attribute"
|
||||
"google.golang.org/grpc/codes"
|
||||
|
||||
_ "github.com/zitadel/zitadel/internal/statik"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
)
|
||||
|
||||
const (
|
||||
GrpcMethod = "grpc_method"
|
||||
ReturnCode = "return_code"
|
||||
GrpcRequestCounter = "grpc.server.request_counter"
|
||||
GrpcRequestCounterDescription = "Grpc request counter"
|
||||
TotalGrpcRequestCounter = "grpc.server.total_request_counter"
|
||||
TotalGrpcRequestCounterDescription = "Total grpc request counter"
|
||||
GrpcStatusCodeCounter = "grpc.server.grpc_status_code"
|
||||
GrpcStatusCodeCounterDescription = "Grpc status code counter"
|
||||
)
|
||||
|
||||
func MetricsHandler(metricTypes []metrics.MetricType, ignoredMethodSuffixes ...string) connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
return RegisterMetrics(ctx, req, handler, metricTypes, ignoredMethodSuffixes...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func RegisterMetrics(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, metricTypes []metrics.MetricType, ignoredMethodSuffixes ...string) (_ connect.AnyResponse, err error) {
|
||||
if len(metricTypes) == 0 {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
for _, ignore := range ignoredMethodSuffixes {
|
||||
if strings.HasSuffix(req.Spec().Procedure, ignore) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
resp, err := handler(ctx, req)
|
||||
if containsMetricsMethod(metrics.MetricTypeRequestCount, metricTypes) {
|
||||
RegisterGrpcRequestCounter(ctx, req.Spec().Procedure)
|
||||
}
|
||||
if containsMetricsMethod(metrics.MetricTypeTotalCount, metricTypes) {
|
||||
RegisterGrpcTotalRequestCounter(ctx)
|
||||
}
|
||||
if containsMetricsMethod(metrics.MetricTypeStatusCode, metricTypes) {
|
||||
RegisterGrpcRequestCodeCounter(ctx, req.Spec().Procedure, err)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func RegisterGrpcRequestCounter(ctx context.Context, path string) {
|
||||
var labels = map[string]attribute.Value{
|
||||
GrpcMethod: attribute.StringValue(path),
|
||||
}
|
||||
err := metrics.RegisterCounter(GrpcRequestCounter, GrpcRequestCounterDescription)
|
||||
logging.OnError(err).Warn("failed to register grpc request counter")
|
||||
err = metrics.AddCount(ctx, GrpcRequestCounter, 1, labels)
|
||||
logging.OnError(err).Warn("failed to add grpc request count")
|
||||
}
|
||||
|
||||
func RegisterGrpcTotalRequestCounter(ctx context.Context) {
|
||||
err := metrics.RegisterCounter(TotalGrpcRequestCounter, TotalGrpcRequestCounterDescription)
|
||||
logging.OnError(err).Warn("failed to register total grpc request counter")
|
||||
err = metrics.AddCount(ctx, TotalGrpcRequestCounter, 1, nil)
|
||||
logging.OnError(err).Warn("failed to add total grpc request count")
|
||||
}
|
||||
|
||||
func RegisterGrpcRequestCodeCounter(ctx context.Context, path string, err error) {
|
||||
statusCode := connect.CodeOf(err)
|
||||
var labels = map[string]attribute.Value{
|
||||
GrpcMethod: attribute.StringValue(path),
|
||||
ReturnCode: attribute.IntValue(runtime.HTTPStatusFromCode(codes.Code(statusCode))),
|
||||
}
|
||||
err = metrics.RegisterCounter(GrpcStatusCodeCounter, GrpcStatusCodeCounterDescription)
|
||||
logging.OnError(err).Warn("failed to register grpc status code counter")
|
||||
err = metrics.AddCount(ctx, GrpcStatusCodeCounter, 1, labels)
|
||||
logging.OnError(err).Warn("failed to add grpc status code count")
|
||||
}
|
||||
|
||||
func containsMetricsMethod(metricType metrics.MetricType, metricTypes []metrics.MetricType) bool {
|
||||
for _, m := range metricTypes {
|
||||
if m == metricType {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
50
internal/api/grpc/server/connect_middleware/mock_test.go
Normal file
50
internal/api/grpc/server/connect_middleware/mock_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func emptyMockHandler(resp connect.AnyResponse, expectedCtxData authz.CtxData) func(*testing.T) connect.UnaryFunc {
|
||||
return func(t *testing.T) connect.UnaryFunc {
|
||||
return func(ctx context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
assert.Equal(t, expectedCtxData, authz.GetCtxData(ctx))
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func errorMockHandler() func(*testing.T) connect.UnaryFunc {
|
||||
return func(t *testing.T) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
return nil, zerrors.ThrowInternal(nil, "test", "error")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type mockReq[t any] struct {
|
||||
connect.Request[t]
|
||||
|
||||
procedure string
|
||||
header http.Header
|
||||
}
|
||||
|
||||
func (m *mockReq[T]) Spec() connect.Spec {
|
||||
return connect.Spec{
|
||||
Procedure: m.procedure,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockReq[T]) Header() http.Header {
|
||||
if m.header == nil {
|
||||
m.header = make(http.Header)
|
||||
}
|
||||
return m.header
|
||||
}
|
@@ -0,0 +1,53 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/record"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
func QuotaExhaustedInterceptor(svc *logstore.Service[*record.AccessLog], ignoreService ...string) connect.UnaryInterceptorFunc {
|
||||
for idx, service := range ignoreService {
|
||||
if !strings.HasPrefix(service, "/") {
|
||||
ignoreService[idx] = "/" + service
|
||||
}
|
||||
}
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
|
||||
if !svc.Enabled() {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
// The auth interceptor will ensure that only authorized or public requests are allowed.
|
||||
// So if there's no authorization context, we don't need to check for limitation
|
||||
// Also, we don't limit calls with system user tokens
|
||||
ctxData := authz.GetCtxData(ctx)
|
||||
if ctxData.IsZero() || ctxData.SystemMemberships != nil {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
for _, service := range ignoreService {
|
||||
if strings.HasPrefix(req.Spec().Procedure, service) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
instance := authz.GetInstance(ctx)
|
||||
remaining := svc.Limit(interceptorCtx, instance.InstanceID())
|
||||
if remaining != nil && *remaining == 0 {
|
||||
return nil, zerrors.ThrowResourceExhausted(nil, "QUOTA-vjAy8", "Quota.Access.Exhausted")
|
||||
}
|
||||
span.End()
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
@@ -0,0 +1,45 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/service"
|
||||
_ "github.com/zitadel/zitadel/internal/statik"
|
||||
)
|
||||
|
||||
const (
|
||||
unknown = "UNKNOWN"
|
||||
)
|
||||
|
||||
func ServiceHandler() connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
serviceName, _ := serviceAndMethod(req.Spec().Procedure)
|
||||
if serviceName != unknown {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
ctx = service.WithService(ctx, serviceName)
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// serviceAndMethod returns the service and method from a procedure.
|
||||
func serviceAndMethod(procedure string) (string, string) {
|
||||
procedure = strings.TrimPrefix(procedure, "/")
|
||||
serviceName, method := unknown, unknown
|
||||
if strings.Contains(procedure, "/") {
|
||||
long := strings.Split(procedure, "/")[0]
|
||||
if strings.Contains(long, ".") {
|
||||
split := strings.Split(long, ".")
|
||||
serviceName = split[len(split)-1]
|
||||
}
|
||||
}
|
||||
if strings.Contains(procedure, "/") {
|
||||
method = strings.Split(procedure, "/")[1]
|
||||
}
|
||||
return serviceName, method
|
||||
}
|
@@ -0,0 +1,48 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/i18n"
|
||||
_ "github.com/zitadel/zitadel/internal/statik"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func TranslationHandler() connect.UnaryInterceptorFunc {
|
||||
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
resp, err := handler(ctx, req)
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
if err != nil {
|
||||
translator, translatorError := getTranslator(ctx)
|
||||
if translatorError != nil {
|
||||
return resp, err
|
||||
}
|
||||
return resp, translateError(ctx, err, translator)
|
||||
}
|
||||
if loc, ok := resp.Any().(localizers); ok {
|
||||
translator, translatorError := getTranslator(ctx)
|
||||
if translatorError != nil {
|
||||
return resp, err
|
||||
}
|
||||
translateFields(ctx, loc, translator)
|
||||
}
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getTranslator(ctx context.Context) (*i18n.Translator, error) {
|
||||
translator, err := i18n.NewZitadelTranslator(authz.GetInstance(ctx).DefaultLanguage())
|
||||
if err != nil {
|
||||
logging.New().WithError(err).Error("could not load translator")
|
||||
}
|
||||
return translator, err
|
||||
}
|
37
internal/api/grpc/server/connect_middleware/translator.go
Normal file
37
internal/api/grpc/server/connect_middleware/translator.go
Normal file
@@ -0,0 +1,37 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/i18n"
|
||||
"github.com/zitadel/zitadel/internal/zerrors"
|
||||
)
|
||||
|
||||
type localizers interface {
|
||||
Localizers() []Localizer
|
||||
}
|
||||
type Localizer interface {
|
||||
LocalizationKey() string
|
||||
SetLocalizedMessage(string)
|
||||
}
|
||||
|
||||
func translateFields(ctx context.Context, object localizers, translator *i18n.Translator) {
|
||||
if translator == nil || object == nil {
|
||||
return
|
||||
}
|
||||
for _, field := range object.Localizers() {
|
||||
field.SetLocalizedMessage(translator.LocalizeFromCtx(ctx, field.LocalizationKey(), nil))
|
||||
}
|
||||
}
|
||||
|
||||
func translateError(ctx context.Context, err error, translator *i18n.Translator) error {
|
||||
if translator == nil || err == nil {
|
||||
return err
|
||||
}
|
||||
caosErr := new(zerrors.ZitadelError)
|
||||
if errors.As(err, &caosErr) {
|
||||
caosErr.SetMessage(translator.LocalizeFromCtx(ctx, caosErr.GetMessage(), nil))
|
||||
}
|
||||
return err
|
||||
}
|
@@ -0,0 +1,36 @@
|
||||
package connect_middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
// import to make sure go.mod does not lose it
|
||||
// because dependency is only needed for generated code
|
||||
_ "github.com/envoyproxy/protoc-gen-validate/validate"
|
||||
)
|
||||
|
||||
func ValidationHandler() connect.UnaryInterceptorFunc {
|
||||
return func(handler connect.UnaryFunc) connect.UnaryFunc {
|
||||
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
|
||||
return validate(ctx, req, handler)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// validator interface needed for github.com/envoyproxy/protoc-gen-validate
|
||||
// (it does not expose an interface itself)
|
||||
type validator interface {
|
||||
Validate() error
|
||||
}
|
||||
|
||||
func validate(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc) (connect.AnyResponse, error) {
|
||||
validate, ok := req.Any().(validator)
|
||||
if !ok {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
err := validate.Validate()
|
||||
if err != nil {
|
||||
return nil, connect.NewError(connect.CodeInvalidArgument, err)
|
||||
}
|
||||
return handler(ctx, req)
|
||||
}
|
@@ -171,7 +171,7 @@ func CreateGateway(
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RegisterGateway(ctx context.Context, gateway *Gateway, server Server) error {
|
||||
func RegisterGateway(ctx context.Context, gateway *Gateway, server WithGateway) error {
|
||||
err := server.RegisterGateway()(ctx, gateway.mux, gateway.connection)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to register grpc gateway: %w", err)
|
||||
|
@@ -2,11 +2,14 @@ package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
"net/http"
|
||||
|
||||
"connectrpc.com/connect"
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
healthpb "google.golang.org/grpc/health/grpc_health_v1"
|
||||
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
grpc_api "github.com/zitadel/zitadel/internal/api/grpc"
|
||||
@@ -19,21 +22,36 @@ import (
|
||||
)
|
||||
|
||||
type Server interface {
|
||||
RegisterServer(*grpc.Server)
|
||||
RegisterGateway() RegisterGatewayFunc
|
||||
AppName() string
|
||||
MethodPrefix() string
|
||||
AuthMethods() authz.MethodMapping
|
||||
}
|
||||
|
||||
type GrpcServer interface {
|
||||
Server
|
||||
RegisterServer(*grpc.Server)
|
||||
}
|
||||
|
||||
type WithGateway interface {
|
||||
Server
|
||||
RegisterGateway() RegisterGatewayFunc
|
||||
}
|
||||
|
||||
// WithGatewayPrefix extends the server interface with a prefix for the grpc gateway
|
||||
//
|
||||
// it's used for the System, Admin, Mgmt and Auth API
|
||||
type WithGatewayPrefix interface {
|
||||
Server
|
||||
GrpcServer
|
||||
WithGateway
|
||||
GatewayPathPrefix() string
|
||||
}
|
||||
|
||||
type ConnectServer interface {
|
||||
Server
|
||||
RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler)
|
||||
FileDescriptor() protoreflect.FileDescriptor
|
||||
}
|
||||
|
||||
func CreateServer(
|
||||
verifier authz.APITokenVerifier,
|
||||
systemAuthz authz.Config,
|
||||
|
Reference in New Issue
Block a user