fix: set exhausted cookie with env json (#5868)

* fix: set exhausted cookie with env json

* lint
This commit is contained in:
Elio Bischof 2023-05-15 08:51:02 +02:00 committed by GitHub
parent b449762aed
commit 0e251a29c8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 109 additions and 102 deletions

View File

@ -306,9 +306,8 @@ func startAPIs(
http_util.WithNonHttpOnly(),
http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))),
)
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access, false)
nonLimitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, nil, config.Quotas.Access, true)
apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, accessSvc, exhaustedCookieHandler, config.Quotas.Access)
limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access)
apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, limitingAccessInterceptor)
if err != nil {
return fmt.Errorf("error creating api %w", err)
}
@ -376,7 +375,7 @@ func startAPIs(
}
apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler())
c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, nonLimitingAccessInterceptor.Handle, config.CustomerPortal)
c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, limitingAccessInterceptor, config.CustomerPortal)
if err != nil {
return fmt.Errorf("unable to start console: %w", err)
}

View File

@ -19,24 +19,22 @@ import (
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/api/ui/login"
"github.com/zitadel/zitadel/internal/errors"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/query"
"github.com/zitadel/zitadel/internal/telemetry/metrics"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type API struct {
port uint16
grpcServer *grpc.Server
verifier *internal_authz.TokenVerifier
health healthCheck
router *mux.Router
http1HostName string
grpcGateway *server.Gateway
healthServer *health.Server
cookieHandler *http_util.CookieHandler
cookieConfig *http_mw.AccessConfig
queries *query.Queries
port uint16
grpcServer *grpc.Server
verifier *internal_authz.TokenVerifier
health healthCheck
router *mux.Router
http1HostName string
grpcGateway *server.Gateway
healthServer *health.Server
accessInterceptor *http_mw.AccessInterceptor
queries *query.Queries
}
type healthCheck interface {
@ -51,23 +49,20 @@ func New(
verifier *internal_authz.TokenVerifier,
authZ internal_authz.Config,
tlsConfig *tls.Config, http2HostName, http1HostName string,
accessSvc *logstore.Service,
cookieHandler *http_util.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
) (_ *API, err error) {
api := &API{
port: port,
verifier: verifier,
health: queries,
router: router,
http1HostName: http1HostName,
cookieConfig: cookieConfig,
cookieHandler: cookieHandler,
queries: queries,
port: port,
verifier: verifier,
health: queries,
router: router,
http1HostName: http1HostName,
queries: queries,
accessInterceptor: accessInterceptor,
}
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc)
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, cookieHandler, cookieConfig)
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessInterceptor.AccessService())
api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, accessInterceptor)
if err != nil {
return nil, err
}
@ -90,8 +85,7 @@ func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayP
grpcServer,
a.port,
a.http1HostName,
a.cookieHandler,
a.cookieConfig,
a.accessInterceptor,
a.queries,
)
if err != nil {

View File

@ -16,7 +16,6 @@ import (
client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_utils "github.com/zitadel/zitadel/internal/api/http"
http_mw "github.com/zitadel/zitadel/internal/api/http/middleware"
"github.com/zitadel/zitadel/internal/query"
)
@ -66,16 +65,15 @@ var (
)
type Gateway struct {
mux *runtime.ServeMux
http1HostName string
connection *grpc.ClientConn
cookieHandler *http_utils.CookieHandler
cookieConfig *http_mw.AccessConfig
queries *query.Queries
mux *runtime.ServeMux
http1HostName string
connection *grpc.ClientConn
accessInterceptor *http_mw.AccessInterceptor
queries *query.Queries
}
func (g *Gateway) Handler() http.Handler {
return addInterceptors(g.mux, g.http1HostName, g.cookieHandler, g.cookieConfig, g.queries)
return addInterceptors(g.mux, g.http1HostName, g.accessInterceptor, g.queries)
}
type CustomHTTPResponse interface {
@ -89,8 +87,7 @@ func CreateGatewayWithPrefix(
g WithGatewayPrefix,
port uint16,
http1HostName string,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
queries *query.Queries,
) (http.Handler, string, error) {
runtimeMux := runtime.NewServeMux(serveMuxOptions...)
@ -106,10 +103,10 @@ func CreateGatewayWithPrefix(
if err != nil {
return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err)
}
return addInterceptors(runtimeMux, http1HostName, cookieHandler, cookieConfig, queries), g.GatewayPathPrefix(), nil
return addInterceptors(runtimeMux, http1HostName, accessInterceptor, queries), g.GatewayPathPrefix(), nil
}
func CreateGateway(ctx context.Context, port uint16, http1HostName string, cookieHandler *http_utils.CookieHandler, cookieConfig *http_mw.AccessConfig) (*Gateway, error) {
func CreateGateway(ctx context.Context, port uint16, http1HostName string, accessInterceptor *http_mw.AccessInterceptor) (*Gateway, error) {
connection, err := dial(ctx,
port,
[]grpc.DialOption{
@ -121,11 +118,10 @@ func CreateGateway(ctx context.Context, port uint16, http1HostName string, cooki
}
runtimeMux := runtime.NewServeMux(append(serveMuxOptions, runtime.WithHealthzEndpoint(healthpb.NewHealthClient(connection)))...)
return &Gateway{
mux: runtimeMux,
http1HostName: http1HostName,
connection: connection,
cookieHandler: cookieHandler,
cookieConfig: cookieConfig,
mux: runtimeMux,
http1HostName: http1HostName,
connection: connection,
accessInterceptor: accessInterceptor,
}, nil
}
@ -163,8 +159,7 @@ func dial(ctx context.Context, port uint16, opts []grpc.DialOption) (*grpc.Clien
func addInterceptors(
handler http.Handler,
http1HostName string,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
queries *query.Queries,
) http.Handler {
handler = http_mw.CallDurationHandler(handler)
@ -174,7 +169,7 @@ func addInterceptors(
handler = http_mw.DefaultTelemetryHandler(handler)
// For some non-obvious reason, the exhaustedCookieInterceptor sends the SetCookie header
// only if it follows the http_mw.DefaultTelemetryHandler
handler = exhaustedCookieInterceptor(handler, cookieHandler, cookieConfig, queries)
handler = exhaustedCookieInterceptor(handler, accessInterceptor, queries)
handler = http_mw.DefaultMetricsHandler(handler)
return handler
}
@ -193,35 +188,32 @@ func http1Host(next http.Handler, http1HostName string) http.Handler {
func exhaustedCookieInterceptor(
next http.Handler,
cookieHandler *http_utils.CookieHandler,
cookieConfig *http_mw.AccessConfig,
accessInterceptor *http_mw.AccessInterceptor,
queries *query.Queries,
) http.Handler {
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
next.ServeHTTP(&cookieResponseWriter{
ResponseWriter: writer,
cookieHandler: cookieHandler,
cookieConfig: cookieConfig,
request: request,
queries: queries,
ResponseWriter: writer,
accessInterceptor: accessInterceptor,
request: request,
queries: queries,
}, request)
})
}
type cookieResponseWriter struct {
http.ResponseWriter
cookieHandler *http_utils.CookieHandler
cookieConfig *http_mw.AccessConfig
request *http.Request
queries *query.Queries
accessInterceptor *http_mw.AccessInterceptor
request *http.Request
queries *query.Queries
}
func (r *cookieResponseWriter) WriteHeader(status int) {
if status >= 200 && status < 300 {
http_mw.DeleteExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.request, r.cookieConfig)
r.accessInterceptor.DeleteExhaustedCookie(r.ResponseWriter, r.request)
}
if status == http.StatusTooManyRequests {
http_mw.SetExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.cookieConfig, r.request)
r.accessInterceptor.SetExhaustedCookie(r.ResponseWriter, r.request)
}
r.ResponseWriter.WriteHeader(status)
}

View File

@ -1,6 +1,7 @@
package middleware
import (
"context"
"net"
"net/http"
"net/url"
@ -32,15 +33,54 @@ type AccessConfig struct {
// NewAccessInterceptor intercepts all requests and stores them to the logstore.
// If storeOnly is false, it also checks if requests are exhausted.
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig, storeOnly bool) *AccessInterceptor {
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
return &AccessInterceptor{
svc: svc,
cookieHandler: cookieHandler,
limitConfig: cookieConfig,
storeOnly: storeOnly,
}
}
func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor {
return &AccessInterceptor{
svc: a.svc,
cookieHandler: a.cookieHandler,
limitConfig: a.limitConfig,
storeOnly: true,
}
}
func (a *AccessInterceptor) AccessService() *logstore.Service {
return a.svc
}
func (a *AccessInterceptor) Limit(ctx context.Context) bool {
if !a.svc.Enabled() || a.storeOnly {
return false
}
instance := authz.GetInstance(ctx)
remaining := a.svc.Limit(ctx, instance.InstanceID())
return remaining != nil && *remaining <= 0
}
func (a *AccessInterceptor) SetExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
a.cookieHandler.SetCookie(writer, a.limitConfig.ExhaustedCookieKey, domain, cookieValue)
}
func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
a.cookieHandler.DeleteCookie(writer, request, a.limitConfig.ExhaustedCookieKey)
}
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if !a.svc.Enabled() {
return next
@ -49,23 +89,16 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
ctx := request.Context()
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
instance := authz.GetInstance(ctx)
limit := false
if !a.storeOnly {
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
limit = remaining != nil && *remaining == 0
}
limited := a.Limit(tracingCtx)
checkSpan.End()
if limit {
// Limit can only be true when storeOnly is false, so set the cookie and the response code
SetExhaustedCookie(a.cookieHandler, wrappedWriter, a.limitConfig, request)
if limited {
a.SetExhaustedCookie(wrappedWriter, request)
http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests)
} else {
if !a.storeOnly {
// If not limited and not storeOnly, ensure the cookie is deleted
DeleteExhaustedCookie(a.cookieHandler, wrappedWriter, request, a.limitConfig)
}
// Always serve if not limited
}
if !limited && !a.storeOnly {
a.DeleteExhaustedCookie(wrappedWriter, request)
}
if !limited {
next.ServeHTTP(wrappedWriter, request)
}
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
@ -75,6 +108,7 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if err != nil {
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
}
instance := authz.GetInstance(tracingCtx)
a.svc.Handle(tracingCtx, &access.Record{
LogDate: time.Now(),
Protocol: access.HTTP,
@ -90,24 +124,6 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
})
}
func SetExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, cookieConfig *AccessConfig, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
cookieHandler.SetCookie(writer, cookieConfig.ExhaustedCookieKey, domain, cookieValue)
}
func DeleteExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, request *http.Request, cookieConfig *AccessConfig) {
cookieHandler.DeleteCookie(writer, request, cookieConfig.ExhaustedCookieKey)
}
type statusRecorder struct {
http.ResponseWriter
status int

View File

@ -91,7 +91,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) {
return f, nil
}
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler func(http.Handler) http.Handler, limitingAccessInterceptor *middleware.AccessInterceptor, customerPortal string) (http.Handler, error) {
fSys, err := fs.Sub(static, "static")
if err != nil {
return nil, err
@ -106,10 +106,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
handler := mux.NewRouter()
handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor)
handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle)
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
url := http_util.BuildOrigin(r.Host, externalSecure)
instance := authz.GetInstance(r.Context())
ctx := r.Context()
instance := authz.GetInstance(ctx)
instanceMgmtURL, err := templateInstanceManagementURL(config.InstanceManagementURL, instance)
if err != nil {
http.Error(w, fmt.Sprintf("unable to template instance management url for console: %v", err), http.StatusInternalServerError)
@ -120,6 +121,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call
http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError)
return
}
if limitingAccessInterceptor.Limit(ctx) {
limitingAccessInterceptor.SetExhaustedCookie(w, r)
} else {
limitingAccessInterceptor.DeleteExhaustedCookie(w, r)
}
_, err = w.Write(environmentJSON)
logging.OnError(err).Error("error serving environment.json")
})))