mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-23 06:16:44 +00:00
Merge commit from fork
* fix: sanitize host headers before use * add additional test
This commit is contained in:
@@ -28,7 +28,10 @@ type Instance interface {
|
||||
}
|
||||
|
||||
type InstanceVerifier interface {
|
||||
InstanceByHost(ctx context.Context, host, publicDomain string) (Instance, error)
|
||||
// InstanceByHost returns the instance for the given instanceDomain or publicDomain.
|
||||
// Previously it used the host (hostname[:port]) to find the instance, but is now using the domain (hostname) only.
|
||||
// For preventing issues in backports, the name of the method is not changed.
|
||||
InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (Instance, error)
|
||||
InstanceByID(ctx context.Context, id string) (Instance, error)
|
||||
}
|
||||
|
||||
|
||||
@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler co
|
||||
|
||||
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
|
||||
requestContext := zitadel_http.DomainContext(ctx)
|
||||
if requestContext.InstanceHost == "" {
|
||||
if requestContext.InstanceDomain() == "" {
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||
if err != nil {
|
||||
origin := zitadel_http.DomainContext(ctx)
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
||||
|
||||
@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req interface{}, handler grpc.Unar
|
||||
|
||||
func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) {
|
||||
requestContext := zitadel_http.DomainContext(ctx)
|
||||
if requestContext.InstanceHost == "" {
|
||||
if requestContext.InstanceDomain() == "" {
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
||||
return nil, status.Error(codes.NotFound, "no instanceHost specified")
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
|
||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||
if err != nil {
|
||||
origin := zitadel_http.DomainContext(ctx)
|
||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
||||
|
||||
@@ -88,10 +88,10 @@ func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ contex
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
requestContext := zitadel_http.DomainContext(ctx)
|
||||
if requestContext.InstanceHost == "" {
|
||||
if requestContext.InstanceDomain() == "" {
|
||||
return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound")
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceHost, requestContext.PublicHost)
|
||||
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
@@ -1,11 +1,15 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"slices"
|
||||
"strconv"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/muhlemmer/httpforwarded"
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
)
|
||||
@@ -16,7 +20,7 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
|
||||
origin := composeDomainContext(
|
||||
r,
|
||||
enforceHttps,
|
||||
// to make sure we don't break existing configurations we append the existing checked headers as well
|
||||
// to make sure we don't break existing configurations, we append the existing checked headers as well
|
||||
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
|
||||
publicDomainHeaders,
|
||||
)
|
||||
@@ -25,17 +29,13 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
|
||||
}
|
||||
}
|
||||
|
||||
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx {
|
||||
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
|
||||
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
|
||||
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
|
||||
if instanceHost == "" {
|
||||
instanceHost = r.Host
|
||||
}
|
||||
return &http_util.DomainCtx{
|
||||
InstanceHost: instanceHost,
|
||||
Protocol: protocolFromRequest(instanceProto, publicProto, enforceHttps),
|
||||
PublicHost: publicHost,
|
||||
}
|
||||
return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
|
||||
}
|
||||
|
||||
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
|
||||
@@ -67,7 +67,7 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
||||
hostFromHeader = r.Header.Get(header)
|
||||
}
|
||||
if host == "" {
|
||||
host = hostFromHeader
|
||||
host = sanitizeHost(hostFromHeader)
|
||||
}
|
||||
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
|
||||
proto = protoFromHeader
|
||||
@@ -76,6 +76,35 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
||||
return host, proto
|
||||
}
|
||||
|
||||
func sanitizeHost(rawHost string) (host string) {
|
||||
if rawHost == "" {
|
||||
return ""
|
||||
}
|
||||
host, port, err := net.SplitHostPort(rawHost)
|
||||
if err != nil {
|
||||
// if the error is about a missing port, the host is probably just "example.com", so we can return it
|
||||
if isMissingPortError(err) {
|
||||
return rawHost
|
||||
}
|
||||
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
|
||||
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
|
||||
return ""
|
||||
}
|
||||
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
|
||||
portNumber, err := strconv.Atoi(port)
|
||||
if err != nil || portNumber < 1 || portNumber > 65535 {
|
||||
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
|
||||
return ""
|
||||
}
|
||||
// if we reach this point, the host contains a valid port, so we return the complete host
|
||||
return rawHost
|
||||
}
|
||||
|
||||
func isMissingPortError(err error) bool {
|
||||
var addrErr *net.AddrError
|
||||
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
|
||||
}
|
||||
|
||||
func hostFromForwarded(values []string) (string, string) {
|
||||
fwd, fwdErr := httpforwarded.Parse(values)
|
||||
if fwdErr == nil {
|
||||
|
||||
@@ -205,3 +205,75 @@ func Test_composeOrigin(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_sanitizeHost(t *testing.T) {
|
||||
type args struct {
|
||||
rawHost string
|
||||
}
|
||||
tests := []struct {
|
||||
name string
|
||||
args args
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "normal host",
|
||||
args: args{rawHost: "example.com"},
|
||||
want: "example.com",
|
||||
},
|
||||
{
|
||||
name: "host with port",
|
||||
args: args{rawHost: "example.com:8080"},
|
||||
want: "example.com:8080",
|
||||
},
|
||||
{
|
||||
name: "ipv4",
|
||||
args: args{rawHost: "192.168.1.1"},
|
||||
want: "192.168.1.1",
|
||||
},
|
||||
{
|
||||
name: "ipv4 with port",
|
||||
args: args{rawHost: "192.168.1.1:8080"},
|
||||
want: "192.168.1.1:8080",
|
||||
},
|
||||
{
|
||||
name: "ipv6",
|
||||
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"},
|
||||
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]",
|
||||
},
|
||||
{
|
||||
name: "ipv6 with port",
|
||||
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080"},
|
||||
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080",
|
||||
},
|
||||
{
|
||||
name: "host with trailing colon",
|
||||
args: args{rawHost: "example.com:"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "host with invalid port",
|
||||
args: args{rawHost: "example.com:port"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid host",
|
||||
args: args{rawHost: "localhost:@attacker.com"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "invalid host",
|
||||
args: args{rawHost: "localhost:@attacker.com:8080"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty host",
|
||||
args: args{rawHost: ""},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equalf(t, tt.want, sanitizeHost(tt.args.rawHost), "sanitizeHost(%v)", tt.args.rawHost)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,8 @@ package http
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"net"
|
||||
"net/url"
|
||||
)
|
||||
|
||||
type DomainCtx struct {
|
||||
@@ -12,6 +13,35 @@ type DomainCtx struct {
|
||||
Protocol string
|
||||
}
|
||||
|
||||
func NewDomainCtx(instanceHostname, publicHostname, protocol string) *DomainCtx {
|
||||
return &DomainCtx{
|
||||
InstanceHost: instanceHostname,
|
||||
PublicHost: publicHostname,
|
||||
Protocol: protocol,
|
||||
}
|
||||
}
|
||||
|
||||
func NewDomainCtxFromOrigin(origin *url.URL) *DomainCtx {
|
||||
return &DomainCtx{
|
||||
InstanceHost: origin.Host,
|
||||
PublicHost: origin.Host,
|
||||
Protocol: origin.Scheme,
|
||||
}
|
||||
}
|
||||
|
||||
// InstanceDomain returns the hostname for which the request was handled.
|
||||
func (r *DomainCtx) InstanceDomain() string {
|
||||
return hostnameFromHost(r.InstanceHost)
|
||||
}
|
||||
|
||||
func hostnameFromHost(host string) string {
|
||||
hostname, _, err := net.SplitHostPort(host)
|
||||
if err != nil {
|
||||
return host
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// RequestedHost returns the host (hostname[:port]) for which the request was handled.
|
||||
// The instance host is returned if not public host was set.
|
||||
func (r *DomainCtx) RequestedHost() string {
|
||||
@@ -22,13 +52,13 @@ func (r *DomainCtx) RequestedHost() string {
|
||||
}
|
||||
|
||||
// RequestedDomain returns the domain (hostname) for which the request was handled.
|
||||
// The instance domain is returned if not public host / domain was set.
|
||||
// The instance domain is returned if no public host / domain was set.
|
||||
func (r *DomainCtx) RequestedDomain() string {
|
||||
return strings.Split(r.RequestedHost(), ":")[0]
|
||||
return hostnameFromHost(r.RequestedHost())
|
||||
}
|
||||
|
||||
// Origin returns the origin (protocol://hostname[:port]) for which the request was handled.
|
||||
// The instance host is used if not public host was set.
|
||||
// The instance host is used if no public host was set.
|
||||
func (r *DomainCtx) Origin() string {
|
||||
host := r.PublicHost
|
||||
if host == "" {
|
||||
|
||||
@@ -52,10 +52,6 @@ func enrichCtx(ctx context.Context, origin string) (context.Context, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ctx = http_utils.WithDomainContext(ctx, &http_utils.DomainCtx{
|
||||
InstanceHost: u.Host,
|
||||
PublicHost: u.Host,
|
||||
Protocol: u.Scheme,
|
||||
})
|
||||
ctx = http_utils.WithDomainContext(ctx, http_utils.NewDomainCtxFromOrigin(u))
|
||||
return ctx, nil
|
||||
}
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
@@ -201,21 +200,18 @@ var (
|
||||
instanceByIDQuery string
|
||||
)
|
||||
|
||||
func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost string) (_ authz.Instance, err error) {
|
||||
func (q *Queries) InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (_ authz.Instance, err error) {
|
||||
var instance *authzInstance
|
||||
ctx, span := tracing.NewSpan(ctx)
|
||||
defer func() {
|
||||
if err != nil {
|
||||
err = fmt.Errorf("unable to get instance by host: instanceHost %s, publicHost %s: %w", instanceHost, publicHost, err)
|
||||
err = fmt.Errorf("unable to get instance by domain: instanceDomain %s, publicHostname %s: %w", instanceDomain, publicDomain, err)
|
||||
} else {
|
||||
q.caches.activeInstances.Add(instance.ID, true)
|
||||
}
|
||||
span.EndWithError(err)
|
||||
}()
|
||||
|
||||
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
|
||||
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
|
||||
|
||||
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
|
||||
if ok {
|
||||
return instance, instance.checkDomain(instanceDomain, publicDomain)
|
||||
|
||||
@@ -2,7 +2,6 @@ package webauthn
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"github.com/go-webauthn/webauthn/protocol"
|
||||
"github.com/go-webauthn/webauthn/webauthn"
|
||||
@@ -20,7 +19,7 @@ func WebAuthNsToCredentials(ctx context.Context, webAuthNs []*domain.WebAuthNTok
|
||||
// then we check if the requested rpID matches the instance domain
|
||||
if webAuthN.State == domain.MFAStateReady &&
|
||||
(webAuthN.RPID == rpID ||
|
||||
(webAuthN.RPID == "" && rpID == strings.Split(http.DomainContext(ctx).InstanceHost, ":")[0])) {
|
||||
(webAuthN.RPID == "" && rpID == http.DomainContext(ctx).InstanceDomain())) {
|
||||
creds = append(creds, webauthn.Credential{
|
||||
ID: webAuthN.KeyID,
|
||||
PublicKey: webAuthN.PublicKey,
|
||||
|
||||
Reference in New Issue
Block a user