mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-23 06:06:42 +00:00
Merge commit from fork
* fix: sanitize host headers before use * add additional test
This commit is contained in:
@@ -28,7 +28,10 @@ type Instance interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type InstanceVerifier interface {
|
type InstanceVerifier interface {
|
||||||
InstanceByHost(ctx context.Context, host, publicDomain string) (Instance, error)
|
// InstanceByHost returns the instance for the given instanceDomain or publicDomain.
|
||||||
|
// Previously it used the host (hostname[:port]) to find the instance, but is now using the domain (hostname) only.
|
||||||
|
// For preventing issues in backports, the name of the method is not changed.
|
||||||
|
InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (Instance, error)
|
||||||
InstanceByID(ctx context.Context, id string) (Instance, error)
|
InstanceByID(ctx context.Context, id string) (Instance, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler co
|
|||||||
|
|
||||||
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
|
func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) {
|
||||||
requestContext := zitadel_http.DomainContext(ctx)
|
requestContext := zitadel_http.DomainContext(ctx)
|
||||||
if requestContext.InstanceHost == "" {
|
if requestContext.InstanceDomain() == "" {
|
||||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
||||||
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
|
return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified"))
|
||||||
}
|
}
|
||||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
|
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
origin := zitadel_http.DomainContext(ctx)
|
origin := zitadel_http.DomainContext(ctx)
|
||||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
||||||
|
|||||||
@@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req interface{}, handler grpc.Unar
|
|||||||
|
|
||||||
func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) {
|
func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) {
|
||||||
requestContext := zitadel_http.DomainContext(ctx)
|
requestContext := zitadel_http.DomainContext(ctx)
|
||||||
if requestContext.InstanceHost == "" {
|
if requestContext.InstanceDomain() == "" {
|
||||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance")
|
||||||
return nil, status.Error(codes.NotFound, "no instanceHost specified")
|
return nil, status.Error(codes.NotFound, "no instanceHost specified")
|
||||||
}
|
}
|
||||||
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost)
|
instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
origin := zitadel_http.DomainContext(ctx)
|
origin := zitadel_http.DomainContext(ctx)
|
||||||
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance")
|
||||||
|
|||||||
@@ -88,10 +88,10 @@ func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ contex
|
|||||||
defer func() { span.EndWithError(err) }()
|
defer func() { span.EndWithError(err) }()
|
||||||
|
|
||||||
requestContext := zitadel_http.DomainContext(ctx)
|
requestContext := zitadel_http.DomainContext(ctx)
|
||||||
if requestContext.InstanceHost == "" {
|
if requestContext.InstanceDomain() == "" {
|
||||||
return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound")
|
return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound")
|
||||||
}
|
}
|
||||||
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceHost, requestContext.PublicHost)
|
instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceDomain(), requestContext.RequestedDomain())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,11 +1,15 @@
|
|||||||
package middleware
|
package middleware
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"errors"
|
||||||
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"github.com/gorilla/mux"
|
"github.com/gorilla/mux"
|
||||||
"github.com/muhlemmer/httpforwarded"
|
"github.com/muhlemmer/httpforwarded"
|
||||||
|
"github.com/zitadel/logging"
|
||||||
|
|
||||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||||
)
|
)
|
||||||
@@ -16,7 +20,7 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
|
|||||||
origin := composeDomainContext(
|
origin := composeDomainContext(
|
||||||
r,
|
r,
|
||||||
enforceHttps,
|
enforceHttps,
|
||||||
// to make sure we don't break existing configurations we append the existing checked headers as well
|
// to make sure we don't break existing configurations, we append the existing checked headers as well
|
||||||
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
|
slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)),
|
||||||
publicDomainHeaders,
|
publicDomainHeaders,
|
||||||
)
|
)
|
||||||
@@ -25,17 +29,13 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx {
|
func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) {
|
||||||
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
|
instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders)
|
||||||
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
|
publicHost, publicProto := hostFromRequest(r, publicDomainHeaders)
|
||||||
if instanceHost == "" {
|
if instanceHost == "" {
|
||||||
instanceHost = r.Host
|
instanceHost = r.Host
|
||||||
}
|
}
|
||||||
return &http_util.DomainCtx{
|
return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps))
|
||||||
InstanceHost: instanceHost,
|
|
||||||
Protocol: protocolFromRequest(instanceProto, publicProto, enforceHttps),
|
|
||||||
PublicHost: publicHost,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
|
func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string {
|
||||||
@@ -67,7 +67,7 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
|||||||
hostFromHeader = r.Header.Get(header)
|
hostFromHeader = r.Header.Get(header)
|
||||||
}
|
}
|
||||||
if host == "" {
|
if host == "" {
|
||||||
host = hostFromHeader
|
host = sanitizeHost(hostFromHeader)
|
||||||
}
|
}
|
||||||
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
|
if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") {
|
||||||
proto = protoFromHeader
|
proto = protoFromHeader
|
||||||
@@ -76,6 +76,35 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) {
|
|||||||
return host, proto
|
return host, proto
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func sanitizeHost(rawHost string) (host string) {
|
||||||
|
if rawHost == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
host, port, err := net.SplitHostPort(rawHost)
|
||||||
|
if err != nil {
|
||||||
|
// if the error is about a missing port, the host is probably just "example.com", so we can return it
|
||||||
|
if isMissingPortError(err) {
|
||||||
|
return rawHost
|
||||||
|
}
|
||||||
|
// if the error is about something else, the host is probably invalid, so we log it and return an empty string
|
||||||
|
logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// if the port is not numeric, the host was probably something like "localhost:@attacker.com"
|
||||||
|
portNumber, err := strconv.Atoi(port)
|
||||||
|
if err != nil || portNumber < 1 || portNumber > 65535 {
|
||||||
|
logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header")
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
// if we reach this point, the host contains a valid port, so we return the complete host
|
||||||
|
return rawHost
|
||||||
|
}
|
||||||
|
|
||||||
|
func isMissingPortError(err error) bool {
|
||||||
|
var addrErr *net.AddrError
|
||||||
|
return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address")
|
||||||
|
}
|
||||||
|
|
||||||
func hostFromForwarded(values []string) (string, string) {
|
func hostFromForwarded(values []string) (string, string) {
|
||||||
fwd, fwdErr := httpforwarded.Parse(values)
|
fwd, fwdErr := httpforwarded.Parse(values)
|
||||||
if fwdErr == nil {
|
if fwdErr == nil {
|
||||||
|
|||||||
@@ -205,3 +205,75 @@ func Test_composeOrigin(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func Test_sanitizeHost(t *testing.T) {
|
||||||
|
type args struct {
|
||||||
|
rawHost string
|
||||||
|
}
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
args args
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "normal host",
|
||||||
|
args: args{rawHost: "example.com"},
|
||||||
|
want: "example.com",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host with port",
|
||||||
|
args: args{rawHost: "example.com:8080"},
|
||||||
|
want: "example.com:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4",
|
||||||
|
args: args{rawHost: "192.168.1.1"},
|
||||||
|
want: "192.168.1.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv4 with port",
|
||||||
|
args: args{rawHost: "192.168.1.1:8080"},
|
||||||
|
want: "192.168.1.1:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6",
|
||||||
|
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"},
|
||||||
|
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ipv6 with port",
|
||||||
|
args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080"},
|
||||||
|
want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host with trailing colon",
|
||||||
|
args: args{rawHost: "example.com:"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "host with invalid port",
|
||||||
|
args: args{rawHost: "example.com:port"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid host",
|
||||||
|
args: args{rawHost: "localhost:@attacker.com"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid host",
|
||||||
|
args: args{rawHost: "localhost:@attacker.com:8080"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty host",
|
||||||
|
args: args{rawHost: ""},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
assert.Equalf(t, tt.want, sanitizeHost(tt.args.rawHost), "sanitizeHost(%v)", tt.args.rawHost)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,7 +3,8 @@ package http
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"net"
|
||||||
|
"net/url"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DomainCtx struct {
|
type DomainCtx struct {
|
||||||
@@ -12,6 +13,35 @@ type DomainCtx struct {
|
|||||||
Protocol string
|
Protocol string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewDomainCtx(instanceHostname, publicHostname, protocol string) *DomainCtx {
|
||||||
|
return &DomainCtx{
|
||||||
|
InstanceHost: instanceHostname,
|
||||||
|
PublicHost: publicHostname,
|
||||||
|
Protocol: protocol,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDomainCtxFromOrigin(origin *url.URL) *DomainCtx {
|
||||||
|
return &DomainCtx{
|
||||||
|
InstanceHost: origin.Host,
|
||||||
|
PublicHost: origin.Host,
|
||||||
|
Protocol: origin.Scheme,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// InstanceDomain returns the hostname for which the request was handled.
|
||||||
|
func (r *DomainCtx) InstanceDomain() string {
|
||||||
|
return hostnameFromHost(r.InstanceHost)
|
||||||
|
}
|
||||||
|
|
||||||
|
func hostnameFromHost(host string) string {
|
||||||
|
hostname, _, err := net.SplitHostPort(host)
|
||||||
|
if err != nil {
|
||||||
|
return host
|
||||||
|
}
|
||||||
|
return hostname
|
||||||
|
}
|
||||||
|
|
||||||
// RequestedHost returns the host (hostname[:port]) for which the request was handled.
|
// RequestedHost returns the host (hostname[:port]) for which the request was handled.
|
||||||
// The instance host is returned if not public host was set.
|
// The instance host is returned if not public host was set.
|
||||||
func (r *DomainCtx) RequestedHost() string {
|
func (r *DomainCtx) RequestedHost() string {
|
||||||
@@ -22,13 +52,13 @@ func (r *DomainCtx) RequestedHost() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RequestedDomain returns the domain (hostname) for which the request was handled.
|
// RequestedDomain returns the domain (hostname) for which the request was handled.
|
||||||
// The instance domain is returned if not public host / domain was set.
|
// The instance domain is returned if no public host / domain was set.
|
||||||
func (r *DomainCtx) RequestedDomain() string {
|
func (r *DomainCtx) RequestedDomain() string {
|
||||||
return strings.Split(r.RequestedHost(), ":")[0]
|
return hostnameFromHost(r.RequestedHost())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Origin returns the origin (protocol://hostname[:port]) for which the request was handled.
|
// Origin returns the origin (protocol://hostname[:port]) for which the request was handled.
|
||||||
// The instance host is used if not public host was set.
|
// The instance host is used if no public host was set.
|
||||||
func (r *DomainCtx) Origin() string {
|
func (r *DomainCtx) Origin() string {
|
||||||
host := r.PublicHost
|
host := r.PublicHost
|
||||||
if host == "" {
|
if host == "" {
|
||||||
|
|||||||
@@ -52,10 +52,6 @@ func enrichCtx(ctx context.Context, origin string) (context.Context, error) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
ctx = http_utils.WithDomainContext(ctx, &http_utils.DomainCtx{
|
ctx = http_utils.WithDomainContext(ctx, http_utils.NewDomainCtxFromOrigin(u))
|
||||||
InstanceHost: u.Host,
|
|
||||||
PublicHost: u.Host,
|
|
||||||
Protocol: u.Scheme,
|
|
||||||
})
|
|
||||||
return ctx, nil
|
return ctx, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
sq "github.com/Masterminds/squirrel"
|
sq "github.com/Masterminds/squirrel"
|
||||||
@@ -201,21 +200,18 @@ var (
|
|||||||
instanceByIDQuery string
|
instanceByIDQuery string
|
||||||
)
|
)
|
||||||
|
|
||||||
func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost string) (_ authz.Instance, err error) {
|
func (q *Queries) InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (_ authz.Instance, err error) {
|
||||||
var instance *authzInstance
|
var instance *authzInstance
|
||||||
ctx, span := tracing.NewSpan(ctx)
|
ctx, span := tracing.NewSpan(ctx)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
err = fmt.Errorf("unable to get instance by host: instanceHost %s, publicHost %s: %w", instanceHost, publicHost, err)
|
err = fmt.Errorf("unable to get instance by domain: instanceDomain %s, publicHostname %s: %w", instanceDomain, publicDomain, err)
|
||||||
} else {
|
} else {
|
||||||
q.caches.activeInstances.Add(instance.ID, true)
|
q.caches.activeInstances.Add(instance.ID, true)
|
||||||
}
|
}
|
||||||
span.EndWithError(err)
|
span.EndWithError(err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port
|
|
||||||
publicDomain := strings.Split(publicHost, ":")[0] // remove possible port
|
|
||||||
|
|
||||||
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
|
instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain)
|
||||||
if ok {
|
if ok {
|
||||||
return instance, instance.checkDomain(instanceDomain, publicDomain)
|
return instance, instance.checkDomain(instanceDomain, publicDomain)
|
||||||
|
|||||||
@@ -2,7 +2,6 @@ package webauthn
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/go-webauthn/webauthn/protocol"
|
"github.com/go-webauthn/webauthn/protocol"
|
||||||
"github.com/go-webauthn/webauthn/webauthn"
|
"github.com/go-webauthn/webauthn/webauthn"
|
||||||
@@ -20,7 +19,7 @@ func WebAuthNsToCredentials(ctx context.Context, webAuthNs []*domain.WebAuthNTok
|
|||||||
// then we check if the requested rpID matches the instance domain
|
// then we check if the requested rpID matches the instance domain
|
||||||
if webAuthN.State == domain.MFAStateReady &&
|
if webAuthN.State == domain.MFAStateReady &&
|
||||||
(webAuthN.RPID == rpID ||
|
(webAuthN.RPID == rpID ||
|
||||||
(webAuthN.RPID == "" && rpID == strings.Split(http.DomainContext(ctx).InstanceHost, ":")[0])) {
|
(webAuthN.RPID == "" && rpID == http.DomainContext(ctx).InstanceDomain())) {
|
||||||
creds = append(creds, webauthn.Credential{
|
creds = append(creds, webauthn.Credential{
|
||||||
ID: webAuthN.KeyID,
|
ID: webAuthN.KeyID,
|
||||||
PublicKey: webAuthN.PublicKey,
|
PublicKey: webAuthN.PublicKey,
|
||||||
|
|||||||
Reference in New Issue
Block a user