mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-23 04:06:46 +00:00
126 lines
4.0 KiB
Go
126 lines
4.0 KiB
Go
package middleware
|
|
|
|
import (
|
|
"errors"
|
|
"net"
|
|
"net/http"
|
|
"slices"
|
|
"strconv"
|
|
|
|
"github.com/gorilla/mux"
|
|
"github.com/muhlemmer/httpforwarded"
|
|
"github.com/zitadel/logging"
|
|
|
|
http_util "github.com/zitadel/zitadel/internal/api/http"
|
|
)
|
|
|
|
func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHostHeaders, publicDomainHeaders []string) mux.MiddlewareFunc {
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
origin := composeDomainContext(
|
|
r,
|
|
enforceHttps,
|
|
// to make sure we don't break existing configurations, we append the existing checked headers as well
|
|
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
|
|
publicDomainHeaders,
|
|
)
|
|
next.ServeHTTP(w, r.WithContext(http_util.WithDomainContext(r.Context(), origin)))
|
|
})
|
|
}
|
|
}
|
|
|
|
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
|
|
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
|
|
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
|
|
if instanceHost == "" {
|
|
instanceHost = r.Host
|
|
}
|
|
return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
|
|
}
|
|
|
|
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
|
|
if enforceHttps {
|
|
return "https"
|
|
}
|
|
if publicProto != "" {
|
|
return publicProto
|
|
}
|
|
if instanceProto != "" {
|
|
return instanceProto
|
|
}
|
|
return "http"
|
|
}
|
|
|
|
func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
|
var hostFromHeader, protoFromHeader string
|
|
for _, header := range headers {
|
|
switch http.CanonicalHeaderKey(header) {
|
|
case http.CanonicalHeaderKey(http_util.Forwarded),
|
|
http.CanonicalHeaderKey(http_util.ForwardedFor),
|
|
http.CanonicalHeaderKey(http_util.ZitadelForwarded):
|
|
hostFromHeader, protoFromHeader = hostFromForwarded(r.Header.Values(header))
|
|
case http.CanonicalHeaderKey(http_util.ForwardedHost):
|
|
hostFromHeader = r.Header.Get(header)
|
|
case http.CanonicalHeaderKey(http_util.ForwardedProto):
|
|
protoFromHeader = r.Header.Get(header)
|
|
default:
|
|
hostFromHeader = r.Header.Get(header)
|
|
}
|
|
if host == "" {
|
|
host = sanitizeHost(hostFromHeader)
|
|
}
|
|
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
|
|
proto = protoFromHeader
|
|
}
|
|
}
|
|
return host, proto
|
|
}
|
|
|
|
func sanitizeHost(rawHost string) (host string) {
|
|
if rawHost == "" {
|
|
return ""
|
|
}
|
|
host, port, err := net.SplitHostPort(rawHost)
|
|
if err != nil {
|
|
// if the error is about a missing port, the host is probably just "example.com", so we can return it
|
|
if isMissingPortError(err) {
|
|
return rawHost
|
|
}
|
|
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
|
|
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
|
|
return ""
|
|
}
|
|
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
|
|
portNumber, err := strconv.Atoi(port)
|
|
if err != nil || portNumber < 1 || portNumber > 65535 {
|
|
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
|
|
return ""
|
|
}
|
|
// if we reach this point, the host contains a valid port, so we return the complete host
|
|
return rawHost
|
|
}
|
|
|
|
func isMissingPortError(err error) bool {
|
|
var addrErr *net.AddrError
|
|
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
|
|
}
|
|
|
|
func hostFromForwarded(values []string) (string, string) {
|
|
fwd, fwdErr := httpforwarded.Parse(values)
|
|
if fwdErr == nil {
|
|
return oldestForwardedValue(fwd, "host"), oldestForwardedValue(fwd, "proto")
|
|
}
|
|
return "", ""
|
|
}
|
|
|
|
func oldestForwardedValue(forwarded map[string][]string, key string) string {
|
|
if forwarded == nil {
|
|
return ""
|
|
}
|
|
values := forwarded[key]
|
|
if len(values) == 0 {
|
|
return ""
|
|
}
|
|
return values[0]
|
|
}
|