Merge commit from fork

* fix: sanitize host headers before use

* add additional test
This commit is contained in:
Livio Spring
2025-10-29 10:05:37 +01:00
committed by GitHub
parent f4503e07cd
commit 72a5c33e6a
10 changed files with 157 additions and 32 deletions

View File

@@ -28,7 +28,10 @@ type Instance interface {
}
type InstanceVerifier interface {
InstanceByHost(ctx context.Context, host, publicDomain string) (Instance, error)
// InstanceByHost returns the instance for the given instanceDomain or publicDomain.
// Previously it used the host (hostname[:port]) to find the instance, but is now using the domain (hostname) only.
// For preventing issues in backports, the name of the method is not changed.
InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (Instance, error)
InstanceByID(ctx context.Context, id string) (Instance, error)
}

View File

@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler co
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
requestContext := zitadel_http.DomainContext(ctx)
if requestContext.InstanceHost == "" {
if requestContext.InstanceDomain() == "" {
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
}
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
if err != nil {
origin := zitadel_http.DomainContext(ctx)
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")

View File

@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req interface{}, handler grpc.Unar
func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) {
requestContext := zitadel_http.DomainContext(ctx)
if requestContext.InstanceHost == "" {
if requestContext.InstanceDomain() == "" {
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
return nil, status.Error(codes.NotFound, "no instanceHost specified")
}
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
if err != nil {
origin := zitadel_http.DomainContext(ctx)
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")

View File

@@ -88,10 +88,10 @@ func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ contex
defer func() { span.EndWithError(err) }()
requestContext := zitadel_http.DomainContext(ctx)
if requestContext.InstanceHost == "" {
if requestContext.InstanceDomain() == "" {
return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound")
}
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceHost, requestContext.PublicHost)
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
if err != nil {
return nil, err
}

View File

@@ -1,11 +1,15 @@
package middleware
import (
"errors"
"net"
"net/http"
"slices"
"strconv"
"github.com/gorilla/mux"
"github.com/muhlemmer/httpforwarded"
"github.com/zitadel/logging"
http_util "github.com/zitadel/zitadel/internal/api/http"
)
@@ -16,7 +20,7 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
origin := composeDomainContext(
r,
enforceHttps,
// to make sure we don't break existing configurations we append the existing checked headers as well
// to make sure we don't break existing configurations, we append the existing checked headers as well
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
publicDomainHeaders,
)
@@ -25,17 +29,13 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
}
}
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx {
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
if instanceHost == "" {
instanceHost = r.Host
}
return &http_util.DomainCtx{
InstanceHost: instanceHost,
Protocol: protocolFromRequest(instanceProto, publicProto, enforceHttps),
PublicHost: publicHost,
}
return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
}
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
@@ -67,7 +67,7 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
hostFromHeader = r.Header.Get(header)
}
if host == "" {
host = hostFromHeader
host = sanitizeHost(hostFromHeader)
}
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
proto = protoFromHeader
@@ -76,6 +76,35 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
return host, proto
}
func sanitizeHost(rawHost string) (host string) {
if rawHost == "" {
return ""
}
host, port, err := net.SplitHostPort(rawHost)
if err != nil {
// if the error is about a missing port, the host is probably just "example.com", so we can return it
if isMissingPortError(err) {
return rawHost
}
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
return ""
}
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
portNumber, err := strconv.Atoi(port)
if err != nil || portNumber < 1 || portNumber > 65535 {
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
return ""
}
// if we reach this point, the host contains a valid port, so we return the complete host
return rawHost
}
func isMissingPortError(err error) bool {
var addrErr *net.AddrError
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
}
func hostFromForwarded(values []string) (string, string) {
fwd, fwdErr := httpforwarded.Parse(values)
if fwdErr == nil {

View File

@@ -205,3 +205,75 @@ func Test_composeOrigin(t *testing.T) {
})
}
}
func Test_sanitizeHost(t *testing.T) {
type args struct {
rawHost string
}
tests := []struct {
name string
args args
want string
}{
{
name: "normal host",
args: args{rawHost: "example.com"},
want: "example.com",
},
{
name: "host with port",
args: args{rawHost: "example.com:8080"},
want: "example.com:8080",
},
{
name: "ipv4",
args: args{rawHost: "192.168.1.1"},
want: "192.168.1.1",
},
{
name: "ipv4 with port",
args: args{rawHost: "192.168.1.1:8080"},
want: "192.168.1.1:8080",
},
{
name: "ipv6",
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"},
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]",
},
{
name: "ipv6 with port",
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080"},
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080",
},
{
name: "host with trailing colon",
args: args{rawHost: "example.com:"},
want: "",
},
{
name: "host with invalid port",
args: args{rawHost: "example.com:port"},
want: "",
},
{
name: "invalid host",
args: args{rawHost: "localhost:@attacker.com"},
want: "",
},
{
name: "invalid host",
args: args{rawHost: "localhost:@attacker.com:8080"},
want: "",
},
{
name: "empty host",
args: args{rawHost: ""},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equalf(t, tt.want, sanitizeHost(tt.args.rawHost), "sanitizeHost(%v)", tt.args.rawHost)
})
}
}

View File

@@ -3,7 +3,8 @@ package http
import (
"context"
"fmt"
"strings"
"net"
"net/url"
)
type DomainCtx struct {
@@ -12,6 +13,35 @@ type DomainCtx struct {
Protocol string
}
func NewDomainCtx(instanceHostname, publicHostname, protocol string) *DomainCtx {
return &DomainCtx{
InstanceHost: instanceHostname,
PublicHost: publicHostname,
Protocol: protocol,
}
}
func NewDomainCtxFromOrigin(origin *url.URL) *DomainCtx {
return &DomainCtx{
InstanceHost: origin.Host,
PublicHost: origin.Host,
Protocol: origin.Scheme,
}
}
// InstanceDomain returns the hostname for which the request was handled.
func (r *DomainCtx) InstanceDomain() string {
return hostnameFromHost(r.InstanceHost)
}
func hostnameFromHost(host string) string {
hostname, _, err := net.SplitHostPort(host)
if err != nil {
return host
}
return hostname
}
// RequestedHost returns the host (hostname[:port]) for which the request was handled.
// The instance host is returned if not public host was set.
func (r *DomainCtx) RequestedHost() string {
@@ -22,13 +52,13 @@ func (r *DomainCtx) RequestedHost() string {
}
// RequestedDomain returns the domain (hostname) for which the request was handled.
// The instance domain is returned if not public host / domain was set.
// The instance domain is returned if no public host / domain was set.
func (r *DomainCtx) RequestedDomain() string {
return strings.Split(r.RequestedHost(), ":")[0]
return hostnameFromHost(r.RequestedHost())
}
// Origin returns the origin (protocol://hostname[:port]) for which the request was handled.
// The instance host is used if not public host was set.
// The instance host is used if no public host was set.
func (r *DomainCtx) Origin() string {
host := r.PublicHost
if host == "" {

View File

@@ -52,10 +52,6 @@ func enrichCtx(ctx context.Context, origin string) (context.Context, error) {
if err != nil {
return nil, err
}
ctx = http_utils.WithDomainContext(ctx, &http_utils.DomainCtx{
InstanceHost: u.Host,
PublicHost: u.Host,
Protocol: u.Scheme,
})
ctx = http_utils.WithDomainContext(ctx, http_utils.NewDomainCtxFromOrigin(u))
return ctx, nil
}

View File

@@ -8,7 +8,6 @@ import (
"errors"
"fmt"
"slices"
"strings"
"time"
sq "github.com/Masterminds/squirrel"
@@ -201,21 +200,18 @@ var (
instanceByIDQuery string
)
func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost string) (_ authz.Instance, err error) {
func (q *Queries) InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (_ authz.Instance, err error) {
var instance *authzInstance
ctx, span := tracing.NewSpan(ctx)
defer func() {
if err != nil {
err = fmt.Errorf("unable to get instance by host: instanceHost %s, publicHost %s: %w", instanceHost, publicHost, err)
err = fmt.Errorf("unable to get instance by domain: instanceDomain %s, publicHostname %s: %w", instanceDomain, publicDomain, err)
} else {
q.caches.activeInstances.Add(instance.ID, true)
}
span.EndWithError(err)
}()
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
if ok {
return instance, instance.checkDomain(instanceDomain, publicDomain)

View File

@@ -2,7 +2,6 @@ package webauthn
import (
"context"
"strings"
"github.com/go-webauthn/webauthn/protocol"
"github.com/go-webauthn/webauthn/webauthn"
@@ -20,7 +19,7 @@ func WebAuthNsToCredentials(ctx context.Context, webAuthNs []*domain.WebAuthNTok
// then we check if the requested rpID matches the instance domain
if webAuthN.State == domain.MFAStateReady &&
(webAuthN.RPID == rpID ||
(webAuthN.RPID == "" && rpID == strings.Split(http.DomainContext(ctx).InstanceHost, ":")[0])) {
(webAuthN.RPID == "" && rpID == http.DomainContext(ctx).InstanceDomain())) {
creds = append(creds, webauthn.Credential{
ID: webAuthN.KeyID,
PublicKey: webAuthN.PublicKey,