fix: instance interceptors return NotFound (404) error for unknown hosts (#4184)

* fix: instance interceptors return "NotFound" (404) error for unknown hosts

* fix tests
This commit is contained in:
Livio Spring
2022-08-17 08:07:41 +02:00
committed by GitHub
parent d0733b3185
commit d656b3f3c9
11 changed files with 71 additions and 46 deletions

View File

@@ -2,11 +2,18 @@ package middleware
import (
"context"
"errors"
"fmt"
"net/http"
"strings"
"github.com/rakyll/statik/fs"
"github.com/zitadel/logging"
"golang.org/x/text/language"
"github.com/zitadel/zitadel/internal/api/authz"
caos_errors "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/i18n"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
@@ -14,6 +21,7 @@ type instanceInterceptor struct {
verifier authz.InstanceVerifier
headerName string
ignoredPrefixes []string
translator *i18n.Translator
}
func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, ignoredPrefixes ...string) *instanceInterceptor {
@@ -21,43 +29,40 @@ func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, ign
verifier: verifier,
headerName: headerName,
ignoredPrefixes: ignoredPrefixes,
translator: newZitadelTranslator(),
}
}
func (a *instanceInterceptor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, prefix := range a.ignoredPrefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
}
ctx, err := setInstance(r, a.verifier, a.headerName)
if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized)
return
}
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
a.handleInstance(w, r, next)
})
}
func (a *instanceInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
for _, prefix := range a.ignoredPrefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
}
ctx, err := setInstance(r, a.verifier, a.headerName)
if err != nil {
http.Error(w, err.Error(), http.StatusForbidden)
a.handleInstance(w, r, next)
}
}
func (a *instanceInterceptor) handleInstance(w http.ResponseWriter, r *http.Request, next http.Handler) {
for _, prefix := range a.ignoredPrefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
ctx, err := setInstance(r, a.verifier, a.headerName)
if err != nil {
caosErr := new(caos_errors.NotFoundError)
if errors.As(err, &caosErr) {
caosErr.Message = a.translator.LocalizeFromRequest(r, caosErr.GetMessage(), nil)
}
http.Error(w, err.Error(), http.StatusNotFound)
return
}
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
func setInstance(r *http.Request, verifier authz.InstanceVerifier, headerName string) (_ context.Context, err error) {
@@ -89,3 +94,12 @@ func HostFromRequest(r *http.Request, headerName string) (string, error) {
}
return host, nil
}
func newZitadelTranslator() *i18n.Translator {
dir, err := fs.NewWithNamespace("zitadel")
logging.WithFields("namespace", "zitadel").OnError(err).Panic("unable to get namespace")
translator, err := i18n.NewTranslator(dir, language.English, "")
logging.OnError(err).Panic("unable to get translator")
return translator
}

View File

@@ -42,7 +42,7 @@ func Test_instanceInterceptor_Handler(t *testing.T) {
request: httptest.NewRequest("", "/url", nil),
},
res{
statusCode: 403,
statusCode: 404,
context: nil,
},
},
@@ -109,7 +109,7 @@ func Test_instanceInterceptor_HandlerFunc(t *testing.T) {
request: httptest.NewRequest("", "/url", nil),
},
res{
statusCode: 403,
statusCode: 404,
context: nil,
},
},
@@ -229,7 +229,7 @@ type testHandler struct {
context context.Context
}
func (t *testHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
func (t *testHandler) ServeHTTP(_ http.ResponseWriter, r *http.Request) {
t.context = r.Context()
}
@@ -237,7 +237,7 @@ type mockInstanceVerifier struct {
host string
}
func (m *mockInstanceVerifier) InstanceByHost(ctx context.Context, host string) (authz.Instance, error) {
func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (authz.Instance, error) {
if host != m.host {
return nil, fmt.Errorf("invalid host")
}