mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-15 04:18:01 +00:00
fix: discover instance by original host
Merge pull request from GHSA-2wmj-46rj-qm2w
* fix: find instance by original domain
* return instance not found on invalid origin
* test: ensure correct host validation
* test: instance not found is translated
(cherry picked from commit 11d7a8ce61
)
This commit is contained in:
parent
31df28380c
commit
6d812137b7
@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/rakyll/statik/fs"
|
"github.com/rakyll/statik/fs"
|
||||||
@ -12,6 +13,7 @@ import (
|
|||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/api/authz"
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
|
||||||
caos_errors "github.com/zitadel/zitadel/internal/errors"
|
caos_errors "github.com/zitadel/zitadel/internal/errors"
|
||||||
"github.com/zitadel/zitadel/internal/i18n"
|
"github.com/zitadel/zitadel/internal/i18n"
|
||||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||||
@ -73,7 +75,7 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st
|
|||||||
|
|
||||||
host, err := HostFromRequest(r, headerName)
|
host, err := HostFromRequest(r, headerName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, caos_errors.ThrowNotFound(err, "INST-zWq7X", "Errors.Instance.NotFound")
|
||||||
}
|
}
|
||||||
|
|
||||||
instance, err := verifier.InstanceByHost(authCtx, host)
|
instance, err := verifier.InstanceByHost(authCtx, host)
|
||||||
@ -84,17 +86,39 @@ func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName st
|
|||||||
return authz.WithInstance(ctx, instance), nil
|
return authz.WithInstance(ctx, instance), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func HostFromRequest(r *http.Request, headerName string) (string, error) {
|
func HostFromRequest(r *http.Request, headerName string) (host string, err error) {
|
||||||
host := r.Host
|
|
||||||
if headerName != "host" {
|
if headerName != "host" {
|
||||||
host = r.Header.Get(headerName)
|
return hostFromSpecialHeader(r, headerName)
|
||||||
}
|
}
|
||||||
|
return hostFromOrigin(r.Context())
|
||||||
|
}
|
||||||
|
|
||||||
|
func hostFromSpecialHeader(r *http.Request, headerName string) (host string, err error) {
|
||||||
|
host = r.Header.Get(headerName)
|
||||||
if host == "" {
|
if host == "" {
|
||||||
return "", fmt.Errorf("host header `%s` not found", headerName)
|
return "", fmt.Errorf("host header `%s` not found", headerName)
|
||||||
}
|
}
|
||||||
return host, nil
|
return host, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hostFromOrigin(ctx context.Context) (host string, err error) {
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
err = fmt.Errorf("invalid origin: %w", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
origin := zitadel_http.ComposedOrigin(ctx)
|
||||||
|
u, err := url.Parse(origin)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
host = u.Hostname()
|
||||||
|
if host == "" {
|
||||||
|
err = errors.New("empty host")
|
||||||
|
}
|
||||||
|
return host, err
|
||||||
|
}
|
||||||
|
|
||||||
func newZitadelTranslator() *i18n.Translator {
|
func newZitadelTranslator() *i18n.Translator {
|
||||||
dir, err := fs.NewWithNamespace("zitadel")
|
dir, err := fs.NewWithNamespace("zitadel")
|
||||||
logging.WithFields("namespace", "zitadel").OnError(err).Panic("unable to get namespace")
|
logging.WithFields("namespace", "zitadel").OnError(err).Panic("unable to get namespace")
|
||||||
|
@ -12,6 +12,7 @@ import (
|
|||||||
"golang.org/x/text/language"
|
"golang.org/x/text/language"
|
||||||
|
|
||||||
"github.com/zitadel/zitadel/internal/api/authz"
|
"github.com/zitadel/zitadel/internal/api/authz"
|
||||||
|
zitadel_http "github.com/zitadel/zitadel/internal/api/http"
|
||||||
)
|
)
|
||||||
|
|
||||||
func Test_instanceInterceptor_Handler(t *testing.T) {
|
func Test_instanceInterceptor_Handler(t *testing.T) {
|
||||||
@ -70,6 +71,7 @@ func Test_instanceInterceptor_Handler(t *testing.T) {
|
|||||||
a := &instanceInterceptor{
|
a := &instanceInterceptor{
|
||||||
verifier: tt.fields.verifier,
|
verifier: tt.fields.verifier,
|
||||||
headerName: tt.fields.headerName,
|
headerName: tt.fields.headerName,
|
||||||
|
translator: newZitadelTranslator(),
|
||||||
}
|
}
|
||||||
next := &testHandler{}
|
next := &testHandler{}
|
||||||
got := a.HandlerFunc(next.ServeHTTP)
|
got := a.HandlerFunc(next.ServeHTTP)
|
||||||
@ -137,6 +139,7 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) {
|
|||||||
a := &instanceInterceptor{
|
a := &instanceInterceptor{
|
||||||
verifier: tt.fields.verifier,
|
verifier: tt.fields.verifier,
|
||||||
headerName: tt.fields.headerName,
|
headerName: tt.fields.headerName,
|
||||||
|
translator: newZitadelTranslator(),
|
||||||
}
|
}
|
||||||
next := &testHandler{}
|
next := &testHandler{}
|
||||||
got := a.HandlerFunc(next.ServeHTTP)
|
got := a.HandlerFunc(next.ServeHTTP)
|
||||||
@ -164,7 +167,7 @@ func Test_setInstance(t *testing.T) {
|
|||||||
res res
|
res res
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
"hostname not found, error",
|
"special host header not found, error",
|
||||||
args{
|
args{
|
||||||
r: func() *http.Request {
|
r: func() *http.Request {
|
||||||
r := httptest.NewRequest("", "/url", nil)
|
r := httptest.NewRequest("", "/url", nil)
|
||||||
@ -179,7 +182,7 @@ func Test_setInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"invalid host, error",
|
"special host header invalid, error",
|
||||||
args{
|
args{
|
||||||
r: func() *http.Request {
|
r: func() *http.Request {
|
||||||
r := httptest.NewRequest("", "/url", nil)
|
r := httptest.NewRequest("", "/url", nil)
|
||||||
@ -195,7 +198,7 @@ func Test_setInstance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"valid host",
|
"special host header valid, ok",
|
||||||
args{
|
args{
|
||||||
r: func() *http.Request {
|
r: func() *http.Request {
|
||||||
r := httptest.NewRequest("", "/url", nil)
|
r := httptest.NewRequest("", "/url", nil)
|
||||||
@ -210,6 +213,52 @@ func Test_setInstance(t *testing.T) {
|
|||||||
err: false,
|
err: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"host from origin if header is not special, ok",
|
||||||
|
args{
|
||||||
|
r: func() *http.Request {
|
||||||
|
r := httptest.NewRequest("", "/url", nil)
|
||||||
|
r.Header.Set("host", "fromrequest")
|
||||||
|
return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://fromorigin:9999"))
|
||||||
|
}(),
|
||||||
|
verifier: &mockInstanceVerifier{"fromorigin"},
|
||||||
|
headerName: "host",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
want: authz.WithInstance(zitadel_http.WithComposedOrigin(context.Background(), "https://fromorigin:9999"), &mockInstance{}),
|
||||||
|
err: false,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"host from origin, instance not found",
|
||||||
|
args{
|
||||||
|
r: func() *http.Request {
|
||||||
|
r := httptest.NewRequest("", "/url", nil)
|
||||||
|
return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://fromorigin:9999"))
|
||||||
|
}(),
|
||||||
|
verifier: &mockInstanceVerifier{"unknowndomain"},
|
||||||
|
headerName: "host",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
want: nil,
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"host from origin invalid, err",
|
||||||
|
args{
|
||||||
|
r: func() *http.Request {
|
||||||
|
r := httptest.NewRequest("", "/url", nil)
|
||||||
|
return r.WithContext(zitadel_http.WithComposedOrigin(r.Context(), "https://from origin:9999"))
|
||||||
|
}(),
|
||||||
|
verifier: &mockInstanceVerifier{"from origin"},
|
||||||
|
headerName: "host",
|
||||||
|
},
|
||||||
|
res{
|
||||||
|
want: nil,
|
||||||
|
err: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
Loading…
Reference in New Issue
Block a user