mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-02 13:32:26 +00:00
fix: sanitize host headers before use
# Which Problems Are Solved
Host headers used to identify the instance and further used in public responses such as OIDC discovery endpoints, email links and more were not correctly handled. While they were matched against existing instances, they were not properly sanitized.
# How the Problems Are Solved
Sanitize host header including port validation (if provided).
# Additional Changes
None
# Additional Context
- requires backports
(cherry picked from commit 72a5c33e6a)
This commit is contained in:
@@ -28,7 +28,10 @@ type Instance interface {
|
||||
}
|
||||
|
||||
type InstanceVerifier interface {
|
||||
InstanceByHost(ctx context.Context, host, publicDomain string) (Instance, error)
|
||||
// InstanceByHost returns the instance for the given instanceDomain or publicDomain.
|
||||
// Previously it used the host (hostname[:port]) to find the instance, but is now using the domain (hostname) only.
|
||||
// For preventing issues in backports, the name of the method is not changed.
|
||||
InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (Instance, error)
|
||||
InstanceByID(ctx context.Context, id string) (Instance, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler co
|
||||
|
||||
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
|
||||
requestContext := zitadel_http.DomainContext(ctx)
|
||||
if requestContext.InstanceHost == "" {
|
||||
if requestContext.InstanceDomain() == "" {
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||
if err != nil {
|
||||
origin := zitadel_http.DomainContext(ctx)
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
||||
|
||||
@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req interface{}, handler grpc.Unar
|
||||
|
||||
func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) {
|
||||
requestContext := zitadel_http.DomainContext(ctx)
|
||||
if requestContext.InstanceHost == "" {
|
||||
if requestContext.InstanceDomain() == "" {
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
||||
return nil, status.Error(codes.NotFound, "no instanceHost specified")
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||
if err != nil {
|
||||
origin := zitadel_http.DomainContext(ctx)
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
||||
|
||||
@@ -88,10 +88,10 @@ func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ contex
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
requestContext := zitadel_http.DomainContext(ctx)
|
||||
if requestContext.InstanceHost == "" {
|
||||
if requestContext.InstanceDomain() == "" {
|
||||
return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound")
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceHost, requestContext.PublicHost)
|
||||
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/muhlemmer/httpforwarded"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
)
|
||||
@@ -16,7 +20,7 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
|
||||
origin := composeDomainContext(
|
||||
r,
|
||||
enforceHttps,
|
||||
// to make sure we don't break existing configurations we append the existing checked headers as well
|
||||
// to make sure we don't break existing configurations, we append the existing checked headers as well
|
||||
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
|
||||
publicDomainHeaders,
|
||||
)
|
||||
@@ -25,17 +29,13 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
|
||||
}
|
||||
}
|
||||
|
||||
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx {
|
||||
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
|
||||
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
|
||||
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
|
||||
if instanceHost == "" {
|
||||
instanceHost = r.Host
|
||||
}
|
||||
return &http_util.DomainCtx{
|
||||
InstanceHost: instanceHost,
|
||||
Protocol: protocolFromRequest(instanceProto, publicProto, enforceHttps),
|
||||
PublicHost: publicHost,
|
||||
}
|
||||
return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
|
||||
}
|
||||
|
||||
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
|
||||
@@ -67,7 +67,7 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
||||
hostFromHeader = r.Header.Get(header)
|
||||
}
|
||||
if host == "" {
|
||||
host = hostFromHeader
|
||||
host = sanitizeHost(hostFromHeader)
|
||||
}
|
||||
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
|
||||
proto = protoFromHeader
|
||||
@@ -76,6 +76,35 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
||||
return host, proto
|
||||
}
|
||||
|
||||
func sanitizeHost(rawHost string) (host string) {
|
||||
if rawHost == "" {
|
||||
return ""
|
||||
}
|
||||
host, port, err := net.SplitHostPort(rawHost)
|
||||
if err != nil {
|
||||
// if the error is about a missing port, the host is probably just "example.com", so we can return it
|
||||
if isMissingPortError(err) {
|
||||
return rawHost
|
||||
}
|
||||
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
|
||||
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
|
||||
return ""
|
||||
}
|
||||
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
|
||||
portNumber, err := strconv.Atoi(port)
|
||||
if err != nil || portNumber < 1 || portNumber > 65535 {
|
||||
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
|
||||
return ""
|
||||
}
|
||||
// if we reach this point, the host contains a valid port, so we return the complete host
|
||||
return rawHost
|
||||
}
|
||||
|
||||
func isMissingPortError(err error) bool {
|
||||
var addrErr *net.AddrError
|
||||
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
|
||||
}
|
||||
|
||||
func hostFromForwarded(values []string) (string, string) {
|
||||
fwd, fwdErr := httpforwarded.Parse(values)
|
||||
if fwdErr == nil {
|
||||
|
||||
@@ -205,3 +205,75 @@ func Test_composeOrigin(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sanitizeHost(t *testing.T) {
|
||||
type args struct {
|
||||
rawHost string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal host",
|
||||
args: args{rawHost: "example.com"},
|
||||
want: "example.com",
|
||||
},
|
||||
{
|
||||
name: "host with port",
|
||||
args: args{rawHost: "example.com:8080"},
|
||||
want: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "ipv4",
|
||||
args: args{rawHost: "192.168.1.1"},
|
||||
want: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "ipv4 with port",
|
||||
args: args{rawHost: "192.168.1.1:8080"},
|
||||
want: "192.168.1.1:8080",
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"},
|
||||
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with port",
|
||||
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080"},
|
||||
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080",
|
||||
},
|
||||
{
|
||||
name: "host with trailing colon",
|
||||
args: args{rawHost: "example.com:"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "host with invalid port",
|
||||
args: args{rawHost: "example.com:port"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid host",
|
||||
args: args{rawHost: "localhost:@attacker.com"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid host",
|
||||
args: args{rawHost: "localhost:@attacker.com:8080"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
args: args{rawHost: ""},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, sanitizeHost(tt.args.rawHost), "sanitizeHost(%v)", tt.args.rawHost)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ package http
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"net"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type DomainCtx struct {
|
||||
@@ -12,6 +13,35 @@ type DomainCtx struct {
|
||||
Protocol string
|
||||
}
|
||||
|
||||
func NewDomainCtx(instanceHostname, publicHostname, protocol string) *DomainCtx {
|
||||
return &DomainCtx{
|
||||
InstanceHost: instanceHostname,
|
||||
PublicHost: publicHostname,
|
||||
Protocol: protocol,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDomainCtxFromOrigin(origin *url.URL) *DomainCtx {
|
||||
return &DomainCtx{
|
||||
InstanceHost: origin.Host,
|
||||
PublicHost: origin.Host,
|
||||
Protocol: origin.Scheme,
|
||||
}
|
||||
}
|
||||
|
||||
// InstanceDomain returns the hostname for which the request was handled.
|
||||
func (r *DomainCtx) InstanceDomain() string {
|
||||
return hostnameFromHost(r.InstanceHost)
|
||||
}
|
||||
|
||||
func hostnameFromHost(host string) string {
|
||||
hostname, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return host
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// RequestedHost returns the host (hostname[:port]) for which the request was handled.
|
||||
// The instance host is returned if not public host was set.
|
||||
func (r *DomainCtx) RequestedHost() string {
|
||||
@@ -22,13 +52,13 @@ func (r *DomainCtx) RequestedHost() string {
|
||||
}
|
||||
|
||||
// RequestedDomain returns the domain (hostname) for which the request was handled.
|
||||
// The instance domain is returned if not public host / domain was set.
|
||||
// The instance domain is returned if no public host / domain was set.
|
||||
func (r *DomainCtx) RequestedDomain() string {
|
||||
return strings.Split(r.RequestedHost(), ":")[0]
|
||||
return hostnameFromHost(r.RequestedHost())
|
||||
}
|
||||
|
||||
// Origin returns the origin (protocol://hostname[:port]) for which the request was handled.
|
||||
// The instance host is used if not public host was set.
|
||||
// The instance host is used if no public host was set.
|
||||
func (r *DomainCtx) Origin() string {
|
||||
host := r.PublicHost
|
||||
if host == "" {
|
||||
|
||||
@@ -52,10 +52,6 @@ func enrichCtx(ctx context.Context, origin string) (context.Context, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = http_utils.WithDomainContext(ctx, &http_utils.DomainCtx{
|
||||
InstanceHost: u.Host,
|
||||
PublicHost: u.Host,
|
||||
Protocol: u.Scheme,
|
||||
})
|
||||
ctx = http_utils.WithDomainContext(ctx, http_utils.NewDomainCtxFromOrigin(u))
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
@@ -201,21 +200,18 @@ var (
|
||||
instanceByIDQuery string
|
||||
)
|
||||
|
||||
func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost string) (_ authz.Instance, err error) {
|
||||
func (q *Queries) InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (_ authz.Instance, err error) {
|
||||
var instance *authzInstance
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to get instance by host: instanceHost %s, publicHost %s: %w", instanceHost, publicHost, err)
|
||||
err = fmt.Errorf("unable to get instance by domain: instanceDomain %s, publicHostname %s: %w", instanceDomain, publicDomain, err)
|
||||
} else {
|
||||
q.caches.activeInstances.Add(instance.ID, true)
|
||||
}
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
|
||||
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
|
||||
|
||||
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
|
||||
if ok {
|
||||
return instance, instance.checkDomain(instanceDomain, publicDomain)
|
||||
|
||||
@@ -2,7 +2,6 @@ package webauthn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
@@ -20,7 +19,7 @@ func WebAuthNsToCredentials(ctx context.Context, webAuthNs []*domain.WebAuthNTok
|
||||
// then we check if the requested rpID matches the instance domain
|
||||
if webAuthN.State == domain.MFAStateReady &&
|
||||
(webAuthN.RPID == rpID ||
|
||||
(webAuthN.RPID == "" && rpID == strings.Split(http.DomainContext(ctx).InstanceHost, ":")[0])) {
|
||||
(webAuthN.RPID == "" && rpID == http.DomainContext(ctx).InstanceDomain())) {
|
||||
creds = append(creds, webauthn.Credential{
|
||||
ID: webAuthN.KeyID,
|
||||
PublicKey: webAuthN.PublicKey,
|
||||
|
||||
Reference in New Issue
Block a user