diff --git a/internal/api/authz/instance.go b/internal/api/authz/instance.go index 1eacb2383c5..368753c0e26 100644 --- a/internal/api/authz/instance.go +++ b/internal/api/authz/instance.go @@ -28,7 +28,10 @@ type Instance interface { } type InstanceVerifier interface { - InstanceByHost(ctx context.Context, host, publicDomain string) (Instance, error) + // InstanceByHost returns the instance for the given instanceDomain or publicDomain. + // Previously it used the host (hostname[:port]) to find the instance, but is now using the domain (hostname) only. + // For preventing issues in backports, the name of the method is not changed. + InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (Instance, error) InstanceByID(ctx context.Context, id string) (Instance, error) } diff --git a/internal/api/grpc/server/connect_middleware/instance_interceptor.go b/internal/api/grpc/server/connect_middleware/instance_interceptor.go index c8f517857e1..4855dfd47a1 100644 --- a/internal/api/grpc/server/connect_middleware/instance_interceptor.go +++ b/internal/api/grpc/server/connect_middleware/instance_interceptor.go @@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req connect.AnyRequest, handler co func addInstanceByRequestedHost(ctx context.Context, req connect.AnyRequest, handler connect.UnaryFunc, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (connect.AnyResponse, error) { requestContext := zitadel_http.DomainContext(ctx) - if requestContext.InstanceHost == "" { + if requestContext.InstanceDomain() == "" { logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance") return nil, connect.NewError(connect.CodeNotFound, errors.New("no instanceHost specified")) } - instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost) + instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain()) if err != nil { origin := zitadel_http.DomainContext(ctx) logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance") diff --git a/internal/api/grpc/server/middleware/instance_interceptor.go b/internal/api/grpc/server/middleware/instance_interceptor.go index 30925a99caa..e9971d233aa 100644 --- a/internal/api/grpc/server/middleware/instance_interceptor.go +++ b/internal/api/grpc/server/middleware/instance_interceptor.go @@ -88,11 +88,11 @@ func addInstanceByDomain(ctx context.Context, req interface{}, handler grpc.Unar func addInstanceByRequestedHost(ctx context.Context, req interface{}, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, translator *i18n.Translator, externalDomain string) (interface{}, error) { requestContext := zitadel_http.DomainContext(ctx) - if requestContext.InstanceHost == "" { + if requestContext.InstanceDomain() == "" { logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).Error("unable to set instance") return nil, status.Error(codes.NotFound, "no instanceHost specified") } - instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceHost, requestContext.PublicHost) + instance, err := verifier.InstanceByHost(ctx, requestContext.InstanceDomain(), requestContext.RequestedDomain()) if err != nil { origin := zitadel_http.DomainContext(ctx) logging.WithFields("origin", requestContext.Origin(), "externalDomain", externalDomain).WithError(err).Error("unable to set instance") diff --git a/internal/api/http/middleware/instance_interceptor.go b/internal/api/http/middleware/instance_interceptor.go index 321eab5cb02..587e2ad22f2 100644 --- a/internal/api/http/middleware/instance_interceptor.go +++ b/internal/api/http/middleware/instance_interceptor.go @@ -88,10 +88,10 @@ func setInstance(ctx context.Context, verifier authz.InstanceVerifier) (_ contex defer func() { span.EndWithError(err) }() requestContext := zitadel_http.DomainContext(ctx) - if requestContext.InstanceHost == "" { + if requestContext.InstanceDomain() == "" { return nil, zerrors.ThrowNotFound(err, "INST-zWq7X", "Errors.IAM.NotFound") } - instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceHost, requestContext.PublicHost) + instance, err := verifier.InstanceByHost(authCtx, requestContext.InstanceDomain(), requestContext.RequestedDomain()) if err != nil { return nil, err } diff --git a/internal/api/http/middleware/origin_interceptor.go b/internal/api/http/middleware/origin_interceptor.go index 607855b80f0..8528fb2a1af 100644 --- a/internal/api/http/middleware/origin_interceptor.go +++ b/internal/api/http/middleware/origin_interceptor.go @@ -1,11 +1,15 @@ package middleware import ( + "errors" + "net" "net/http" "slices" + "strconv" "github.com/gorilla/mux" "github.com/muhlemmer/httpforwarded" + "github.com/zitadel/logging" http_util "github.com/zitadel/zitadel/internal/api/http" ) @@ -16,7 +20,7 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost origin := composeDomainContext( r, enforceHttps, - // to make sure we don't break existing configurations we append the existing checked headers as well + // to make sure we don't break existing configurations, we append the existing checked headers as well slices.Compact(append(instanceHostHeaders, http1Header, http2Header, http_util.Forwarded, http_util.ZitadelForwarded, http_util.ForwardedFor, http_util.ForwardedHost, http_util.ForwardedProto)), publicDomainHeaders, ) @@ -25,17 +29,13 @@ func WithOrigin(enforceHttps bool, http1Header, http2Header string, instanceHost } } -func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) *http_util.DomainCtx { +func composeDomainContext(r *http.Request, enforceHttps bool, instanceDomainHeaders, publicDomainHeaders []string) (_ *http_util.DomainCtx) { instanceHost, instanceProto := hostFromRequest(r, instanceDomainHeaders) publicHost, publicProto := hostFromRequest(r, publicDomainHeaders) if instanceHost == "" { instanceHost = r.Host } - return &http_util.DomainCtx{ - InstanceHost: instanceHost, - Protocol: protocolFromRequest(instanceProto, publicProto, enforceHttps), - PublicHost: publicHost, - } + return http_util.NewDomainCtx(instanceHost, publicHost, protocolFromRequest(instanceProto, publicProto, enforceHttps)) } func protocolFromRequest(instanceProto, publicProto string, enforceHttps bool) string { @@ -67,7 +67,7 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) { hostFromHeader = r.Header.Get(header) } if host == "" { - host = hostFromHeader + host = sanitizeHost(hostFromHeader) } if proto == "" && (protoFromHeader == "http" || protoFromHeader == "https") { proto = protoFromHeader @@ -76,6 +76,35 @@ func hostFromRequest(r *http.Request, headers []string) (host, proto string) { return host, proto } +func sanitizeHost(rawHost string) (host string) { + if rawHost == "" { + return "" + } + host, port, err := net.SplitHostPort(rawHost) + if err != nil { + // if the error is about a missing port, the host is probably just "example.com", so we can return it + if isMissingPortError(err) { + return rawHost + } + // if the error is about something else, the host is probably invalid, so we log it and return an empty string + logging.WithFields("host", rawHost).Warning("invalid host header, ignoring header") + return "" + } + // if the port is not numeric, the host was probably something like "localhost:@attacker.com" + portNumber, err := strconv.Atoi(port) + if err != nil || portNumber < 1 || portNumber > 65535 { + logging.WithFields("host", rawHost).Warning("invalid port in host header, ignoring header") + return "" + } + // if we reach this point, the host contains a valid port, so we return the complete host + return rawHost +} + +func isMissingPortError(err error) bool { + var addrErr *net.AddrError + return errors.As(err, &addrErr) && (addrErr.Err == "missing port in address") +} + func hostFromForwarded(values []string) (string, string) { fwd, fwdErr := httpforwarded.Parse(values) if fwdErr == nil { diff --git a/internal/api/http/middleware/origin_interceptor_test.go b/internal/api/http/middleware/origin_interceptor_test.go index 7419c91aba0..9dcc9b63927 100644 --- a/internal/api/http/middleware/origin_interceptor_test.go +++ b/internal/api/http/middleware/origin_interceptor_test.go @@ -205,3 +205,75 @@ func Test_composeOrigin(t *testing.T) { }) } } + +func Test_sanitizeHost(t *testing.T) { + type args struct { + rawHost string + } + tests := []struct { + name string + args args + want string + }{ + { + name: "normal host", + args: args{rawHost: "example.com"}, + want: "example.com", + }, + { + name: "host with port", + args: args{rawHost: "example.com:8080"}, + want: "example.com:8080", + }, + { + name: "ipv4", + args: args{rawHost: "192.168.1.1"}, + want: "192.168.1.1", + }, + { + name: "ipv4 with port", + args: args{rawHost: "192.168.1.1:8080"}, + want: "192.168.1.1:8080", + }, + { + name: "ipv6", + args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]"}, + want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]", + }, + { + name: "ipv6 with port", + args: args{rawHost: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080"}, + want: "[2001:0db8:85a3:0000:0000:8a2e:0370:7334]:8080", + }, + { + name: "host with trailing colon", + args: args{rawHost: "example.com:"}, + want: "", + }, + { + name: "host with invalid port", + args: args{rawHost: "example.com:port"}, + want: "", + }, + { + name: "invalid host", + args: args{rawHost: "localhost:@attacker.com"}, + want: "", + }, + { + name: "invalid host", + args: args{rawHost: "localhost:@attacker.com:8080"}, + want: "", + }, + { + name: "empty host", + args: args{rawHost: ""}, + want: "", + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + assert.Equalf(t, tt.want, sanitizeHost(tt.args.rawHost), "sanitizeHost(%v)", tt.args.rawHost) + }) + } +} diff --git a/internal/api/http/request_context.go b/internal/api/http/request_context.go index 9ed345ed880..5cec32d91d8 100644 --- a/internal/api/http/request_context.go +++ b/internal/api/http/request_context.go @@ -3,7 +3,8 @@ package http import ( "context" "fmt" - "strings" + "net" + "net/url" ) type DomainCtx struct { @@ -12,6 +13,35 @@ type DomainCtx struct { Protocol string } +func NewDomainCtx(instanceHostname, publicHostname, protocol string) *DomainCtx { + return &DomainCtx{ + InstanceHost: instanceHostname, + PublicHost: publicHostname, + Protocol: protocol, + } +} + +func NewDomainCtxFromOrigin(origin *url.URL) *DomainCtx { + return &DomainCtx{ + InstanceHost: origin.Host, + PublicHost: origin.Host, + Protocol: origin.Scheme, + } +} + +// InstanceDomain returns the hostname for which the request was handled. +func (r *DomainCtx) InstanceDomain() string { + return hostnameFromHost(r.InstanceHost) +} + +func hostnameFromHost(host string) string { + hostname, _, err := net.SplitHostPort(host) + if err != nil { + return host + } + return hostname +} + // RequestedHost returns the host (hostname[:port]) for which the request was handled. // The instance host is returned if not public host was set. func (r *DomainCtx) RequestedHost() string { @@ -22,13 +52,13 @@ func (r *DomainCtx) RequestedHost() string { } // RequestedDomain returns the domain (hostname) for which the request was handled. -// The instance domain is returned if not public host / domain was set. +// The instance domain is returned if no public host / domain was set. func (r *DomainCtx) RequestedDomain() string { - return strings.Split(r.RequestedHost(), ":")[0] + return hostnameFromHost(r.RequestedHost()) } // Origin returns the origin (protocol://hostname[:port]) for which the request was handled. -// The instance host is used if not public host was set. +// The instance host is used if no public host was set. func (r *DomainCtx) Origin() string { host := r.PublicHost if host == "" { diff --git a/internal/notification/handlers/origin.go b/internal/notification/handlers/origin.go index 8846f5e2dc0..6b5af15d3f2 100644 --- a/internal/notification/handlers/origin.go +++ b/internal/notification/handlers/origin.go @@ -52,10 +52,6 @@ func enrichCtx(ctx context.Context, origin string) (context.Context, error) { if err != nil { return nil, err } - ctx = http_utils.WithDomainContext(ctx, &http_utils.DomainCtx{ - InstanceHost: u.Host, - PublicHost: u.Host, - Protocol: u.Scheme, - }) + ctx = http_utils.WithDomainContext(ctx, http_utils.NewDomainCtxFromOrigin(u)) return ctx, nil } diff --git a/internal/query/instance.go b/internal/query/instance.go index aab76ad5fc7..626d561477e 100644 --- a/internal/query/instance.go +++ b/internal/query/instance.go @@ -8,7 +8,6 @@ import ( "errors" "fmt" "slices" - "strings" "time" sq "github.com/Masterminds/squirrel" @@ -201,21 +200,18 @@ var ( instanceByIDQuery string ) -func (q *Queries) InstanceByHost(ctx context.Context, instanceHost, publicHost string) (_ authz.Instance, err error) { +func (q *Queries) InstanceByHost(ctx context.Context, instanceDomain, publicDomain string) (_ authz.Instance, err error) { var instance *authzInstance ctx, span := tracing.NewSpan(ctx) defer func() { if err != nil { - err = fmt.Errorf("unable to get instance by host: instanceHost %s, publicHost %s: %w", instanceHost, publicHost, err) + err = fmt.Errorf("unable to get instance by domain: instanceDomain %s, publicHostname %s: %w", instanceDomain, publicDomain, err) } else { q.caches.activeInstances.Add(instance.ID, true) } span.EndWithError(err) }() - instanceDomain := strings.Split(instanceHost, ":")[0] // remove possible port - publicDomain := strings.Split(publicHost, ":")[0] // remove possible port - instance, ok := q.caches.instance.Get(ctx, instanceIndexByHost, instanceDomain) if ok { return instance, instance.checkDomain(instanceDomain, publicDomain) diff --git a/internal/webauthn/converter.go b/internal/webauthn/converter.go index c914bb8bf9b..0b220582663 100644 --- a/internal/webauthn/converter.go +++ b/internal/webauthn/converter.go @@ -2,7 +2,6 @@ package webauthn import ( "context" - "strings" "github.com/go-webauthn/webauthn/protocol" "github.com/go-webauthn/webauthn/webauthn" @@ -20,7 +19,7 @@ func WebAuthNsToCredentials(ctx context.Context, webAuthNs []*domain.WebAuthNTok // then we check if the requested rpID matches the instance domain if webAuthN.State == domain.MFAStateReady && (webAuthN.RPID == rpID || - (webAuthN.RPID == "" && rpID == strings.Split(http.DomainContext(ctx).InstanceHost, ":")[0])) { + (webAuthN.RPID == "" && rpID == http.DomainContext(ctx).InstanceDomain())) { creds = append(creds, webauthn.Credential{ ID: webAuthN.KeyID, PublicKey: webAuthN.PublicKey,