fix: improve interceptor handling (#3578)

* fix: improve interceptor handling

* fix: improve interceptor handling

Co-authored-by: Florian Forster <florian@caos.ch>
This commit is contained in:
Livio Amstutz 2022-05-02 17:26:54 +02:00 committed by GitHub
parent 20f275f178
commit 06a1b52adf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 81 additions and 28 deletions

View File

@ -174,10 +174,10 @@ func startAPIs(ctx context.Context, router *mux.Router, commands *command.Comman
return err return err
} }
instanceInterceptor := middleware.InstanceInterceptor(queries, config.HTTP1HostHeader) instanceInterceptor := middleware.InstanceInterceptor(queries, config.HTTP1HostHeader, login.IgnoreInstanceEndpoints...)
authenticatedAPIs.RegisterHandler(assets.HandlerPrefix, assets.NewHandler(commands, verifier, config.InternalAuthZ, id.SonyFlakeGenerator, store, queries, instanceInterceptor.Handler)) authenticatedAPIs.RegisterHandler(assets.HandlerPrefix, assets.NewHandler(commands, verifier, config.InternalAuthZ, id.SonyFlakeGenerator, store, queries, instanceInterceptor.Handler))
userAgentInterceptor, err := middleware.NewUserAgentHandler(config.UserAgentCookie, keys.UserAgentCookieKey, id.SonyFlakeGenerator, config.ExternalSecure) userAgentInterceptor, err := middleware.NewUserAgentHandler(config.UserAgentCookie, keys.UserAgentCookieKey, id.SonyFlakeGenerator, config.ExternalSecure, login.EndpointResources)
if err != nil { if err != nil {
return err return err
} }

View File

@ -11,6 +11,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
type InstanceVerifier interface { type InstanceVerifier interface {
@ -24,20 +25,23 @@ func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, ign
} }
func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, headerName string, ignoredServices ...string) (_ interface{}, err error) { func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, headerName string, ignoredServices ...string) (_ interface{}, err error) {
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
defer func() { span.EndWithError(err) }()
for _, service := range ignoredServices { for _, service := range ignoredServices {
if strings.HasPrefix(info.FullMethod, service) { if strings.HasPrefix(info.FullMethod, service) {
return handler(ctx, req) return handler(ctx, req)
} }
} }
host, err := hostNameFromContext(ctx, headerName) host, err := hostNameFromContext(interceptorCtx, headerName)
if err != nil { if err != nil {
return nil, status.Error(codes.PermissionDenied, err.Error()) return nil, status.Error(codes.PermissionDenied, err.Error())
} }
instance, err := verifier.InstanceByHost(ctx, host) instance, err := verifier.InstanceByHost(interceptorCtx, host)
if err != nil { if err != nil {
return nil, status.Error(codes.PermissionDenied, err.Error()) return nil, status.Error(codes.PermissionDenied, err.Error())
} }
span.End()
return handler(authz.WithInstance(ctx, instance), req) return handler(authz.WithInstance(ctx, instance), req)
} }

View File

@ -136,4 +136,11 @@ func (c *CookieHandler) httpSet(w http.ResponseWriter, name, domain, value strin
Secure: c.secureOnly, Secure: c.secureOnly,
SameSite: c.sameSite, SameSite: c.sameSite,
}) })
varyValues := w.Header().Values("vary")
for _, vary := range varyValues {
if vary == "Cookie" {
return
}
}
w.Header().Add("vary", "Cookie")
} }

View File

@ -4,25 +4,34 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/zitadel/zitadel/internal/api/authz" "github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
type instanceInterceptor struct { type instanceInterceptor struct {
verifier authz.InstanceVerifier verifier authz.InstanceVerifier
headerName string headerName string
ignoredPrefixes []string
} }
func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string) *instanceInterceptor { func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, ignoredPrefixes ...string) *instanceInterceptor {
return &instanceInterceptor{ return &instanceInterceptor{
verifier: verifier, verifier: verifier,
headerName: headerName, headerName: headerName,
ignoredPrefixes: ignoredPrefixes,
} }
} }
func (a *instanceInterceptor) Handler(next http.Handler) http.Handler { func (a *instanceInterceptor) Handler(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
for _, prefix := range a.ignoredPrefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
}
ctx, err := setInstance(r, a.verifier, a.headerName) ctx, err := setInstance(r, a.verifier, a.headerName)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusUnauthorized) http.Error(w, err.Error(), http.StatusUnauthorized)
@ -35,6 +44,12 @@ func (a *instanceInterceptor) Handler(next http.Handler) http.Handler {
func (a *instanceInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc { func (a *instanceInterceptor) HandlerFunc(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
for _, prefix := range a.ignoredPrefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
next.ServeHTTP(w, r)
return
}
}
ctx, err := setInstance(r, a.verifier, a.headerName) ctx, err := setInstance(r, a.verifier, a.headerName)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusForbidden) http.Error(w, err.Error(), http.StatusForbidden)

View File

@ -3,6 +3,7 @@ package middleware
import ( import (
"context" "context"
"net/http" "net/http"
"strings"
"time" "time"
http_utils "github.com/zitadel/zitadel/internal/api/http" http_utils "github.com/zitadel/zitadel/internal/api/http"
@ -26,10 +27,11 @@ type UserAgent struct {
} }
type userAgentHandler struct { type userAgentHandler struct {
cookieHandler *http_utils.CookieHandler cookieHandler *http_utils.CookieHandler
cookieName string cookieName string
idGenerator id.Generator idGenerator id.Generator
nextHandler http.Handler nextHandler http.Handler
ignoredPrefixes []string
} }
type UserAgentCookieConfig struct { type UserAgentCookieConfig struct {
@ -37,7 +39,7 @@ type UserAgentCookieConfig struct {
MaxAge time.Duration MaxAge time.Duration
} }
func NewUserAgentHandler(config *UserAgentCookieConfig, cookieKey []byte, idGenerator id.Generator, externalSecure bool) (func(http.Handler) http.Handler, error) { func NewUserAgentHandler(config *UserAgentCookieConfig, cookieKey []byte, idGenerator id.Generator, externalSecure bool, ignoredPrefixes ...string) (func(http.Handler) http.Handler, error) {
opts := []http_utils.CookieHandlerOpt{ opts := []http_utils.CookieHandlerOpt{
http_utils.WithEncryption(cookieKey, cookieKey), http_utils.WithEncryption(cookieKey, cookieKey),
http_utils.WithMaxAge(int(config.MaxAge.Seconds())), http_utils.WithMaxAge(int(config.MaxAge.Seconds())),
@ -47,15 +49,22 @@ func NewUserAgentHandler(config *UserAgentCookieConfig, cookieKey []byte, idGene
} }
return func(handler http.Handler) http.Handler { return func(handler http.Handler) http.Handler {
return &userAgentHandler{ return &userAgentHandler{
nextHandler: handler, nextHandler: handler,
cookieName: config.Name, cookieName: config.Name,
cookieHandler: http_utils.NewCookieHandler(opts...), cookieHandler: http_utils.NewCookieHandler(opts...),
idGenerator: idGenerator, idGenerator: idGenerator,
ignoredPrefixes: ignoredPrefixes,
} }
}, nil }, nil
} }
func (ua *userAgentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (ua *userAgentHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
for _, prefix := range ua.ignoredPrefixes {
if strings.HasPrefix(r.URL.Path, prefix) {
ua.nextHandler.ServeHTTP(w, r)
return
}
}
agent, err := ua.getUserAgent(r) agent, err := ua.getUserAgent(r)
if err != nil { if err != nil {
agent, err = ua.newUserAgent() agent, err = ua.newUserAgent()

View File

@ -201,11 +201,11 @@ func setOIDCCtx(ctx context.Context) context.Context {
func retry(retryable func() error) (err error) { func retry(retryable func() error) (err error) {
for i := 0; i < retryCount; i++ { for i := 0; i < retryCount; i++ {
time.Sleep(retryBackoff)
err = retryable() err = retryable()
if err == nil { if err == nil {
return nil return nil
} }
time.Sleep(retryBackoff)
} }
return err return err
} }

View File

@ -74,7 +74,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, inst
handler := mux.NewRouter() handler := mux.NewRouter()
handler.Use(cache, security) handler.Use(cache, security)
handler.Handle(envRequestPath, instanceHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler.Handle(envRequestPath, middleware.TelemetryHandler()(instanceHandler(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
instance := authz.GetInstance(r.Context()) instance := authz.GetInstance(r.Context())
if instance.InstanceID() == "" { if instance.InstanceID() == "" {
http.Error(w, "empty instanceID", http.StatusInternalServerError) http.Error(w, "empty instanceID", http.StatusInternalServerError)
@ -88,7 +88,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, inst
} }
_, err = w.Write(environmentJSON) _, err = w.Write(environmentJSON)
logging.OnError(err).Error("error serving environment.json") logging.OnError(err).Error("error serving environment.json")
}))) }))))
handler.SkipClean(true).PathPrefix("").Handler(http.FileServer(&spaHandler{http.FS(fSys)})) handler.SkipClean(true).PathPrefix("").Handler(http.FileServer(&spaHandler{http.FS(fSys)}))
return handler, nil return handler, nil
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"fmt" "fmt"
"net/http" "net/http"
"strings"
"github.com/gorilla/csrf" "github.com/gorilla/csrf"
"github.com/gorilla/mux" "github.com/gorilla/mux"
@ -93,7 +94,7 @@ func CreateLogin(config Config,
} }
security := middleware.SecurityHeaders(csp(), login.cspErrorHandler) security := middleware.SecurityHeaders(csp(), login.cspErrorHandler)
login.router = CreateRouter(login, statikFS, instanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, middleware.TelemetryHandler(EndpointResources), issuerInterceptor) login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), instanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
login.renderer = CreateRenderer(HandlerPrefix, statikFS, staticStorage, config.LanguageCookieName) login.renderer = CreateRenderer(HandlerPrefix, statikFS, staticStorage, config.LanguageCookieName)
login.parser = form.NewParser() login.parser = form.NewParser()
return login, nil return login, nil
@ -109,12 +110,20 @@ func csp() *middleware.CSP {
func createCSRFInterceptor(cookieName string, csrfCookieKey []byte, externalSecure bool, errorHandler http.Handler) (func(http.Handler) http.Handler, error) { func createCSRFInterceptor(cookieName string, csrfCookieKey []byte, externalSecure bool, errorHandler http.Handler) (func(http.Handler) http.Handler, error) {
path := "/" path := "/"
return csrf.Protect(csrfCookieKey, return func(handler http.Handler) http.Handler {
csrf.Secure(externalSecure), return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
csrf.CookieName(http_utils.SetCookiePrefix(cookieName, "", path, externalSecure)), if strings.HasPrefix(r.URL.Path, EndpointResources) {
csrf.Path(path), handler.ServeHTTP(w, r)
csrf.ErrorHandler(errorHandler), return
), nil }
csrf.Protect(csrfCookieKey,
csrf.Secure(externalSecure),
csrf.CookieName(http_utils.SetCookiePrefix(cookieName, "", path, externalSecure)),
csrf.Path(path),
csrf.ErrorHandler(errorHandler),
)(handler).ServeHTTP(w, r)
})
}, nil
} }
func (l *Login) Handler() http.Handler { func (l *Login) Handler() http.Handler {

View File

@ -46,6 +46,15 @@ const (
EndpointDynamicResources = "/resources/dynamic" EndpointDynamicResources = "/resources/dynamic"
) )
var (
IgnoreInstanceEndpoints = []string{
EndpointResources + "/fonts",
EndpointResources + "/images",
EndpointResources + "/scripts",
EndpointResources + "/themes",
}
)
func CreateRouter(login *Login, staticDir http.FileSystem, interceptors ...mux.MiddlewareFunc) *mux.Router { func CreateRouter(login *Login, staticDir http.FileSystem, interceptors ...mux.MiddlewareFunc) *mux.Router {
router := mux.NewRouter() router := mux.NewRouter()
router.Use(interceptors...) router.Use(interceptors...)