2022-06-03 12:44:04 +00:00

91 lines
2.7 KiB
Go

package server
import (
"context"
"fmt"
"net/http"
"strings"
"github.com/grpc-ecosystem/grpc-gateway/v2/runtime"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"google.golang.org/protobuf/encoding/protojson"
client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
)
const (
mimeWildcard = "*/*"
)
var (
customHeaders = []string{
"x-zitadel-",
}
jsonMarshaler = &runtime.JSONPb{
UnmarshalOptions: protojson.UnmarshalOptions{
DiscardUnknown: true,
},
}
serveMuxOptions = []runtime.ServeMuxOption{
runtime.WithMarshalerOption(jsonMarshaler.ContentType(nil), jsonMarshaler),
runtime.WithMarshalerOption(mimeWildcard, jsonMarshaler),
runtime.WithMarshalerOption(runtime.MIMEWildcard, jsonMarshaler),
runtime.WithIncomingHeaderMatcher(headerMatcher),
runtime.WithOutgoingHeaderMatcher(runtime.DefaultHeaderMatcher),
}
headerMatcher = runtime.HeaderMatcherFunc(
func(header string) (string, bool) {
for _, customHeader := range customHeaders {
if strings.HasPrefix(strings.ToLower(header), customHeader) {
return header, true
}
}
return runtime.DefaultHeaderMatcher(header)
},
)
)
type Gateway interface {
RegisterGateway() GatewayFunc
GatewayPathPrefix() string
}
type GatewayFunc func(ctx context.Context, mux *runtime.ServeMux, endpoint string, opts []grpc.DialOption) error
func CreateGateway(ctx context.Context, g Gateway, port uint16, http1HostName string) (http.Handler, string, error) {
runtimeMux := runtime.NewServeMux(serveMuxOptions...)
opts := []grpc.DialOption{
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithUnaryInterceptor(client_middleware.DefaultTracingClient()),
}
err := g.RegisterGateway()(ctx, runtimeMux, fmt.Sprintf("localhost:%d", port), opts)
if err != nil {
return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err)
}
return addInterceptors(runtimeMux, http1HostName), g.GatewayPathPrefix(), nil
}
func addInterceptors(handler http.Handler, http1HostName string) http.Handler {
handler = http1Host(handler, http1HostName)
handler = http_mw.CORSInterceptor(handler)
handler = http_mw.DefaultTelemetryHandler(handler)
return http_mw.DefaultMetricsHandler(handler)
}
func http1Host(next http.Handler, http1HostName string) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
host, err := http_mw.HostFromRequest(r, http1HostName)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
return
}
r.Header.Set(middleware.HTTP1Host, host)
next.ServeHTTP(w, r)
})
}