From 3a1569bd945caf46e6449726b7bf5d8fc54f49f4 Mon Sep 17 00:00:00 2001 From: Livio Amstutz Date: Fri, 3 Jun 2022 14:44:04 +0200 Subject: [PATCH] fix: grpc gateway interceptors (#3767) --- cmd/admin/start/start.go | 2 +- internal/api/api.go | 14 +++-------- internal/api/grpc/server/gateway.go | 24 ++++++++++++++---- .../server/middleware/instance_interceptor.go | 25 ++++++++++++++++--- .../middleware/instance_interceptor_test.go | 6 ++--- .../http/middleware/instance_interceptor.go | 4 +-- 6 files changed, 51 insertions(+), 24 deletions(-) diff --git a/cmd/admin/start/start.go b/cmd/admin/start/start.go index df58a1d7a2..a5272eabe8 100644 --- a/cmd/admin/start/start.go +++ b/cmd/admin/start/start.go @@ -153,7 +153,7 @@ func startAPIs(ctx context.Context, router *mux.Router, commands *command.Comman } verifier := internal_authz.Start(repo, http_util.BuildHTTP(config.ExternalDomain, config.ExternalPort, config.ExternalSecure)+oidc.HandlerPrefix, systemAPIKeys) - apis := api.New(config.Port, router, queries, verifier, config.InternalAuthZ, config.ExternalSecure, config.HTTP2HostHeader) + apis := api.New(config.Port, router, queries, verifier, config.InternalAuthZ, config.ExternalSecure, config.HTTP2HostHeader, config.HTTP1HostHeader) authRepo, err := auth_es.Start(config.Auth, config.SystemDefaults, commands, queries, dbClient, keys.OIDC, keys.User) if err != nil { return fmt.Errorf("error starting auth repo: %w", err) diff --git a/internal/api/api.go b/internal/api/api.go index dad6623b87..31bbc8b541 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -26,6 +26,7 @@ type API struct { health health router *mux.Router externalSecure bool + http1HostName string } type health interface { @@ -33,21 +34,14 @@ type health interface { Instance(ctx context.Context) (*query.Instance, error) } -func New( - port uint16, - router *mux.Router, - queries *query.Queries, - verifier *internal_authz.TokenVerifier, - authZ internal_authz.Config, - externalSecure bool, - http2HostName string, -) *API { +func New(port uint16, router *mux.Router, queries *query.Queries, verifier *internal_authz.TokenVerifier, authZ internal_authz.Config, externalSecure bool, http2HostName, http1HostName string) *API { api := &API{ port: port, verifier: verifier, health: queries, router: router, externalSecure: externalSecure, + http1HostName: http1HostName, } api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName) api.routeGRPC() @@ -59,7 +53,7 @@ func New( func (a *API) RegisterServer(ctx context.Context, grpcServer server.Server) error { grpcServer.RegisterServer(a.grpcServer) - handler, prefix, err := server.CreateGateway(ctx, grpcServer, a.port) + handler, prefix, err := server.CreateGateway(ctx, grpcServer, a.port, a.http1HostName) if err != nil { return err } diff --git a/internal/api/grpc/server/gateway.go b/internal/api/grpc/server/gateway.go index 4310c52db0..c0e32685fe 100644 --- a/internal/api/grpc/server/gateway.go +++ b/internal/api/grpc/server/gateway.go @@ -12,6 +12,7 @@ import ( "google.golang.org/protobuf/encoding/protojson" client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware" + "github.com/zitadel/zitadel/internal/api/grpc/server/middleware" http_mw "github.com/zitadel/zitadel/internal/api/http/middleware" ) @@ -56,7 +57,7 @@ type Gateway interface { type GatewayFunc func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error -func CreateGateway(ctx context.Context, g Gateway, port uint16) (http.Handler, string, error) { +func CreateGateway(ctx context.Context, g Gateway, port uint16, http1HostName string) (http.Handler, string, error) { runtimeMux := runtime.NewServeMux(serveMuxOptions...) opts := []grpc.DialOption{ grpc.WithTransportCredentials(insecure.NewCredentials()), @@ -66,11 +67,24 @@ func CreateGateway(ctx context.Context, g Gateway, port uint16) (http.Handler, s if err != nil { return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err) } - return addInterceptors(runtimeMux), g.GatewayPathPrefix(), nil + return addInterceptors(runtimeMux, http1HostName), g.GatewayPathPrefix(), nil } -func addInterceptors(handler http.Handler) http.Handler { - handler = http_mw.DefaultMetricsHandler(handler) +func addInterceptors(handler http.Handler, http1HostName string) http.Handler { + handler = http1Host(handler, http1HostName) + handler = http_mw.CORSInterceptor(handler) handler = http_mw.DefaultTelemetryHandler(handler) - return http_mw.CORSInterceptor(handler) + return http_mw.DefaultMetricsHandler(handler) +} + +func http1Host(next http.Handler, http1HostName string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + host, err := http_mw.HostFromRequest(r, http1HostName) + if err != nil { + http.Error(w, err.Error(), http.StatusForbidden) + return + } + r.Header.Set(middleware.HTTP1Host, host) + next.ServeHTTP(w, r) + }) } diff --git a/internal/api/grpc/server/middleware/instance_interceptor.go b/internal/api/grpc/server/middleware/instance_interceptor.go index e5b7aea7ce..73a417fcb1 100644 --- a/internal/api/grpc/server/middleware/instance_interceptor.go +++ b/internal/api/grpc/server/middleware/instance_interceptor.go @@ -14,6 +14,10 @@ import ( "github.com/zitadel/zitadel/internal/telemetry/tracing" ) +const ( + HTTP1Host = "x-zitadel-http1-host" +) + type InstanceVerifier interface { GetInstance(ctx context.Context) } @@ -36,7 +40,7 @@ func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInf } } - host, err := hostNameFromContext(interceptorCtx, headerName) + host, err := hostFromContext(interceptorCtx, headerName) if err != nil { return nil, status.Error(codes.PermissionDenied, err.Error()) } @@ -48,12 +52,19 @@ func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInf return handler(authz.WithInstance(ctx, instance), req) } -func hostNameFromContext(ctx context.Context, headerName string) (string, error) { +func hostFromContext(ctx context.Context, headerName string) (string, error) { md, ok := metadata.FromIncomingContext(ctx) if !ok { return "", fmt.Errorf("cannot read metadata") } - host, ok := md[headerName] + host, ok := md[HTTP1Host] + if ok && len(host) == 1 { + if !isAllowedToSendHTTP1Header(md) { + return "", fmt.Errorf("no valid host header") + } + return host[0], nil + } + host, ok = md[headerName] if !ok { return "", fmt.Errorf("cannot find header: %v", headerName) } @@ -62,3 +73,11 @@ func hostNameFromContext(ctx context.Context, headerName string) (string, error) } return host[0], nil } + +//isAllowedToSendHTTP1Header check if the gRPC call was sent to `localhost` +//this is only possible when calling the server directly running on localhost +//or through the gRPC gateway +func isAllowedToSendHTTP1Header(md metadata.MD) bool { + authority, ok := md[":authority"] + return ok && len(authority) == 1 && strings.Split(authority[0], ":")[0] == "localhost" +} diff --git a/internal/api/grpc/server/middleware/instance_interceptor_test.go b/internal/api/grpc/server/middleware/instance_interceptor_test.go index 8074f1fca4..b8a9342a01 100644 --- a/internal/api/grpc/server/middleware/instance_interceptor_test.go +++ b/internal/api/grpc/server/middleware/instance_interceptor_test.go @@ -63,13 +63,13 @@ func Test_hostNameFromContext(t *testing.T) { } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got, err := hostNameFromContext(tt.args.ctx, tt.args.headerName) + got, err := hostFromContext(tt.args.ctx, tt.args.headerName) if (err != nil) != tt.res.err { - t.Errorf("hostNameFromContext() error = %v, wantErr %v", err, tt.res.err) + t.Errorf("hostFromContext() error = %v, wantErr %v", err, tt.res.err) return } if got != tt.res.want { - t.Errorf("hostNameFromContext() got = %v, want %v", got, tt.res.want) + t.Errorf("hostFromContext() got = %v, want %v", got, tt.res.want) } }) } diff --git a/internal/api/http/middleware/instance_interceptor.go b/internal/api/http/middleware/instance_interceptor.go index cf0d72d0a5..17e08e64ad 100644 --- a/internal/api/http/middleware/instance_interceptor.go +++ b/internal/api/http/middleware/instance_interceptor.go @@ -66,7 +66,7 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st authCtx, span := tracing.NewServerInterceptorSpan(ctx) defer func() { span.EndWithError(err) }() - host, err := getHost(r, headerName) + host, err := HostFromRequest(r, headerName) if err != nil { return nil, err } @@ -79,7 +79,7 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st return authz.WithInstance(ctx, instance), nil } -func getHost(r *http.Request, headerName string) (string, error) { +func HostFromRequest(r *http.Request, headerName string) (string, error) { host := r.Host if headerName != "host" { host = r.Header.Get(headerName)