diff --git a/cmd/start/start.go b/cmd/start/start.go index 1cb6bab94f..c5153a3ffc 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -223,7 +223,7 @@ func startAPIs( logging.Warn("access logs are currently in beta") } accessInterceptor := middleware.NewAccessInterceptor(accessSvc, config.Quotas.Access) - apis := api.New(config.Port, router, queries, verifier, config.InternalAuthZ, config.ExternalSecure, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, accessSvc) + apis := api.New(config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, accessSvc) authRepo, err := auth_es.Start(ctx, config.Auth, config.SystemDefaults, commands, queries, dbClient, eventstore, keys.OIDC, keys.User) if err != nil { return fmt.Errorf("error starting auth repo: %w", err) diff --git a/internal/api/api.go b/internal/api/api.go index ec51d7fffd..78044e5b6b 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -23,13 +23,12 @@ import ( ) type API struct { - port uint16 - grpcServer *grpc.Server - verifier *internal_authz.TokenVerifier - health health - router *mux.Router - externalSecure bool - http1HostName string + port uint16 + grpcServer *grpc.Server + verifier *internal_authz.TokenVerifier + health health + router *mux.Router + http1HostName string } type health interface { @@ -43,19 +42,15 @@ func New( queries *query.Queries, verifier *internal_authz.TokenVerifier, authZ internal_authz.Config, - externalSecure bool, - tlsConfig *tls.Config, - http2HostName, - http1HostName string, + tlsConfig *tls.Config, http2HostName, http1HostName string, accessSvc *logstore.Service, ) *API { api := &API{ - port: port, - verifier: verifier, - health: queries, - router: router, - externalSecure: externalSecure, - http1HostName: http1HostName, + port: port, + verifier: verifier, + health: queries, + router: router, + http1HostName: http1HostName, } api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc) @@ -95,42 +90,34 @@ func (a *API) routeGRPC() { Headers("Content-Type", "application/grpc"). Handler(a.grpcServer) - if !a.externalSecure { - a.routeGRPCWeb(a.router) - return - } - a.routeGRPCWeb(http2Route) + a.routeGRPCWeb() } -func (a *API) routeGRPCWeb(router *mux.Router) { - router.NewRoute(). +func (a *API) routeGRPCWeb() { + grpcWebServer := grpcweb.WrapServer(a.grpcServer, + grpcweb.WithAllowedRequestHeaders( + []string{ + http_util.Origin, + http_util.ContentType, + http_util.Accept, + http_util.AcceptLanguage, + http_util.Authorization, + http_util.ZitadelOrgID, + http_util.XUserAgent, + http_util.XGrpcWeb, + }, + ), + grpcweb.WithOriginFunc(func(_ string) bool { + return true + }), + ) + a.router.NewRoute(). Methods(http.MethodPost, http.MethodOptions). MatcherFunc( func(r *http.Request, _ *mux.RouteMatch) bool { - if strings.Contains(strings.ToLower(r.Header.Get("content-type")), "application/grpc-web+") { - return true - } - return strings.Contains(strings.ToLower(r.Header.Get("access-control-request-headers")), "x-grpc-web") + return grpcWebServer.IsGrpcWebRequest(r) || grpcWebServer.IsAcceptableGrpcCorsRequest(r) }). - Handler( - grpcweb.WrapServer(a.grpcServer, - grpcweb.WithAllowedRequestHeaders( - []string{ - http_util.Origin, - http_util.ContentType, - http_util.Accept, - http_util.AcceptLanguage, - http_util.Authorization, - http_util.ZitadelOrgID, - http_util.XUserAgent, - http_util.XGrpcWeb, - }, - ), - grpcweb.WithOriginFunc(func(_ string) bool { - return true - }), - ), - ) + Handler(grpcWebServer) } func (a *API) healthHandler() http.Handler {