fix: discover instance by original host

Merge pull request from GHSA-2wmj-46rj-qm2w

* fix: find instance by original domain

* return instance not found on invalid origin

* test: ensure correct host validation

* test: instance not found is translated

(cherry picked from commit 11d7a8ce61)
This commit is contained in:
Elio Bischof 2023-11-29 11:57:47 +01:00 committed by Livio Spring
parent 31df28380c
commit 6d812137b7
No known key found for this signature in database
GPG Key ID: 26BB1C2FA5952CF0
2 changed files with 80 additions and 7 deletions

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"strings" "strings"
"github.com/rakyll/statik/fs" "github.com/rakyll/statik/fs"
@ -12,6 +13,7 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
caos_errors "github.com/zitadel/zitadel/internal/errors" caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/i18n" "github.com/zitadel/zitadel/internal/i18n"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
@ -73,7 +75,7 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st
host, err := HostFromRequest(r, headerName) host, err := HostFromRequest(r, headerName)
if err != nil { if err != nil {
return nil, err return nil, caos_errors.ThrowNotFound(err, "INST-zWq7X", "Errors.Instance.NotFound")
} }
instance, err := verifier.InstanceByHost(authCtx, host) instance, err := verifier.InstanceByHost(authCtx, host)
@ -84,17 +86,39 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st
return authz.WithInstance(ctx, instance), nil return authz.WithInstance(ctx, instance), nil
} }
func HostFromRequest(r *http.Request, headerName string) (string, error) { func HostFromRequest(r *http.Request, headerName string) (host string, err error) {
host := r.Host
if headerName != "host" { if headerName != "host" {
host = r.Header.Get(headerName) return hostFromSpecialHeader(r, headerName)
} }
return hostFromOrigin(r.Context())
}
func hostFromSpecialHeader(r *http.Request, headerName string) (host string, err error) {
host = r.Header.Get(headerName)
if host == "" { if host == "" {
return "", fmt.Errorf("host header `%s` not found", headerName) return "", fmt.Errorf("host header `%s` not found", headerName)
} }
return host, nil return host, nil
} }
func hostFromOrigin(ctx context.Context) (host string, err error) {
defer func() {
if err != nil {
err = fmt.Errorf("invalid origin: %w", err)
}
}()
origin := zitadel_http.ComposedOrigin(ctx)
u, err := url.Parse(origin)
if err != nil {
return "", err
}
host = u.Hostname()
if host == "" {
err = errors.New("empty host")
}
return host, err
}
func newZitadelTranslator() *i18n.Translator { func newZitadelTranslator() *i18n.Translator {
dir, err := fs.NewWithNamespace("zitadel") dir, err := fs.NewWithNamespace("zitadel")
logging.WithFields("namespace", "zitadel").OnError(err).Panic("unable to get namespace") logging.WithFields("namespace", "zitadel").OnError(err).Panic("unable to get namespace")

View File

@ -12,6 +12,7 @@ import (
"golang.org/x/text/language" "golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
) )
func Test_instanceInterceptor_Handler(t *testing.T) { func Test_instanceInterceptor_Handler(t *testing.T) {
@ -70,6 +71,7 @@ func Test_instanceInterceptor_Handler(t *testing.T) {
a := &instanceInterceptor{ a := &instanceInterceptor{
verifier: tt.fields.verifier, verifier: tt.fields.verifier,
headerName: tt.fields.headerName, headerName: tt.fields.headerName,
translator: newZitadelTranslator(),
} }
next := &testHandler{} next := &testHandler{}
got := a.HandlerFunc(next.ServeHTTP) got := a.HandlerFunc(next.ServeHTTP)
@ -137,6 +139,7 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) {
a := &instanceInterceptor{ a := &instanceInterceptor{
verifier: tt.fields.verifier, verifier: tt.fields.verifier,
headerName: tt.fields.headerName, headerName: tt.fields.headerName,
translator: newZitadelTranslator(),
} }
next := &testHandler{} next := &testHandler{}
got := a.HandlerFunc(next.ServeHTTP) got := a.HandlerFunc(next.ServeHTTP)
@ -164,7 +167,7 @@ func Test_setInstance(t *testing.T) {
res res res res
}{ }{
{ {
"hostname not found, error", "special host header not found, error",
args{ args{
r: func() *http.Request { r: func() *http.Request {
r := httptest.NewRequest("", "/url", nil) r := httptest.NewRequest("", "/url", nil)
@ -179,7 +182,7 @@ func Test_setInstance(t *testing.T) {
}, },
}, },
{ {
"invalid host, error", "special host header invalid, error",
args{ args{
r: func() *http.Request { r: func() *http.Request {
r := httptest.NewRequest("", "/url", nil) r := httptest.NewRequest("", "/url", nil)
@ -195,7 +198,7 @@ func Test_setInstance(t *testing.T) {
}, },
}, },
{ {
"valid host", "special host header valid, ok",
args{ args{
r: func() *http.Request { r: func() *http.Request {
r := httptest.NewRequest("", "/url", nil) r := httptest.NewRequest("", "/url", nil)
@ -210,6 +213,52 @@ func Test_setInstance(t *testing.T) {
err: false, err: false,
}, },
}, },
{
"host from origin if header is not special, ok",
args{
r: func() *http.Request {
r := httptest.NewRequest("", "/url", nil)
r.Header.Set("host", "fromrequest")
return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://fromorigin:9999"))
}(),
verifier: &mockInstanceVerifier{"fromorigin"},
headerName: "host",
},
res{
want: authz.WithInstance(zitadel_http.WithComposedOrigin(context.Background(), "https://fromorigin:9999"), &mockInstance{}),
err: false,
},
},
{
"host from origin, instance not found",
args{
r: func() *http.Request {
r := httptest.NewRequest("", "/url", nil)
return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://fromorigin:9999"))
}(),
verifier: &mockInstanceVerifier{"unknowndomain"},
headerName: "host",
},
res{
want: nil,
err: true,
},
},
{
"host from origin invalid, err",
args{
r: func() *http.Request {
r := httptest.NewRequest("", "/url", nil)
return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://from origin:9999"))
}(),
verifier: &mockInstanceVerifier{"from origin"},
headerName: "host",
},
res{
want: nil,
err: true,
},
},
} }
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {