mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-23 14:27:40 +00:00
# Which Problems Are Solved Host headers used to identify the instance and further used in public responses such as OIDC discovery endpoints, email links and more were not correctly handled. While they were matched against existing instances, they were not properly sanitized. # How the Problems Are Solved Sanitize host header including port validation (if provided). # Additional Changes None # Additional Context - requires backports (cherry picked from commit72a5c33e6a) (cherry picked from commit7520450e11)
126 lines
4.0 KiB
Go
126 lines
4.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/muhlemmer/httpforwarded"
|
|
"github.com/zitadel/logging"
|
|
|
|
http_util "github.com/zitadel/zitadel/internal/api/http"
|
|
)
|
|
|
|
func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHostHeaders, publicDomainHeaders []string) mux.MiddlewareFunc {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := composeDomainContext(
|
|
r,
|
|
enforceHttps,
|
|
// to make sure we don't break existing configurations, we append the existing checked headers as well
|
|
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
|
|
publicDomainHeaders,
|
|
)
|
|
next.ServeHTTP(w, r.WithContext(http_util.WithDomainContext(r.Context(), origin)))
|
|
})
|
|
}
|
|
}
|
|
|
|
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
|
|
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
|
|
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
|
|
if instanceHost == "" {
|
|
instanceHost = r.Host
|
|
}
|
|
return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
|
|
}
|
|
|
|
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
|
|
if enforceHttps {
|
|
return "https"
|
|
}
|
|
if publicProto != "" {
|
|
return publicProto
|
|
}
|
|
if instanceProto != "" {
|
|
return instanceProto
|
|
}
|
|
return "http"
|
|
}
|
|
|
|
func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
|
var hostFromHeader, protoFromHeader string
|
|
for _, header := range headers {
|
|
switch http.CanonicalHeaderKey(header) {
|
|
case http.CanonicalHeaderKey(http_util.Forwarded),
|
|
http.CanonicalHeaderKey(http_util.ForwardedFor),
|
|
http.CanonicalHeaderKey(http_util.ZitadelForwarded):
|
|
hostFromHeader, protoFromHeader = hostFromForwarded(r.Header.Values(header))
|
|
case http.CanonicalHeaderKey(http_util.ForwardedHost):
|
|
hostFromHeader = r.Header.Get(header)
|
|
case http.CanonicalHeaderKey(http_util.ForwardedProto):
|
|
protoFromHeader = r.Header.Get(header)
|
|
default:
|
|
hostFromHeader = r.Header.Get(header)
|
|
}
|
|
if host == "" {
|
|
host = sanitizeHost(hostFromHeader)
|
|
}
|
|
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
|
|
proto = protoFromHeader
|
|
}
|
|
}
|
|
return host, proto
|
|
}
|
|
|
|
func sanitizeHost(rawHost string) (host string) {
|
|
if rawHost == "" {
|
|
return ""
|
|
}
|
|
host, port, err := net.SplitHostPort(rawHost)
|
|
if err != nil {
|
|
// if the error is about a missing port, the host is probably just "example.com", so we can return it
|
|
if isMissingPortError(err) {
|
|
return rawHost
|
|
}
|
|
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
|
|
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
|
|
return ""
|
|
}
|
|
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
|
|
portNumber, err := strconv.Atoi(port)
|
|
if err != nil || portNumber < 1 || portNumber > 65535 {
|
|
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
|
|
return ""
|
|
}
|
|
// if we reach this point, the host contains a valid port, so we return the complete host
|
|
return rawHost
|
|
}
|
|
|
|
func isMissingPortError(err error) bool {
|
|
var addrErr *net.AddrError
|
|
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
|
|
}
|
|
|
|
func hostFromForwarded(values []string) (string, string) {
|
|
fwd, fwdErr := httpforwarded.Parse(values)
|
|
if fwdErr == nil {
|
|
return oldestForwardedValue(fwd, "host"), oldestForwardedValue(fwd, "proto")
|
|
}
|
|
return "", ""
|
|
}
|
|
|
|
func oldestForwardedValue(forwarded map[string][]string, key string) string {
|
|
if forwarded == nil {
|
|
return ""
|
|
}
|
|
values := forwarded[key]
|
|
if len(values) == 0 {
|
|
return ""
|
|
}
|
|
return values[0]
|
|
}
|