feat: exchange gRPC server implementation to connectRPC (#10145)

# Which Problems Are Solved

The current maintained gRPC server in combination with a REST (grpc)
gateway is getting harder and harder to maintain. Additionally, there
have been and still are issues with supporting / displaying `oneOf`s
correctly.
We therefore decided to exchange the server implementation to
connectRPC, which apart from supporting connect as protocol, also also
"standard" gRCP clients as well as HTTP/1.1 / rest like clients, e.g.
curl directly call the server without any additional gateway.

# How the Problems Are Solved

- All v2 services are moved to connectRPC implementation. (v1 services
are still served as pure grpc servers)
- All gRPC server interceptors were migrated / copied to a corresponding
connectRPC interceptor.
- API.ListGrpcServices and API. ListGrpcMethods were changed to include
the connect services and endpoints.
- gRPC server reflection was changed to a `StaticReflector` using the
`ListGrpcServices` list.
- The `grpc.Server` interfaces was split into different combinations to
be able to handle the different cases (grpc server and prefixed gateway,
connect server with grpc gateway, connect server only, ...)
- Docs of services serving connectRPC only with no additional gateway
(instance, webkey, project, app, org v2 beta) are changed to expose that
- since the plugin is not yet available on buf, we download it using
`postinstall` hook of the docs

# Additional Changes

- WebKey service is added as v2 service (in addition to the current
v2beta)

# Additional Context

closes #9483

---------

Co-authored-by: Elio Bischof <elio@zitadel.com>
This commit is contained in:
Livio Spring
2025-07-04 10:06:20 -04:00
committed by GitHub
parent 82cd1cee08
commit 9ebf2316c6
133 changed files with 5191 additions and 1187 deletions

View File

@@ -0,0 +1,57 @@
package connect_middleware
import (
"context"
"net/http"
"time"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/authz"
http_util "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func AccessStorageInterceptor(svc *logstore.Service[*record.AccessLog]) connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
if !svc.Enabled() {
return handler(ctx, req)
}
resp, handlerErr := handler(ctx, req)
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
var respStatus uint32
if code := connect.CodeOf(handlerErr); code != connect.CodeUnknown {
respStatus = uint32(code)
}
respHeader := http.Header{}
if resp != nil {
respHeader = resp.Header()
}
instance := authz.GetInstance(ctx)
domainCtx := http_util.DomainContext(ctx)
r := &record.AccessLog{
LogDate: time.Now(),
Protocol: record.GRPC,
RequestURL: req.Spec().Procedure,
ResponseStatus: respStatus,
RequestHeaders: req.Header(),
ResponseHeaders: respHeader,
InstanceID: instance.InstanceID(),
ProjectID: instance.ProjectID(),
RequestedDomain: domainCtx.RequestedDomain(),
RequestedHost: domainCtx.RequestedHost(),
}
svc.Handle(interceptorCtx, r)
return resp, handlerErr
}
}
}

View File

@@ -0,0 +1,52 @@
package connect_middleware
import (
"context"
"net/http"
"slices"
"strings"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/activity"
"github.com/zitadel/zitadel/internal/api/grpc/gerrors"
ainfo "github.com/zitadel/zitadel/internal/api/info"
)
func ActivityInterceptor() connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
ctx = activityInfoFromGateway(ctx, req.Header()).SetMethod(req.Spec().Procedure).IntoContext(ctx)
resp, err := handler(ctx, req)
if isResourceAPI(req.Spec().Procedure) {
code, _, _, _ := gerrors.ExtractZITADELError(err)
ctx = ainfo.ActivityInfoFromContext(ctx).SetGRPCStatus(code).IntoContext(ctx)
activity.TriggerGRPCWithContext(ctx, activity.ResourceAPI)
}
return resp, err
}
}
}
var resourcePrefixes = []string{
"/zitadel.management.v1.ManagementService/",
"/zitadel.admin.v1.AdminService/",
"/zitadel.user.v2.UserService/",
"/zitadel.settings.v2.SettingsService/",
"/zitadel.user.v2beta.UserService/",
"/zitadel.settings.v2beta.SettingsService/",
"/zitadel.auth.v1.AuthService/",
}
func isResourceAPI(method string) bool {
return slices.ContainsFunc(resourcePrefixes, func(prefix string) bool {
return strings.HasPrefix(method, prefix)
})
}
func activityInfoFromGateway(ctx context.Context, headers http.Header) *ainfo.ActivityInfo {
info := ainfo.ActivityInfoFromContext(ctx)
path := headers.Get(activity.PathKey)
requestMethod := headers.Get(activity.RequestMethodKey)
return info.SetPath(path).SetRequestMethod(requestMethod)
}

View File

@@ -0,0 +1,65 @@
package connect_middleware
import (
"context"
"errors"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func AuthorizationInterceptor(verifier authz.APITokenVerifier, systemUserPermissions authz.Config, authConfig authz.Config) connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
return authorize(ctx, req, handler, verifier, systemUserPermissions, authConfig)
}
}
}
func authorize(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.APITokenVerifier, systemUserPermissions authz.Config, authConfig authz.Config) (_ connect.AnyResponse, err error) {
authOpt, needsToken := verifier.CheckAuthMethod(req.Spec().Procedure)
if !needsToken {
return handler(ctx, req)
}
authCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
authToken := req.Header().Get(http.Authorization)
if authToken == "" {
return nil, connect.NewError(connect.CodeUnauthenticated, errors.New("auth header missing"))
}
orgID, orgDomain := orgIDAndDomainFromRequest(req)
ctxSetter, err := authz.CheckUserAuthorization(authCtx, req, authToken, orgID, orgDomain, verifier, systemUserPermissions.RolePermissionMappings, authConfig.RolePermissionMappings, authOpt, req.Spec().Procedure)
if err != nil {
return nil, err
}
span.End()
return handler(ctxSetter(ctx), req)
}
func orgIDAndDomainFromRequest(req connect.AnyRequest) (id, domain string) {
orgID := req.Header().Get(http.ZitadelOrgID)
oz, ok := req.Any().(OrganizationFromRequest)
if ok {
id = oz.OrganizationFromRequestConnect().ID
domain = oz.OrganizationFromRequestConnect().Domain
if id != "" || domain != "" {
return id, domain
}
}
return orgID, domain
}
type Organization struct {
ID string
Domain string
}
type OrganizationFromRequest interface {
OrganizationFromRequestConnect() *Organization
}

View File

@@ -0,0 +1,318 @@
package connect_middleware
import (
"context"
"errors"
"net/http"
"reflect"
"testing"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/zerrors"
)
const anAPIRole = "AN_API_ROLE"
type authzRepoMock struct{}
func (v *authzRepoMock) VerifyAccessToken(ctx context.Context, token, clientID, projectID string) (string, string, string, string, string, error) {
return "", "", "", "", "", nil
}
func (v *authzRepoMock) SearchMyMemberships(ctx context.Context, orgID string, _ bool) ([]*authz.Membership, error) {
return authz.Memberships{{
MemberType: authz.MemberTypeOrganization,
AggregateID: orgID,
Roles: []string{anAPIRole},
}}, nil
}
func (v *authzRepoMock) ProjectIDAndOriginsByClientID(ctx context.Context, clientID string) (string, []string, error) {
return "", nil, nil
}
func (v *authzRepoMock) ExistsOrg(ctx context.Context, orgID, domain string) (string, error) {
return orgID, nil
}
func (v *authzRepoMock) VerifierClientID(ctx context.Context, appName string) (string, string, error) {
return "", "", nil
}
var (
accessTokenOK = authz.AccessTokenVerifierFunc(func(ctx context.Context, token string) (userID string, clientID string, agentID string, prefLan string, resourceOwner string, err error) {
return "user1", "", "", "", "org1", nil
})
accessTokenNOK = authz.AccessTokenVerifierFunc(func(ctx context.Context, token string) (userID string, clientID string, agentID string, prefLan string, resourceOwner string, err error) {
return "", "", "", "", "", zerrors.ThrowUnauthenticated(nil, "TEST-fQHDI", "unauthenticaded")
})
systemTokenNOK = authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) {
return nil, "", errors.New("system token error")
})
)
type mockOrgFromRequest struct {
id string
}
func (m *mockOrgFromRequest) OrganizationFromRequestConnect() *Organization {
return &Organization{
ID: m.id,
Domain: "",
}
}
func Test_authorize(t *testing.T) {
type args struct {
ctx context.Context
req connect.AnyRequest
handler func(t *testing.T) connect.UnaryFunc
verifier func() authz.APITokenVerifier
authConfig authz.Config
}
type res struct {
want interface{}
wantErr bool
}
tests := []struct {
name string
args args
res res
}{
{
"no token needed ok",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/no/token/needed"},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
verifier.RegisterServer("need", "need", authz.MethodMapping{})
return verifier
},
},
res{
&connect.Response[struct{}]{},
false,
},
},
{
"auth header missing error",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/need/authentication"},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
return verifier
},
authConfig: authz.Config{},
},
res{
nil,
true,
},
},
{
"unauthorized error",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"wrong"}}},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
return verifier
},
authConfig: authz.Config{},
},
res{
nil,
true,
},
},
{
"authorized ok",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
UserID: "user1",
OrgID: "org1",
ResourceOwner: "org1",
}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
return verifier
},
authConfig: authz.Config{},
},
res{
&connect.Response[struct{}]{},
false,
},
},
{
"authorized ok, org by request",
args{
ctx: context.Background(),
req: &mockReq[mockOrgFromRequest]{
Request: connect.Request[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}},
procedure: "/need/authentication",
header: http.Header{"Authorization": []string{"Bearer token"}},
},
handler: emptyMockHandler(&connect.Response[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}}, authz.CtxData{
UserID: "user1",
OrgID: "id",
ResourceOwner: "org1",
}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "authenticated"}})
return verifier
},
authConfig: authz.Config{},
},
res{
&connect.Response[mockOrgFromRequest]{Msg: &mockOrgFromRequest{"id"}},
false,
},
},
{
"permission denied error",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
UserID: "user1",
OrgID: "org1",
ResourceOwner: "org1",
}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
return verifier
},
authConfig: authz.Config{
RolePermissionMappings: []authz.RoleMapping{{
Role: anAPIRole,
Permissions: []string{"to.do.something.else"},
}},
},
},
res{
nil,
true,
},
},
{
"permission ok",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
UserID: "user1",
OrgID: "org1",
ResourceOwner: "org1",
}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenOK, systemTokenNOK)
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
return verifier
},
authConfig: authz.Config{
RolePermissionMappings: []authz.RoleMapping{{
Role: anAPIRole,
Permissions: []string{"to.do.something"},
}},
},
},
res{
&connect.Response[struct{}]{},
false,
},
},
{
"system token permission denied error",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenNOK, authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) {
return authz.Memberships{{
MemberType: authz.MemberTypeSystem,
Roles: []string{"A_SYSTEM_ROLE"},
}}, "systemuser", nil
}))
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
return verifier
},
authConfig: authz.Config{
RolePermissionMappings: []authz.RoleMapping{{
Role: "A_SYSTEM_ROLE",
Permissions: []string{"to.do.something.else"},
}},
},
},
res{
nil,
true,
},
},
{
"system token permission denied error",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{procedure: "/need/authentication", header: http.Header{"Authorization": []string{"Bearer token"}}},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{
UserID: "systemuser",
SystemMemberships: authz.Memberships{{
MemberType: authz.MemberTypeSystem,
Roles: []string{"A_SYSTEM_ROLE"},
}},
SystemUserPermissions: []authz.SystemUserPermissions{{
MemberType: authz.MemberTypeSystem,
Permissions: []string{"to.do.something"},
}},
}),
verifier: func() authz.APITokenVerifier {
verifier := authz.StartAPITokenVerifier(&authzRepoMock{}, accessTokenNOK, authz.SystemTokenVerifierFunc(func(ctx context.Context, token string, orgID string) (memberships authz.Memberships, userID string, err error) {
return authz.Memberships{{
MemberType: authz.MemberTypeSystem,
Roles: []string{"A_SYSTEM_ROLE"},
}}, "systemuser", nil
}))
verifier.RegisterServer("need", "need", authz.MethodMapping{"/need/authentication": authz.Option{Permission: "to.do.something"}})
return verifier
},
authConfig: authz.Config{
RolePermissionMappings: []authz.RoleMapping{{
Role: "A_SYSTEM_ROLE",
Permissions: []string{"to.do.something"},
}},
},
},
res{
&connect.Response[struct{}]{},
false,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := authorize(tt.args.ctx, tt.args.req, tt.args.handler(t), tt.args.verifier(), tt.args.authConfig, tt.args.authConfig)
if (err != nil) != tt.res.wantErr {
t.Errorf("authorize() error = %v, wantErr %v", err, tt.res.wantErr)
return
}
if !reflect.DeepEqual(got, tt.res.want) {
t.Errorf("authorize() got = %v, want %v", got, tt.res.want)
}
})
}
}

View File

@@ -0,0 +1,31 @@
package connect_middleware
import (
"context"
"net/http"
"time"
"connectrpc.com/connect"
_ "github.com/zitadel/zitadel/internal/statik"
)
func NoCacheInterceptor() connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
headers := map[string]string{
"cache-control": "no-store",
"expires": time.Now().UTC().Format(http.TimeFormat),
"pragma": "no-cache",
}
resp, err := handler(ctx, req)
if err != nil {
return nil, err
}
for key, value := range headers {
resp.Header().Set(key, value)
}
return resp, err
}
}
}

View File

@@ -0,0 +1,18 @@
package connect_middleware
import (
"context"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/call"
)
func CallDurationHandler() connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
ctx = call.WithTimestamp(ctx)
return handler(ctx, req)
}
}
}

View File

@@ -0,0 +1,23 @@
package connect_middleware
import (
"context"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/grpc/gerrors"
_ "github.com/zitadel/zitadel/internal/statik"
)
func ErrorHandler() connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
return toConnectError(ctx, req, handler)
}
}
}
func toConnectError(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc) (connect.AnyResponse, error) {
resp, err := handler(ctx, req)
return resp, gerrors.ZITADELToConnectError(err) // TODO !
}

View File

@@ -0,0 +1,65 @@
package connect_middleware
import (
"context"
"reflect"
"testing"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/authz"
)
func Test_toGRPCError(t *testing.T) {
type args struct {
ctx context.Context
req connect.AnyRequest
handler func(t *testing.T) connect.UnaryFunc
}
type res struct {
want interface{}
wantErr bool
}
tests := []struct {
name string
args args
res res
}{
{
"no error",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{},
handler: emptyMockHandler(&connect.Response[struct{}]{}, authz.CtxData{}),
},
res{
&connect.Response[struct{}]{},
false,
},
},
{
"error",
args{
ctx: context.Background(),
req: &mockReq[struct{}]{},
handler: errorMockHandler(),
},
res{
nil,
true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := toConnectError(tt.args.ctx, tt.args.req, tt.args.handler(t))
if (err != nil) != tt.res.wantErr {
t.Errorf("toGRPCError() error = %v, wantErr %v", err, tt.res.wantErr)
return
}
if !reflect.DeepEqual(got, tt.res.want) {
t.Errorf("toGRPCError() got = %v, want %v", got, tt.res.want)
}
})
}
}

View File

@@ -0,0 +1,160 @@
package connect_middleware
import (
"context"
"encoding/json"
"connectrpc.com/connect"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/execution"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func ExecutionHandler(queries *query.Queries) connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
requestTargets, responseTargets := execution.QueryExecutionTargetsForRequestAndResponse(ctx, queries, req.Spec().Procedure)
// call targets otherwise return req
handledReq, err := executeTargetsForRequest(ctx, requestTargets, req.Spec().Procedure, req)
if err != nil {
return nil, err
}
response, err := handler(ctx, handledReq)
if err != nil {
return nil, err
}
return executeTargetsForResponse(ctx, responseTargets, req.Spec().Procedure, handledReq, response)
}
}
}
func executeTargetsForRequest(ctx context.Context, targets []execution.Target, fullMethod string, req connect.AnyRequest) (_ connect.AnyRequest, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// if no targets are found, return without any calls
if len(targets) == 0 {
return req, nil
}
ctxData := authz.GetCtxData(ctx)
info := &ContextInfoRequest{
FullMethod: fullMethod,
InstanceID: authz.GetInstance(ctx).InstanceID(),
ProjectID: ctxData.ProjectID,
OrgID: ctxData.OrgID,
UserID: ctxData.UserID,
Request: Message{req.Any().(proto.Message)},
}
_, err = execution.CallTargets(ctx, targets, info)
if err != nil {
return nil, err
}
return req, nil
}
func executeTargetsForResponse(ctx context.Context, targets []execution.Target, fullMethod string, req connect.AnyRequest, resp connect.AnyResponse) (_ connect.AnyResponse, err error) {
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
// if no targets are found, return without any calls
if len(targets) == 0 {
return resp, nil
}
ctxData := authz.GetCtxData(ctx)
info := &ContextInfoResponse{
FullMethod: fullMethod,
InstanceID: authz.GetInstance(ctx).InstanceID(),
ProjectID: ctxData.ProjectID,
OrgID: ctxData.OrgID,
UserID: ctxData.UserID,
Request: Message{req.Any().(proto.Message)},
Response: Message{resp.Any().(proto.Message)},
}
_, err = execution.CallTargets(ctx, targets, info)
if err != nil {
return nil, err
}
return resp, nil
}
var _ execution.ContextInfo = &ContextInfoRequest{}
type ContextInfoRequest struct {
FullMethod string `json:"fullMethod,omitempty"`
InstanceID string `json:"instanceID,omitempty"`
OrgID string `json:"orgID,omitempty"`
ProjectID string `json:"projectID,omitempty"`
UserID string `json:"userID,omitempty"`
Request Message `json:"request,omitempty"`
}
type Message struct {
proto.Message
}
func (r *Message) MarshalJSON() ([]byte, error) {
data, err := protojson.Marshal(r.Message)
if err != nil {
return nil, err
}
return data, nil
}
func (r *Message) UnmarshalJSON(data []byte) error {
return protojson.Unmarshal(data, r.Message)
}
func (c *ContextInfoRequest) GetHTTPRequestBody() []byte {
data, err := json.Marshal(c)
if err != nil {
return nil
}
return data
}
func (c *ContextInfoRequest) SetHTTPResponseBody(resp []byte) error {
return json.Unmarshal(resp, &c.Request)
}
func (c *ContextInfoRequest) GetContent() interface{} {
return c.Request.Message
}
var _ execution.ContextInfo = &ContextInfoResponse{}
type ContextInfoResponse struct {
FullMethod string `json:"fullMethod,omitempty"`
InstanceID string `json:"instanceID,omitempty"`
OrgID string `json:"orgID,omitempty"`
ProjectID string `json:"projectID,omitempty"`
UserID string `json:"userID,omitempty"`
Request Message `json:"request,omitempty"`
Response Message `json:"response,omitempty"`
}
func (c *ContextInfoResponse) GetHTTPRequestBody() []byte {
data, err := json.Marshal(c)
if err != nil {
return nil
}
return data
}
func (c *ContextInfoResponse) SetHTTPResponseBody(resp []byte) error {
return json.Unmarshal(resp, &c.Response)
}
func (c *ContextInfoResponse) GetContent() interface{} {
return c.Response.Message
}

View File

@@ -0,0 +1,815 @@
package connect_middleware
import (
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"reflect"
"testing"
"time"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"google.golang.org/protobuf/encoding/protojson"
"google.golang.org/protobuf/proto"
"google.golang.org/protobuf/types/known/structpb"
"github.com/zitadel/zitadel/internal/domain"
"github.com/zitadel/zitadel/internal/execution"
)
var _ execution.Target = &mockExecutionTarget{}
type mockExecutionTarget struct {
InstanceID string
ExecutionID string
TargetID string
TargetType domain.TargetType
Endpoint string
Timeout time.Duration
InterruptOnError bool
SigningKey string
}
func (e *mockExecutionTarget) SetEndpoint(endpoint string) {
e.Endpoint = endpoint
}
func (e *mockExecutionTarget) IsInterruptOnError() bool {
return e.InterruptOnError
}
func (e *mockExecutionTarget) GetEndpoint() string {
return e.Endpoint
}
func (e *mockExecutionTarget) GetTargetType() domain.TargetType {
return e.TargetType
}
func (e *mockExecutionTarget) GetTimeout() time.Duration {
return e.Timeout
}
func (e *mockExecutionTarget) GetTargetID() string {
return e.TargetID
}
func (e *mockExecutionTarget) GetExecutionID() string {
return e.ExecutionID
}
func (e *mockExecutionTarget) GetSigningKey() string {
return e.SigningKey
}
func newMockContentRequest(content string) *connect.Request[structpb.Struct] {
return connect.NewRequest(&structpb.Struct{
Fields: map[string]*structpb.Value{
"content": {
Kind: &structpb.Value_StringValue{StringValue: content},
},
},
})
}
func newMockContentResponse(content string) *connect.Response[structpb.Struct] {
return connect.NewResponse(&structpb.Struct{
Fields: map[string]*structpb.Value{
"content": {
Kind: &structpb.Value_StringValue{StringValue: content},
},
},
})
}
func newMockContextInfoRequest(fullMethod, request string) *ContextInfoRequest {
return &ContextInfoRequest{
FullMethod: fullMethod,
Request: Message{Message: newMockContentRequest(request).Msg},
}
}
func newMockContextInfoResponse(fullMethod, request, response string) *ContextInfoResponse {
return &ContextInfoResponse{
FullMethod: fullMethod,
Request: Message{Message: newMockContentRequest(request).Msg},
Response: Message{Message: newMockContentResponse(response).Msg},
}
}
func Test_executeTargetsForGRPCFullMethod_request(t *testing.T) {
type target struct {
reqBody execution.ContextInfo
sleep time.Duration
statusCode int
respBody connect.AnyResponse
}
type args struct {
ctx context.Context
executionTargets []execution.Target
targets []target
fullMethod string
req connect.AnyRequest
}
type res struct {
want interface{}
wantErr bool
}
tests := []struct {
name string
args args
res res
}{
{
"target, executionTargets nil",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: nil,
req: newMockContentRequest("request"),
},
res{
want: newMockContentRequest("request"),
},
},
{
"target, executionTargets empty",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{},
req: newMockContentRequest("request"),
},
res{
want: newMockContentRequest("request"),
},
},
{
"target, not reachable",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
},
},
targets: []target{},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
{
"target, error without interrupt",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 0,
statusCode: http.StatusBadRequest,
},
},
req: newMockContentRequest("content"),
},
res{
want: newMockContentRequest("content"),
},
},
{
"target, interruptOnError",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 0,
statusCode: http.StatusBadRequest,
},
},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
{
"target, timeout",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Second,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 5 * time.Second,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
{
"target, wrong request",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Second,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{reqBody: newMockContextInfoRequest("/service/method", "wrong")},
},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
{
"target, ok",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 0,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
want: newMockContentRequest("content1"),
},
},
{
"target async, timeout",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeAsync,
Timeout: time.Second,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 5 * time.Second,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
want: newMockContentRequest("content"),
},
},
{
"target async, ok",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeAsync,
Timeout: time.Minute,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 0,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
want: newMockContentRequest("content"),
},
},
{
"webhook, error",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeWebhook,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
sleep: 0,
statusCode: http.StatusInternalServerError,
},
},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
{
"webhook, timeout",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeWebhook,
Timeout: time.Second,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 5 * time.Second,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
{
"webhook, ok",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeWebhook,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 0,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
want: newMockContentRequest("content"),
},
},
{
"with includes, interruptOnError",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target1",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target2",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target3",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 0,
statusCode: http.StatusOK,
},
{
reqBody: newMockContextInfoRequest("/service/method", "content1"),
respBody: newMockContentResponse("content2"),
sleep: 0,
statusCode: http.StatusBadRequest,
},
{
reqBody: newMockContextInfoRequest("/service/method", "content2"),
respBody: newMockContentResponse("content3"),
sleep: 0,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
{
"with includes, timeout",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target1",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target2",
TargetType: domain.TargetTypeCall,
Timeout: time.Second,
InterruptOnError: true,
SigningKey: "signingkey",
},
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target3",
TargetType: domain.TargetTypeCall,
Timeout: time.Second,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse("content1"),
sleep: 0,
statusCode: http.StatusOK,
},
{
reqBody: newMockContextInfoRequest("/service/method", "content1"),
respBody: newMockContentResponse("content2"),
sleep: 5 * time.Second,
statusCode: http.StatusBadRequest,
},
{
reqBody: newMockContextInfoRequest("/service/method", "content2"),
respBody: newMockContentResponse("content3"),
sleep: 5 * time.Second,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("content"),
},
res{
wantErr: true,
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
closeFuncs := make([]func(), len(tt.args.targets))
for i, target := range tt.args.targets {
url, closeF := testServerCall(
target.reqBody,
target.sleep,
target.statusCode,
target.respBody,
)
et := tt.args.executionTargets[i].(*mockExecutionTarget)
et.SetEndpoint(url)
closeFuncs[i] = closeF
}
resp, err := executeTargetsForRequest(
tt.args.ctx,
tt.args.executionTargets,
tt.args.fullMethod,
tt.args.req,
)
if tt.res.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.EqualExportedValues(t, tt.res.want, resp)
for _, closeF := range closeFuncs {
closeF()
}
})
}
}
func testServerCall(
reqBody interface{},
sleep time.Duration,
statusCode int,
respBody connect.AnyResponse,
) (string, func()) {
handler := func(w http.ResponseWriter, r *http.Request) {
data, err := json.Marshal(reqBody)
if err != nil {
http.Error(w, "error", http.StatusInternalServerError)
return
}
sentBody, err := io.ReadAll(r.Body)
if err != nil {
http.Error(w, "error", http.StatusInternalServerError)
return
}
if !reflect.DeepEqual(data, sentBody) {
http.Error(w, "error", http.StatusInternalServerError)
return
}
if statusCode != http.StatusOK {
http.Error(w, "error", statusCode)
return
}
time.Sleep(sleep)
w.Header().Set("Content-Type", "application/json")
resp, err := protojson.Marshal(respBody.Any().(proto.Message))
if err != nil {
http.Error(w, "error", http.StatusInternalServerError)
return
}
if _, err := w.Write(resp); err != nil {
http.Error(w, "error", http.StatusInternalServerError)
return
}
}
server := httptest.NewServer(http.HandlerFunc(handler))
return server.URL, server.Close
}
func Test_executeTargetsForGRPCFullMethod_response(t *testing.T) {
type target struct {
reqBody execution.ContextInfo
sleep time.Duration
statusCode int
respBody connect.AnyResponse
}
type args struct {
ctx context.Context
executionTargets []execution.Target
targets []target
fullMethod string
req connect.AnyRequest
resp connect.AnyResponse
}
type res struct {
want interface{}
wantErr bool
}
tests := []struct {
name string
args args
res res
}{
{
"target, executionTargets nil",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: nil,
req: newMockContentRequest("request"),
resp: newMockContentResponse("response"),
},
res{
want: newMockContentResponse("response"),
},
},
{
"target, executionTargets empty",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{},
req: newMockContentRequest("request"),
resp: newMockContentResponse("response"),
},
res{
want: newMockContentResponse("response"),
},
},
{
"target, empty response",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "request./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoRequest("/service/method", "content"),
respBody: newMockContentResponse(""),
sleep: 0,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest(""),
resp: newMockContentResponse(""),
},
res{
wantErr: true,
},
},
{
"target, ok",
args{
ctx: context.Background(),
fullMethod: "/service/method",
executionTargets: []execution.Target{
&mockExecutionTarget{
InstanceID: "instance",
ExecutionID: "response./zitadel.session.v2.SessionService/SetSession",
TargetID: "target",
TargetType: domain.TargetTypeCall,
Timeout: time.Minute,
InterruptOnError: true,
SigningKey: "signingkey",
},
},
targets: []target{
{
reqBody: newMockContextInfoResponse("/service/method", "request", "response"),
respBody: newMockContentResponse("response1"),
sleep: 0,
statusCode: http.StatusOK,
},
},
req: newMockContentRequest("request"),
resp: newMockContentResponse("response"),
},
res{
want: newMockContentResponse("response1"),
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
closeFuncs := make([]func(), len(tt.args.targets))
for i, target := range tt.args.targets {
url, closeF := testServerCall(
target.reqBody,
target.sleep,
target.statusCode,
target.respBody,
)
et := tt.args.executionTargets[i].(*mockExecutionTarget)
et.SetEndpoint(url)
closeFuncs[i] = closeF
}
resp, err := executeTargetsForResponse(
tt.args.ctx,
tt.args.executionTargets,
tt.args.fullMethod,
tt.args.req,
tt.args.resp,
)
if tt.res.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
assert.EqualExportedValues(t, tt.res.want, resp)
for _, closeF := range closeFuncs {
closeF()
}
})
}
}

View File

@@ -0,0 +1,107 @@
package connect_middleware
import (
"context"
"errors"
"fmt"
"strings"
"connectrpc.com/connect"
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/i18n"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
object_v3 "github.com/zitadel/zitadel/pkg/grpc/object/v3alpha"
)
func InstanceInterceptor(verifier authz.InstanceVerifier, externalDomain string, explicitInstanceIdServices ...string) connect.UnaryInterceptorFunc {
translator, err := i18n.NewZitadelTranslator(language.English)
logging.OnError(err).Panic("unable to get translator")
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
return setInstance(ctx, req, handler, verifier, externalDomain, translator, explicitInstanceIdServices...)
}
}
}
func setInstance(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, externalDomain string, translator *i18n.Translator, idFromRequestsServices ...string) (_ connect.AnyResponse, err error) {
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
for _, service := range idFromRequestsServices {
if !strings.HasPrefix(service, "/") {
service = "/" + service
}
if strings.HasPrefix(req.Spec().Procedure, service) {
withInstanceIDProperty, ok := req.Any().(interface {
GetInstanceId() string
})
if !ok {
return handler(ctx, req)
}
return addInstanceByID(interceptorCtx, req, handler, verifier, translator, withInstanceIDProperty.GetInstanceId())
}
}
explicitInstanceRequest, ok := req.Any().(interface {
GetInstance() *object_v3.Instance
})
if ok {
instance := explicitInstanceRequest.GetInstance()
if id := instance.GetId(); id != "" {
return addInstanceByID(interceptorCtx, req, handler, verifier, translator, id)
}
if domain := instance.GetDomain(); domain != "" {
return addInstanceByDomain(interceptorCtx, req, handler, verifier, translator, domain)
}
}
return addInstanceByRequestedHost(interceptorCtx, req, handler, verifier, translator, externalDomain)
}
func addInstanceByID(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, id string) (connect.AnyResponse, error) {
instance, err := verifier.InstanceByID(ctx, id)
if err != nil {
notFoundErr := new(zerrors.ZitadelError)
if errors.As(err, &notFoundErr) {
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
}
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using id %s: %w", id, notFoundErr))
}
return handler(authz.WithInstance(ctx, instance), req)
}
func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, domain string) (connect.AnyResponse, error) {
instance, err := verifier.InstanceByHost(ctx, domain, "")
if err != nil {
notFoundErr := new(zerrors.NotFoundError)
if errors.As(err, &notFoundErr) {
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
}
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using domain %s: %w", domain, notFoundErr))
}
return handler(authz.WithInstance(ctx, instance), req)
}
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
requestContext := zitadel_http.DomainContext(ctx)
if requestContext.InstanceHost == "" {
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
}
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
if err != nil {
origin := zitadel_http.DomainContext(ctx)
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
zErr := new(zerrors.ZitadelError)
if errors.As(err, &zErr) {
zErr.SetMessage(translator.LocalizeFromCtx(ctx, zErr.GetMessage(), nil))
zErr.Parent = err
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using origin %s (ExternalDomain is %s): %s", origin, externalDomain, zErr.Error()))
}
return nil, connect.NewError(connect.CodeNotFound, fmt.Errorf("unable to set instance using origin %s (ExternalDomain is %s)", origin, externalDomain))
}
return handler(authz.WithInstance(ctx, instance), req)
}

View File

@@ -0,0 +1,34 @@
package connect_middleware
import (
"context"
"strings"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/zerrors"
)
func LimitsInterceptor(ignoreService ...string) connect.UnaryInterceptorFunc {
for idx, service := range ignoreService {
if !strings.HasPrefix(service, "/") {
ignoreService[idx] = "/" + service
}
}
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
for _, service := range ignoreService {
if strings.HasPrefix(req.Spec().Procedure, service) {
return handler(ctx, req)
}
}
instance := authz.GetInstance(ctx)
if block := instance.Block(); block != nil && *block {
return nil, zerrors.ThrowResourceExhausted(nil, "LIMITS-molsj", "Errors.Limits.Instance.Blocked")
}
return handler(ctx, req)
}
}
}

View File

@@ -0,0 +1,96 @@
package connect_middleware
import (
"context"
"strings"
"connectrpc.com/connect"
"github.com/grpc-ecosystem/grpc-gateway/runtime"
"github.com/zitadel/logging"
"go.opentelemetry.io/otel/attribute"
"google.golang.org/grpc/codes"
_ "github.com/zitadel/zitadel/internal/statik"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
)
const (
GrpcMethod = "grpc_method"
ReturnCode = "return_code"
GrpcRequestCounter = "grpc.server.request_counter"
GrpcRequestCounterDescription = "Grpc request counter"
TotalGrpcRequestCounter = "grpc.server.total_request_counter"
TotalGrpcRequestCounterDescription = "Total grpc request counter"
GrpcStatusCodeCounter = "grpc.server.grpc_status_code"
GrpcStatusCodeCounterDescription = "Grpc status code counter"
)
func MetricsHandler(metricTypes []metrics.MetricType, ignoredMethodSuffixes ...string) connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
return RegisterMetrics(ctx, req, handler, metricTypes, ignoredMethodSuffixes...)
}
}
}
func RegisterMetrics(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, metricTypes []metrics.MetricType, ignoredMethodSuffixes ...string) (_ connect.AnyResponse, err error) {
if len(metricTypes) == 0 {
return handler(ctx, req)
}
for _, ignore := range ignoredMethodSuffixes {
if strings.HasSuffix(req.Spec().Procedure, ignore) {
return handler(ctx, req)
}
}
resp, err := handler(ctx, req)
if containsMetricsMethod(metrics.MetricTypeRequestCount, metricTypes) {
RegisterGrpcRequestCounter(ctx, req.Spec().Procedure)
}
if containsMetricsMethod(metrics.MetricTypeTotalCount, metricTypes) {
RegisterGrpcTotalRequestCounter(ctx)
}
if containsMetricsMethod(metrics.MetricTypeStatusCode, metricTypes) {
RegisterGrpcRequestCodeCounter(ctx, req.Spec().Procedure, err)
}
return resp, err
}
func RegisterGrpcRequestCounter(ctx context.Context, path string) {
var labels = map[string]attribute.Value{
GrpcMethod: attribute.StringValue(path),
}
err := metrics.RegisterCounter(GrpcRequestCounter, GrpcRequestCounterDescription)
logging.OnError(err).Warn("failed to register grpc request counter")
err = metrics.AddCount(ctx, GrpcRequestCounter, 1, labels)
logging.OnError(err).Warn("failed to add grpc request count")
}
func RegisterGrpcTotalRequestCounter(ctx context.Context) {
err := metrics.RegisterCounter(TotalGrpcRequestCounter, TotalGrpcRequestCounterDescription)
logging.OnError(err).Warn("failed to register total grpc request counter")
err = metrics.AddCount(ctx, TotalGrpcRequestCounter, 1, nil)
logging.OnError(err).Warn("failed to add total grpc request count")
}
func RegisterGrpcRequestCodeCounter(ctx context.Context, path string, err error) {
statusCode := connect.CodeOf(err)
var labels = map[string]attribute.Value{
GrpcMethod: attribute.StringValue(path),
ReturnCode: attribute.IntValue(runtime.HTTPStatusFromCode(codes.Code(statusCode))),
}
err = metrics.RegisterCounter(GrpcStatusCodeCounter, GrpcStatusCodeCounterDescription)
logging.OnError(err).Warn("failed to register grpc status code counter")
err = metrics.AddCount(ctx, GrpcStatusCodeCounter, 1, labels)
logging.OnError(err).Warn("failed to add grpc status code count")
}
func containsMetricsMethod(metricType metrics.MetricType, metricTypes []metrics.MetricType) bool {
for _, m := range metricTypes {
if m == metricType {
return true
}
}
return false
}

View File

@@ -0,0 +1,50 @@
package connect_middleware
import (
"context"
"net/http"
"testing"
"connectrpc.com/connect"
"github.com/stretchr/testify/assert"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/zerrors"
)
func emptyMockHandler(resp connect.AnyResponse, expectedCtxData authz.CtxData) func(*testing.T) connect.UnaryFunc {
return func(t *testing.T) connect.UnaryFunc {
return func(ctx context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) {
assert.Equal(t, expectedCtxData, authz.GetCtxData(ctx))
return resp, nil
}
}
}
func errorMockHandler() func(*testing.T) connect.UnaryFunc {
return func(t *testing.T) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
return nil, zerrors.ThrowInternal(nil, "test", "error")
}
}
}
type mockReq[t any] struct {
connect.Request[t]
procedure string
header http.Header
}
func (m *mockReq[T]) Spec() connect.Spec {
return connect.Spec{
Procedure: m.procedure,
}
}
func (m *mockReq[T]) Header() http.Header {
if m.header == nil {
m.header = make(http.Header)
}
return m.header
}

View File

@@ -0,0 +1,53 @@
package connect_middleware
import (
"context"
"strings"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/record"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
"github.com/zitadel/zitadel/internal/zerrors"
)
func QuotaExhaustedInterceptor(svc *logstore.Service[*record.AccessLog], ignoreService ...string) connect.UnaryInterceptorFunc {
for idx, service := range ignoreService {
if !strings.HasPrefix(service, "/") {
ignoreService[idx] = "/" + service
}
}
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (_ connect.AnyResponse, err error) {
if !svc.Enabled() {
return handler(ctx, req)
}
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
// The auth interceptor will ensure that only authorized or public requests are allowed.
// So if there's no authorization context, we don't need to check for limitation
// Also, we don't limit calls with system user tokens
ctxData := authz.GetCtxData(ctx)
if ctxData.IsZero() || ctxData.SystemMemberships != nil {
return handler(ctx, req)
}
for _, service := range ignoreService {
if strings.HasPrefix(req.Spec().Procedure, service) {
return handler(ctx, req)
}
}
instance := authz.GetInstance(ctx)
remaining := svc.Limit(interceptorCtx, instance.InstanceID())
if remaining != nil && *remaining == 0 {
return nil, zerrors.ThrowResourceExhausted(nil, "QUOTA-vjAy8", "Quota.Access.Exhausted")
}
span.End()
return handler(ctx, req)
}
}
}

View File

@@ -0,0 +1,45 @@
package connect_middleware
import (
"context"
"strings"
"connectrpc.com/connect"
"github.com/zitadel/zitadel/internal/api/service"
_ "github.com/zitadel/zitadel/internal/statik"
)
const (
unknown = "UNKNOWN"
)
func ServiceHandler() connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
serviceName, _ := serviceAndMethod(req.Spec().Procedure)
if serviceName != unknown {
return handler(ctx, req)
}
ctx = service.WithService(ctx, serviceName)
return handler(ctx, req)
}
}
}
// serviceAndMethod returns the service and method from a procedure.
func serviceAndMethod(procedure string) (string, string) {
procedure = strings.TrimPrefix(procedure, "/")
serviceName, method := unknown, unknown
if strings.Contains(procedure, "/") {
long := strings.Split(procedure, "/")[0]
if strings.Contains(long, ".") {
split := strings.Split(long, ".")
serviceName = split[len(split)-1]
}
}
if strings.Contains(procedure, "/") {
method = strings.Split(procedure, "/")[1]
}
return serviceName, method
}

View File

@@ -0,0 +1,48 @@
package connect_middleware
import (
"context"
"connectrpc.com/connect"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/i18n"
_ "github.com/zitadel/zitadel/internal/statik"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
func TranslationHandler() connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
resp, err := handler(ctx, req)
ctx, span := tracing.NewSpan(ctx)
defer func() { span.EndWithError(err) }()
if err != nil {
translator, translatorError := getTranslator(ctx)
if translatorError != nil {
return resp, err
}
return resp, translateError(ctx, err, translator)
}
if loc, ok := resp.Any().(localizers); ok {
translator, translatorError := getTranslator(ctx)
if translatorError != nil {
return resp, err
}
translateFields(ctx, loc, translator)
}
return resp, nil
}
}
}
func getTranslator(ctx context.Context) (*i18n.Translator, error) {
translator, err := i18n.NewZitadelTranslator(authz.GetInstance(ctx).DefaultLanguage())
if err != nil {
logging.New().WithError(err).Error("could not load translator")
}
return translator, err
}

View File

@@ -0,0 +1,37 @@
package connect_middleware
import (
"context"
"errors"
"github.com/zitadel/zitadel/internal/i18n"
"github.com/zitadel/zitadel/internal/zerrors"
)
type localizers interface {
Localizers() []Localizer
}
type Localizer interface {
LocalizationKey() string
SetLocalizedMessage(string)
}
func translateFields(ctx context.Context, object localizers, translator *i18n.Translator) {
if translator == nil || object == nil {
return
}
for _, field := range object.Localizers() {
field.SetLocalizedMessage(translator.LocalizeFromCtx(ctx, field.LocalizationKey(), nil))
}
}
func translateError(ctx context.Context, err error, translator *i18n.Translator) error {
if translator == nil || err == nil {
return err
}
caosErr := new(zerrors.ZitadelError)
if errors.As(err, &caosErr) {
caosErr.SetMessage(translator.LocalizeFromCtx(ctx, caosErr.GetMessage(), nil))
}
return err
}

View File

@@ -0,0 +1,36 @@
package connect_middleware
import (
"context"
"connectrpc.com/connect"
// import to make sure go.mod does not lose it
// because dependency is only needed for generated code
_ "github.com/envoyproxy/protoc-gen-validate/validate"
)
func ValidationHandler() connect.UnaryInterceptorFunc {
return func(handler connect.UnaryFunc) connect.UnaryFunc {
return func(ctx context.Context, req connect.AnyRequest) (connect.AnyResponse, error) {
return validate(ctx, req, handler)
}
}
}
// validator interface needed for github.com/envoyproxy/protoc-gen-validate
// (it does not expose an interface itself)
type validator interface {
Validate() error
}
func validate(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc) (connect.AnyResponse, error) {
validate, ok := req.Any().(validator)
if !ok {
return handler(ctx, req)
}
err := validate.Validate()
if err != nil {
return nil, connect.NewError(connect.CodeInvalidArgument, err)
}
return handler(ctx, req)
}

View File

@@ -171,7 +171,7 @@ func CreateGateway(
}, nil
}
func RegisterGateway(ctx context.Context, gateway *Gateway, server Server) error {
func RegisterGateway(ctx context.Context, gateway *Gateway, server WithGateway) error {
err := server.RegisterGateway()(ctx, gateway.mux, gateway.connection)
if err != nil {
return fmt.Errorf("failed to register grpc gateway: %w", err)

View File

@@ -2,11 +2,14 @@ package server
import (
"crypto/tls"
"net/http"
"connectrpc.com/connect"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials"
healthpb "google.golang.org/grpc/health/grpc_health_v1"
protoreflect "google.golang.org/protobuf/reflect/protoreflect"
"github.com/zitadel/zitadel/internal/api/authz"
grpc_api "github.com/zitadel/zitadel/internal/api/grpc"
@@ -19,21 +22,36 @@ import (
)
type Server interface {
RegisterServer(*grpc.Server)
RegisterGateway() RegisterGatewayFunc
AppName() string
MethodPrefix() string
AuthMethods() authz.MethodMapping
}
type GrpcServer interface {
Server
RegisterServer(*grpc.Server)
}
type WithGateway interface {
Server
RegisterGateway() RegisterGatewayFunc
}
// WithGatewayPrefix extends the server interface with a prefix for the grpc gateway
//
// it's used for the System, Admin, Mgmt and Auth API
type WithGatewayPrefix interface {
Server
GrpcServer
WithGateway
GatewayPathPrefix() string
}
type ConnectServer interface {
Server
RegisterConnectServer(interceptors ...connect.Interceptor) (string, http.Handler)
FileDescriptor() protoreflect.FileDescriptor
}
func CreateServer(
verifier authz.APITokenVerifier,
systemAuthz authz.Config,