fix: set exhausted cookie with env json (#5868)

* fix: set exhausted cookie with env json

* lint
This commit is contained in:
Elio Bischof 2023-05-15 08:51:02 +02:00 committed by GitHub
parent b449762aed
commit 0e251a29c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 102 deletions

View File

@ -306,9 +306,8 @@ func startAPIs(
http_util.WithNonHttpOnly(), http_util.WithNonHttpOnly(),
http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))), http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))),
) )
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access, false) limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access)
nonLimitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, nil, config.Quotas.Access, true) apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, limitingAccessInterceptor)
apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, accessSvc, exhaustedCookieHandler, config.Quotas.Access)
if err != nil { if err != nil {
return fmt.Errorf("error creating api %w", err) return fmt.Errorf("error creating api %w", err)
} }
@ -376,7 +375,7 @@ func startAPIs(
} }
apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler()) apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler())
c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, nonLimitingAccessInterceptor.Handle, config.CustomerPortal) c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, limitingAccessInterceptor, config.CustomerPortal)
if err != nil { if err != nil {
return fmt.Errorf("unable to start console: %w", err) return fmt.Errorf("unable to start console: %w", err)
} }

View File

@ -19,24 +19,22 @@ import (
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware" http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/api/ui/login" "github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/errors" "github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/metrics" "github.com/zitadel/zitadel/internal/telemetry/metrics"
"github.com/zitadel/zitadel/internal/telemetry/tracing" "github.com/zitadel/zitadel/internal/telemetry/tracing"
) )
type API struct { type API struct {
port uint16 port uint16
grpcServer *grpc.Server grpcServer *grpc.Server
verifier *internal_authz.TokenVerifier verifier *internal_authz.TokenVerifier
health healthCheck health healthCheck
router *mux.Router router *mux.Router
http1HostName string http1HostName string
grpcGateway *server.Gateway grpcGateway *server.Gateway
healthServer *health.Server healthServer *health.Server
cookieHandler *http_util.CookieHandler accessInterceptor *http_mw.AccessInterceptor
cookieConfig *http_mw.AccessConfig queries *query.Queries
queries *query.Queries
} }
type healthCheck interface { type healthCheck interface {
@ -51,23 +49,20 @@ func New(
verifier *internal_authz.TokenVerifier, verifier *internal_authz.TokenVerifier,
authZ internal_authz.Config, authZ internal_authz.Config,
tlsConfig *tls.Config, http2HostName, http1HostName string, tlsConfig *tls.Config, http2HostName, http1HostName string,
accessSvc *logstore.Service, accessInterceptor *http_mw.AccessInterceptor,
cookieHandler *http_util.CookieHandler,
cookieConfig *http_mw.AccessConfig,
) (_ *API, err error) { ) (_ *API, err error) {
api := &API{ api := &API{
port: port, port: port,
verifier: verifier, verifier: verifier,
health: queries, health: queries,
router: router, router: router,
http1HostName: http1HostName, http1HostName: http1HostName,
cookieConfig: cookieConfig, queries: queries,
cookieHandler: cookieHandler, accessInterceptor: accessInterceptor,
queries: queries,
} }
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc) api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessInterceptor.AccessService())
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, cookieHandler, cookieConfig) api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, accessInterceptor)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -90,8 +85,7 @@ func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayP
grpcServer, grpcServer,
a.port, a.port,
a.http1HostName, a.http1HostName,
a.cookieHandler, a.accessInterceptor,
a.cookieConfig,
a.queries, a.queries,
) )
if err != nil { if err != nil {

View File

@ -16,7 +16,6 @@ import (
client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware" client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware" "github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_utils "github.com/zitadel/zitadel/internal/api/http"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware" http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/query"
) )
@ -66,16 +65,15 @@ var (
) )
type Gateway struct { type Gateway struct {
mux *runtime.ServeMux mux *runtime.ServeMux
http1HostName string http1HostName string
connection *grpc.ClientConn connection *grpc.ClientConn
cookieHandler *http_utils.CookieHandler accessInterceptor *http_mw.AccessInterceptor
cookieConfig *http_mw.AccessConfig queries *query.Queries
queries *query.Queries
} }
func (g *Gateway) Handler() http.Handler { func (g *Gateway) Handler() http.Handler {
return addInterceptors(g.mux, g.http1HostName, g.cookieHandler, g.cookieConfig, g.queries) return addInterceptors(g.mux, g.http1HostName, g.accessInterceptor, g.queries)
} }
type CustomHTTPResponse interface { type CustomHTTPResponse interface {
@ -89,8 +87,7 @@ func CreateGatewayWithPrefix(
g WithGatewayPrefix, g WithGatewayPrefix,
port uint16, port uint16,
http1HostName string, http1HostName string,
cookieHandler *http_utils.CookieHandler, accessInterceptor *http_mw.AccessInterceptor,
cookieConfig *http_mw.AccessConfig,
queries *query.Queries, queries *query.Queries,
) (http.Handler, string, error) { ) (http.Handler, string, error) {
runtimeMux := runtime.NewServeMux(serveMuxOptions...) runtimeMux := runtime.NewServeMux(serveMuxOptions...)
@ -106,10 +103,10 @@ func CreateGatewayWithPrefix(
if err != nil { if err != nil {
return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err) return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err)
} }
return addInterceptors(runtimeMux, http1HostName, cookieHandler, cookieConfig, queries), g.GatewayPathPrefix(), nil return addInterceptors(runtimeMux, http1HostName, accessInterceptor, queries), g.GatewayPathPrefix(), nil
} }
func CreateGateway(ctx context.Context, port uint16, http1HostName string, cookieHandler *http_utils.CookieHandler, cookieConfig *http_mw.AccessConfig) (*Gateway, error) { func CreateGateway(ctx context.Context, port uint16, http1HostName string, accessInterceptor *http_mw.AccessInterceptor) (*Gateway, error) {
connection, err := dial(ctx, connection, err := dial(ctx,
port, port,
[]grpc.DialOption{ []grpc.DialOption{
@ -121,11 +118,10 @@ func CreateGateway(ctx context.Context, port uint16, http1HostName string, cooki
} }
runtimeMux := runtime.NewServeMux(append(serveMuxOptions, runtime.WithHealthzEndpoint(healthpb.NewHealthClient(connection)))...) runtimeMux := runtime.NewServeMux(append(serveMuxOptions, runtime.WithHealthzEndpoint(healthpb.NewHealthClient(connection)))...)
return &Gateway{ return &Gateway{
mux: runtimeMux, mux: runtimeMux,
http1HostName: http1HostName, http1HostName: http1HostName,
connection: connection, connection: connection,
cookieHandler: cookieHandler, accessInterceptor: accessInterceptor,
cookieConfig: cookieConfig,
}, nil }, nil
} }
@ -163,8 +159,7 @@ func dial(ctx context.Context, port uint16, opts []grpc.DialOption) (*grpc.Clien
func addInterceptors( func addInterceptors(
handler http.Handler, handler http.Handler,
http1HostName string, http1HostName string,
cookieHandler *http_utils.CookieHandler, accessInterceptor *http_mw.AccessInterceptor,
cookieConfig *http_mw.AccessConfig,
queries *query.Queries, queries *query.Queries,
) http.Handler { ) http.Handler {
handler = http_mw.CallDurationHandler(handler) handler = http_mw.CallDurationHandler(handler)
@ -174,7 +169,7 @@ func addInterceptors(
handler = http_mw.DefaultTelemetryHandler(handler) handler = http_mw.DefaultTelemetryHandler(handler)
// For some non-obvious reason, the exhaustedCookieInterceptor sends the SetCookie header // For some non-obvious reason, the exhaustedCookieInterceptor sends the SetCookie header
// only if it follows the http_mw.DefaultTelemetryHandler // only if it follows the http_mw.DefaultTelemetryHandler
handler = exhaustedCookieInterceptor(handler, cookieHandler, cookieConfig, queries) handler = exhaustedCookieInterceptor(handler, accessInterceptor, queries)
handler = http_mw.DefaultMetricsHandler(handler) handler = http_mw.DefaultMetricsHandler(handler)
return handler return handler
} }
@ -193,35 +188,32 @@ func http1Host(next http.Handler, http1HostName string) http.Handler {
func exhaustedCookieInterceptor( func exhaustedCookieInterceptor(
next http.Handler, next http.Handler,
cookieHandler *http_utils.CookieHandler, accessInterceptor *http_mw.AccessInterceptor,
cookieConfig *http_mw.AccessConfig,
queries *query.Queries, queries *query.Queries,
) http.Handler { ) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
next.ServeHTTP(&cookieResponseWriter{ next.ServeHTTP(&cookieResponseWriter{
ResponseWriter: writer, ResponseWriter: writer,
cookieHandler: cookieHandler, accessInterceptor: accessInterceptor,
cookieConfig: cookieConfig, request: request,
request: request, queries: queries,
queries: queries,
}, request) }, request)
}) })
} }
type cookieResponseWriter struct { type cookieResponseWriter struct {
http.ResponseWriter http.ResponseWriter
cookieHandler *http_utils.CookieHandler accessInterceptor *http_mw.AccessInterceptor
cookieConfig *http_mw.AccessConfig request *http.Request
request *http.Request queries *query.Queries
queries *query.Queries
} }
func (r *cookieResponseWriter) WriteHeader(status int) { func (r *cookieResponseWriter) WriteHeader(status int) {
if status >= 200 && status < 300 { if status >= 200 && status < 300 {
http_mw.DeleteExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.request, r.cookieConfig) r.accessInterceptor.DeleteExhaustedCookie(r.ResponseWriter, r.request)
} }
if status == http.StatusTooManyRequests { if status == http.StatusTooManyRequests {
http_mw.SetExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.cookieConfig, r.request) r.accessInterceptor.SetExhaustedCookie(r.ResponseWriter, r.request)
} }
r.ResponseWriter.WriteHeader(status) r.ResponseWriter.WriteHeader(status)
} }

View File

@ -1,6 +1,7 @@
package middleware package middleware
import ( import (
"context"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
@ -32,15 +33,54 @@ type AccessConfig struct {
// NewAccessInterceptor intercepts all requests and stores them to the logstore. // NewAccessInterceptor intercepts all requests and stores them to the logstore.
// If storeOnly is false, it also checks if requests are exhausted. // If storeOnly is false, it also checks if requests are exhausted.
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie // If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig, storeOnly bool) *AccessInterceptor { func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
return &AccessInterceptor{ return &AccessInterceptor{
svc: svc, svc: svc,
cookieHandler: cookieHandler, cookieHandler: cookieHandler,
limitConfig: cookieConfig, limitConfig: cookieConfig,
storeOnly: storeOnly,
} }
} }
func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor {
return &AccessInterceptor{
svc: a.svc,
cookieHandler: a.cookieHandler,
limitConfig: a.limitConfig,
storeOnly: true,
}
}
func (a *AccessInterceptor) AccessService() *logstore.Service {
return a.svc
}
func (a *AccessInterceptor) Limit(ctx context.Context) bool {
if !a.svc.Enabled() || a.storeOnly {
return false
}
instance := authz.GetInstance(ctx)
remaining := a.svc.Limit(ctx, instance.InstanceID())
return remaining != nil && *remaining <= 0
}
func (a *AccessInterceptor) SetExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
a.cookieHandler.SetCookie(writer, a.limitConfig.ExhaustedCookieKey, domain, cookieValue)
}
func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
a.cookieHandler.DeleteCookie(writer, request, a.limitConfig.ExhaustedCookieKey)
}
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if !a.svc.Enabled() { if !a.svc.Enabled() {
return next return next
@ -49,23 +89,16 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
ctx := request.Context() ctx := request.Context()
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess") tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0} wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
instance := authz.GetInstance(ctx) limited := a.Limit(tracingCtx)
limit := false
if !a.storeOnly {
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
limit = remaining != nil && *remaining == 0
}
checkSpan.End() checkSpan.End()
if limit { if limited {
// Limit can only be true when storeOnly is false, so set the cookie and the response code a.SetExhaustedCookie(wrappedWriter, request)
SetExhaustedCookie(a.cookieHandler, wrappedWriter, a.limitConfig, request)
http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests) http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests)
} else { }
if !a.storeOnly { if !limited && !a.storeOnly {
// If not limited and not storeOnly, ensure the cookie is deleted a.DeleteExhaustedCookie(wrappedWriter, request)
DeleteExhaustedCookie(a.cookieHandler, wrappedWriter, request, a.limitConfig) }
} if !limited {
// Always serve if not limited
next.ServeHTTP(wrappedWriter, request) next.ServeHTTP(wrappedWriter, request)
} }
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess") tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
@ -75,6 +108,7 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if err != nil { if err != nil {
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url") logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
} }
instance := authz.GetInstance(tracingCtx)
a.svc.Handle(tracingCtx, &access.Record{ a.svc.Handle(tracingCtx, &access.Record{
LogDate: time.Now(), LogDate: time.Now(),
Protocol: access.HTTP, Protocol: access.HTTP,
@ -90,24 +124,6 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
}) })
} }
func SetExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, cookieConfig *AccessConfig, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
cookieHandler.SetCookie(writer, cookieConfig.ExhaustedCookieKey, domain, cookieValue)
}
func DeleteExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, request *http.Request, cookieConfig *AccessConfig) {
cookieHandler.DeleteCookie(writer, request, cookieConfig.ExhaustedCookieKey)
}
type statusRecorder struct { type statusRecorder struct {
http.ResponseWriter http.ResponseWriter
status int status int

View File

@ -91,7 +91,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) {
return f, nil return f, nil
} }
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) { func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler func(http.Handler) http.Handler, limitingAccessInterceptor *middleware.AccessInterceptor, customerPortal string) (http.Handler, error) {
fSys, err := fs.Sub(static, "static") fSys, err := fs.Sub(static, "static")
if err != nil { if err != nil {
return nil, err return nil, err
@ -106,10 +106,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
handler := mux.NewRouter() handler := mux.NewRouter()
handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor) handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle)
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure) url := http_util.BuildOrigin(r.Host, externalSecure)
instance := authz.GetInstance(r.Context()) ctx := r.Context()
instance := authz.GetInstance(ctx)
instanceMgmtURL, err := templateInstanceManagementURL(config.InstanceManagementURL, instance) instanceMgmtURL, err := templateInstanceManagementURL(config.InstanceManagementURL, instance)
if err != nil { if err != nil {
http.Error(w, fmt.Sprintf("unable to template instance management url for console: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("unable to template instance management url for console: %v", err), http.StatusInternalServerError)
@ -120,6 +121,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError) http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError)
return return
} }
if limitingAccessInterceptor.Limit(ctx) {
limitingAccessInterceptor.SetExhaustedCookie(w, r)
} else {
limitingAccessInterceptor.DeleteExhaustedCookie(w, r)
}
_, err = w.Write(environmentJSON) _, err = w.Write(environmentJSON)
logging.OnError(err).Error("error serving environment.json") logging.OnError(err).Error("error serving environment.json")
}))) })))