package middleware import ( "net/http" "slices" "github.com/gorilla/mux" "github.com/muhlemmer/httpforwarded" http_util "github.com/zitadel/zitadel/internal/api/http" ) func WithOrigin(fallBackToHttps bool, http1Header, http2Header string, instanceHostHeaders, publicDomainHeaders []string) mux.MiddlewareFunc { return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { origin := composeDomainContext( r, fallBackToHttps, // to make sure we don't break existing configurations we append the existing checked headers as well slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)), publicDomainHeaders, ) next.ServeHTTP(w, r.WithContext(http_util.WithDomainContext(r.Context(), origin))) }) } } func composeDomainContext(r *http.Request, fallBackToHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx { instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders) publicHost, publicProto := hostFromRequest(r, publicDomainHeaders) if publicProto == "" { publicProto = instanceProto } if publicProto == "" { publicProto = "http" if fallBackToHttps { publicProto = "https" } } if instanceHost == "" { instanceHost = r.Host } return &http_util.DomainCtx{ InstanceHost: instanceHost, Protocol: publicProto, PublicHost: publicHost, } } func hostFromRequest(r *http.Request, headers []string) (host, proto string) { var hostFromHeader, protoFromHeader string for _, header := range headers { switch http.CanonicalHeaderKey(header) { case http.CanonicalHeaderKey(http_util.Forwarded), http.CanonicalHeaderKey(http_util.ForwardedFor): hostFromHeader, protoFromHeader = hostFromForwarded(r.Header.Values(header)) case http.CanonicalHeaderKey(http_util.ForwardedHost): hostFromHeader = r.Header.Get(header) case http.CanonicalHeaderKey(http_util.ForwardedProto): protoFromHeader = r.Header.Get(header) default: hostFromHeader = r.Header.Get(header) } if host == "" { host = hostFromHeader } if proto == "" { proto = protoFromHeader } } return host, proto } func hostFromForwarded(values []string) (string, string) { fwd, fwdErr := httpforwarded.Parse(values) if fwdErr == nil { return oldestForwardedValue(fwd, "host"), oldestForwardedValue(fwd, "proto") } return "", "" } func oldestForwardedValue(forwarded map[string][]string, key string) string { if forwarded == nil { return "" } values := forwarded[key] if len(values) == 0 { return "" } return values[0] }