Files
zitadel/internal/api/http/middleware/origin_interceptor.go
Livio Spring 7520450e11 fix: sanitize host headers before use
# Which Problems Are Solved

Host headers used to identify the instance and further used in public responses such as OIDC discovery endpoints, email links and more were not correctly handled. While they were matched against existing instances, they were not properly sanitized.

# How the Problems Are Solved

Sanitize host header including port validation (if provided).

# Additional Changes

None

# Additional Context

- requires backports

(cherry picked from commit 72a5c33e6a)
2025-10-29 10:07:05 +01:00

126 lines
4.0 KiB
Go

package middleware
import (
"errors"
"net"
"net/http"
"slices"
"strconv"
"github.com/gorilla/mux"
"github.com/muhlemmer/httpforwarded"
"github.com/zitadel/logging"
http_util "github.com/zitadel/zitadel/internal/api/http"
)
func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHostHeaders, publicDomainHeaders []string) mux.MiddlewareFunc {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
origin := composeDomainContext(
r,
enforceHttps,
// to make sure we don't break existing configurations, we append the existing checked headers as well
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
publicDomainHeaders,
)
next.ServeHTTP(w, r.WithContext(http_util.WithDomainContext(r.Context(), origin)))
})
}
}
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
if instanceHost == "" {
instanceHost = r.Host
}
return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
}
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
if enforceHttps {
return "https"
}
if publicProto != "" {
return publicProto
}
if instanceProto != "" {
return instanceProto
}
return "http"
}
func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
var hostFromHeader, protoFromHeader string
for _, header := range headers {
switch http.CanonicalHeaderKey(header) {
case http.CanonicalHeaderKey(http_util.Forwarded),
http.CanonicalHeaderKey(http_util.ForwardedFor),
http.CanonicalHeaderKey(http_util.ZitadelForwarded):
hostFromHeader, protoFromHeader = hostFromForwarded(r.Header.Values(header))
case http.CanonicalHeaderKey(http_util.ForwardedHost):
hostFromHeader = r.Header.Get(header)
case http.CanonicalHeaderKey(http_util.ForwardedProto):
protoFromHeader = r.Header.Get(header)
default:
hostFromHeader = r.Header.Get(header)
}
if host == "" {
host = sanitizeHost(hostFromHeader)
}
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
proto = protoFromHeader
}
}
return host, proto
}
func sanitizeHost(rawHost string) (host string) {
if rawHost == "" {
return ""
}
host, port, err := net.SplitHostPort(rawHost)
if err != nil {
// if the error is about a missing port, the host is probably just "example.com", so we can return it
if isMissingPortError(err) {
return rawHost
}
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
return ""
}
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
portNumber, err := strconv.Atoi(port)
if err != nil || portNumber < 1 || portNumber > 65535 {
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
return ""
}
// if we reach this point, the host contains a valid port, so we return the complete host
return rawHost
}
func isMissingPortError(err error) bool {
var addrErr *net.AddrError
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
}
func hostFromForwarded(values []string) (string, string) {
fwd, fwdErr := httpforwarded.Parse(values)
if fwdErr == nil {
return oldestForwardedValue(fwd, "host"), oldestForwardedValue(fwd, "proto")
}
return "", ""
}
func oldestForwardedValue(forwarded map[string][]string, key string) string {
if forwarded == nil {
return ""
}
values := forwarded[key]
if len(values) == 0 {
return ""
}
return values[0]
}