fix: sanitize host headers before use

# Which Problems Are Solved

Host headers used to identify the instance and further used in public responses such as OIDC discovery endpoints, email links and more were not correctly handled. While they were matched against existing instances, they were not properly sanitized.

# How the Problems Are Solved

Sanitize host header including port validation (if provided).

# Additional Changes

None

# Additional Context

- requires backports

(cherry picked from commit 72a5c33e6a)
This commit is contained in:
Livio Spring
2025-10-29 10:05:37 +01:00
parent 2535f43e69
commit 7520450e11
10 changed files with 157 additions and 32 deletions

View File

@@ -28,7 +28,10 @@ type Instance interface {
} }
type InstanceVerifier interface { type InstanceVerifier interface {
InstanceByHost(ctx context.Context, host, publicDomain string) (Instance, error) // InstanceByHost returns the instance for the given instanceDomain or publicDomain.
// Previously it used the host (hostname[:port]) to find the instance, but is now using the domain (hostname) only.
// For preventing issues in backports, the name of the method is not changed.
InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (Instance, error)
InstanceByID(ctx context.Context, id string) (Instance, error) InstanceByID(ctx context.Context, id string) (Instance, error)
} }

View File

@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler co
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) { func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
requestContext := zitadel_http.DomainContext(ctx) requestContext := zitadel_http.DomainContext(ctx)
if requestContext.InstanceHost == "" { if requestContext.InstanceDomain() == "" {
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance") logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified")) return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
} }
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost) instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
if err != nil { if err != nil {
origin := zitadel_http.DomainContext(ctx) origin := zitadel_http.DomainContext(ctx)
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance") logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")

View File

@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req interface{}, handler grpc.Unar
func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) { func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) {
requestContext := zitadel_http.DomainContext(ctx) requestContext := zitadel_http.DomainContext(ctx)
if requestContext.InstanceHost == "" { if requestContext.InstanceDomain() == "" {
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance") logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
return nil, status.Error(codes.NotFound, "no instanceHost specified") return nil, status.Error(codes.NotFound, "no instanceHost specified")
} }
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost) instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
if err != nil { if err != nil {
origin := zitadel_http.DomainContext(ctx) origin := zitadel_http.DomainContext(ctx)
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance") logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")

View File

@@ -88,10 +88,10 @@ func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ contex
defer func() { span.EndWithError(err) }() defer func() { span.EndWithError(err) }()
requestContext := zitadel_http.DomainContext(ctx) requestContext := zitadel_http.DomainContext(ctx)
if requestContext.InstanceHost == "" { if requestContext.InstanceDomain() == "" {
return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound") return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound")
} }
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceHost, requestContext.PublicHost) instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,11 +1,15 @@
package middleware package middleware
import ( import (
"errors"
"net"
"net/http" "net/http"
"slices" "slices"
"strconv"
"github.com/gorilla/mux" "github.com/gorilla/mux"
"github.com/muhlemmer/httpforwarded" "github.com/muhlemmer/httpforwarded"
"github.com/zitadel/logging"
http_util "github.com/zitadel/zitadel/internal/api/http" http_util "github.com/zitadel/zitadel/internal/api/http"
) )
@@ -16,7 +20,7 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
origin := composeDomainContext( origin := composeDomainContext(
r, r,
enforceHttps, enforceHttps,
// to make sure we don't break existing configurations we append the existing checked headers as well // to make sure we don't break existing configurations, we append the existing checked headers as well
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)), slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
publicDomainHeaders, publicDomainHeaders,
) )
@@ -25,17 +29,13 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
} }
} }
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx { func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders) instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders) publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
if instanceHost == "" { if instanceHost == "" {
instanceHost = r.Host instanceHost = r.Host
} }
return &http_util.DomainCtx{ return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
InstanceHost: instanceHost,
Protocol: protocolFromRequest(instanceProto, publicProto, enforceHttps),
PublicHost: publicHost,
}
} }
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string { func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
@@ -67,7 +67,7 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
hostFromHeader = r.Header.Get(header) hostFromHeader = r.Header.Get(header)
} }
if host == "" { if host == "" {
host = hostFromHeader host = sanitizeHost(hostFromHeader)
} }
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") { if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
proto = protoFromHeader proto = protoFromHeader
@@ -76,6 +76,35 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
return host, proto return host, proto
} }
func sanitizeHost(rawHost string) (host string) {
if rawHost == "" {
return ""
}
host, port, err := net.SplitHostPort(rawHost)
if err != nil {
// if the error is about a missing port, the host is probably just "example.com", so we can return it
if isMissingPortError(err) {
return rawHost
}
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
return ""
}
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
portNumber, err := strconv.Atoi(port)
if err != nil || portNumber < 1 || portNumber > 65535 {
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
return ""
}
// if we reach this point, the host contains a valid port, so we return the complete host
return rawHost
}
func isMissingPortError(err error) bool {
var addrErr *net.AddrError
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
}
func hostFromForwarded(values []string) (string, string) { func hostFromForwarded(values []string) (string, string) {
fwd, fwdErr := httpforwarded.Parse(values) fwd, fwdErr := httpforwarded.Parse(values)
if fwdErr == nil { if fwdErr == nil {

View File

@@ -205,3 +205,75 @@ func Test_composeOrigin(t *testing.T) {
}) })
} }
} }
func Test_sanitizeHost(t *testing.T) {
type args struct {
rawHost string
}
tests := []struct {
name string
args args
want string
}{
{
name: "normal host",
args: args{rawHost: "example.com"},
want: "example.com",
},
{
name: "host with port",
args: args{rawHost: "example.com:8080"},
want: "example.com:8080",
},
{
name: "ipv4",
args: args{rawHost: "192.168.1.1"},
want: "192.168.1.1",
},
{
name: "ipv4 with port",
args: args{rawHost: "192.168.1.1:8080"},
want: "192.168.1.1:8080",
},
{
name: "ipv6",
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"},
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]",
},
{
name: "ipv6 with port",
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080"},
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080",
},
{
name: "host with trailing colon",
args: args{rawHost: "example.com:"},
want: "",
},
{
name: "host with invalid port",
args: args{rawHost: "example.com:port"},
want: "",
},
{
name: "invalid host",
args: args{rawHost: "localhost:@attacker.com"},
want: "",
},
{
name: "invalid host",
args: args{rawHost: "localhost:@attacker.com:8080"},
want: "",
},
{
name: "empty host",
args: args{rawHost: ""},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, sanitizeHost(tt.args.rawHost), "sanitizeHost(%v)", tt.args.rawHost)
})
}
}

View File

@@ -3,7 +3,8 @@ package http
import ( import (
"context" "context"
"fmt" "fmt"
"strings" "net"
"net/url"
) )
type DomainCtx struct { type DomainCtx struct {
@@ -12,6 +13,35 @@ type DomainCtx struct {
Protocol string Protocol string
} }
func NewDomainCtx(instanceHostname, publicHostname, protocol string) *DomainCtx {
return &DomainCtx{
InstanceHost: instanceHostname,
PublicHost: publicHostname,
Protocol: protocol,
}
}
func NewDomainCtxFromOrigin(origin *url.URL) *DomainCtx {
return &DomainCtx{
InstanceHost: origin.Host,
PublicHost: origin.Host,
Protocol: origin.Scheme,
}
}
// InstanceDomain returns the hostname for which the request was handled.
func (r *DomainCtx) InstanceDomain() string {
return hostnameFromHost(r.InstanceHost)
}
func hostnameFromHost(host string) string {
hostname, _, err := net.SplitHostPort(host)
if err != nil {
return host
}
return hostname
}
// RequestedHost returns the host (hostname[:port]) for which the request was handled. // RequestedHost returns the host (hostname[:port]) for which the request was handled.
// The instance host is returned if not public host was set. // The instance host is returned if not public host was set.
func (r *DomainCtx) RequestedHost() string { func (r *DomainCtx) RequestedHost() string {
@@ -22,13 +52,13 @@ func (r *DomainCtx) RequestedHost() string {
} }
// RequestedDomain returns the domain (hostname) for which the request was handled. // RequestedDomain returns the domain (hostname) for which the request was handled.
// The instance domain is returned if not public host / domain was set. // The instance domain is returned if no public host / domain was set.
func (r *DomainCtx) RequestedDomain() string { func (r *DomainCtx) RequestedDomain() string {
return strings.Split(r.RequestedHost(), ":")[0] return hostnameFromHost(r.RequestedHost())
} }
// Origin returns the origin (protocol://hostname[:port]) for which the request was handled. // Origin returns the origin (protocol://hostname[:port]) for which the request was handled.
// The instance host is used if not public host was set. // The instance host is used if no public host was set.
func (r *DomainCtx) Origin() string { func (r *DomainCtx) Origin() string {
host := r.PublicHost host := r.PublicHost
if host == "" { if host == "" {

View File

@@ -52,10 +52,6 @@ func enrichCtx(ctx context.Context, origin string) (context.Context, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
ctx = http_utils.WithDomainContext(ctx, &http_utils.DomainCtx{ ctx = http_utils.WithDomainContext(ctx, http_utils.NewDomainCtxFromOrigin(u))
InstanceHost: u.Host,
PublicHost: u.Host,
Protocol: u.Scheme,
})
return ctx, nil return ctx, nil
} }

View File

@@ -8,7 +8,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"slices" "slices"
"strings"
"time" "time"
sq "github.com/Masterminds/squirrel" sq "github.com/Masterminds/squirrel"
@@ -201,21 +200,18 @@ var (
instanceByIDQuery string instanceByIDQuery string
) )
func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost string) (_ authz.Instance, err error) { func (q *Queries) InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (_ authz.Instance, err error) {
var instance *authzInstance var instance *authzInstance
ctx, span := tracing.NewSpan(ctx) ctx, span := tracing.NewSpan(ctx)
defer func() { defer func() {
if err != nil { if err != nil {
err = fmt.Errorf("unable to get instance by host: instanceHost %s, publicHost %s: %w", instanceHost, publicHost, err) err = fmt.Errorf("unable to get instance by domain: instanceDomain %s, publicHostname %s: %w", instanceDomain, publicDomain, err)
} else { } else {
q.caches.activeInstances.Add(instance.ID, true) q.caches.activeInstances.Add(instance.ID, true)
} }
span.EndWithError(err) span.EndWithError(err)
}() }()
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain) instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
if ok { if ok {
return instance, instance.checkDomain(instanceDomain, publicDomain) return instance, instance.checkDomain(instanceDomain, publicDomain)

View File

@@ -2,7 +2,6 @@ package webauthn
import ( import (
"context" "context"
"strings"
"github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn" "github.com/go-webauthn/webauthn/webauthn"
@@ -20,7 +19,7 @@ func WebAuthNsToCredentials(ctx context.Context, webAuthNs []*domain.WebAuthNTok
// then we check if the requested rpID matches the instance domain // then we check if the requested rpID matches the instance domain
if webAuthN.State == domain.MFAStateReady && if webAuthN.State == domain.MFAStateReady &&
(webAuthN.RPID == rpID || (webAuthN.RPID == rpID ||
(webAuthN.RPID == "" && rpID == strings.Split(http.DomainContext(ctx).InstanceHost, ":")[0])) { (webAuthN.RPID == "" && rpID == http.DomainContext(ctx).InstanceDomain())) {
creds = append(creds, webauthn.Credential{ creds = append(creds, webauthn.Credential{
ID: webAuthN.KeyID, ID: webAuthN.KeyID,
PublicKey: webAuthN.PublicKey, PublicKey: webAuthN.PublicKey,