diff --git a/internal/api/http/middleware/instance_interceptor.go b/internal/api/http/middleware/instance_interceptor.go index 70edd7b5c1..276037301d 100644 --- a/internal/api/http/middleware/instance_interceptor.go +++ b/internal/api/http/middleware/instance_interceptor.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" "net/http" + "net/url" "strings" "github.com/rakyll/statik/fs" @@ -12,6 +13,7 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" + zitadel_http "github.com/zitadel/zitadel/internal/api/http" caos_errors "github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/i18n" "github.com/zitadel/zitadel/internal/telemetry/tracing" @@ -73,7 +75,7 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st host, err := HostFromRequest(r, headerName) if err != nil { - return nil, err + return nil, caos_errors.ThrowNotFound(err, "INST-zWq7X", "Errors.Instance.NotFound") } instance, err := verifier.InstanceByHost(authCtx, host) @@ -84,17 +86,39 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st return authz.WithInstance(ctx, instance), nil } -func HostFromRequest(r *http.Request, headerName string) (string, error) { - host := r.Host +func HostFromRequest(r *http.Request, headerName string) (host string, err error) { if headerName != "host" { - host = r.Header.Get(headerName) + return hostFromSpecialHeader(r, headerName) } + return hostFromOrigin(r.Context()) +} + +func hostFromSpecialHeader(r *http.Request, headerName string) (host string, err error) { + host = r.Header.Get(headerName) if host == "" { return "", fmt.Errorf("host header `%s` not found", headerName) } return host, nil } +func hostFromOrigin(ctx context.Context) (host string, err error) { + defer func() { + if err != nil { + err = fmt.Errorf("invalid origin: %w", err) + } + }() + origin := zitadel_http.ComposedOrigin(ctx) + u, err := url.Parse(origin) + if err != nil { + return "", err + } + host = u.Hostname() + if host == "" { + err = errors.New("empty host") + } + return host, err +} + func newZitadelTranslator() *i18n.Translator { dir, err := fs.NewWithNamespace("zitadel") logging.WithFields("namespace", "zitadel").OnError(err).Panic("unable to get namespace") diff --git a/internal/api/http/middleware/instance_interceptor_test.go b/internal/api/http/middleware/instance_interceptor_test.go index 639cede84a..e61fade72d 100644 --- a/internal/api/http/middleware/instance_interceptor_test.go +++ b/internal/api/http/middleware/instance_interceptor_test.go @@ -12,6 +12,7 @@ import ( "golang.org/x/text/language" "github.com/zitadel/zitadel/internal/api/authz" + zitadel_http "github.com/zitadel/zitadel/internal/api/http" ) func Test_instanceInterceptor_Handler(t *testing.T) { @@ -70,6 +71,7 @@ func Test_instanceInterceptor_Handler(t *testing.T) { a := &instanceInterceptor{ verifier: tt.fields.verifier, headerName: tt.fields.headerName, + translator: newZitadelTranslator(), } next := &testHandler{} got := a.HandlerFunc(next.ServeHTTP) @@ -137,6 +139,7 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) { a := &instanceInterceptor{ verifier: tt.fields.verifier, headerName: tt.fields.headerName, + translator: newZitadelTranslator(), } next := &testHandler{} got := a.HandlerFunc(next.ServeHTTP) @@ -164,7 +167,7 @@ func Test_setInstance(t *testing.T) { res res }{ { - "hostname not found, error", + "special host header not found, error", args{ r: func() *http.Request { r := httptest.NewRequest("", "/url", nil) @@ -179,7 +182,7 @@ func Test_setInstance(t *testing.T) { }, }, { - "invalid host, error", + "special host header invalid, error", args{ r: func() *http.Request { r := httptest.NewRequest("", "/url", nil) @@ -195,7 +198,7 @@ func Test_setInstance(t *testing.T) { }, }, { - "valid host", + "special host header valid, ok", args{ r: func() *http.Request { r := httptest.NewRequest("", "/url", nil) @@ -210,6 +213,52 @@ func Test_setInstance(t *testing.T) { err: false, }, }, + { + "host from origin if header is not special, ok", + args{ + r: func() *http.Request { + r := httptest.NewRequest("", "/url", nil) + r.Header.Set("host", "fromrequest") + return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://fromorigin:9999")) + }(), + verifier: &mockInstanceVerifier{"fromorigin"}, + headerName: "host", + }, + res{ + want: authz.WithInstance(zitadel_http.WithComposedOrigin(context.Background(), "https://fromorigin:9999"), &mockInstance{}), + err: false, + }, + }, + { + "host from origin, instance not found", + args{ + r: func() *http.Request { + r := httptest.NewRequest("", "/url", nil) + return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://fromorigin:9999")) + }(), + verifier: &mockInstanceVerifier{"unknowndomain"}, + headerName: "host", + }, + res{ + want: nil, + err: true, + }, + }, + { + "host from origin invalid, err", + args{ + r: func() *http.Request { + r := httptest.NewRequest("", "/url", nil) + return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://from origin:9999")) + }(), + verifier: &mockInstanceVerifier{"from origin"}, + headerName: "host", + }, + res{ + want: nil, + err: true, + }, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) {