diff --git a/cmd/start/start.go b/cmd/start/start.go index 0d2d973690..efa0e9326f 100644 --- a/cmd/start/start.go +++ b/cmd/start/start.go @@ -306,9 +306,8 @@ func startAPIs( http_util.WithNonHttpOnly(), http_util.WithMaxAge(int(math.Floor(config.Quotas.Access.ExhaustedCookieMaxAge.Seconds()))), ) - limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access, false) - nonLimitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, nil, config.Quotas.Access, true) - apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, accessSvc, exhaustedCookieHandler, config.Quotas.Access) + limitingAccessInterceptor := middleware.NewAccessInterceptor(accessSvc, exhaustedCookieHandler, config.Quotas.Access) + apis, err := api.New(ctx, config.Port, router, queries, verifier, config.InternalAuthZ, tlsConfig, config.HTTP2HostHeader, config.HTTP1HostHeader, limitingAccessInterceptor) if err != nil { return fmt.Errorf("error creating api %w", err) } @@ -376,7 +375,7 @@ func startAPIs( } apis.RegisterHandlerOnPrefix(saml.HandlerPrefix, samlProvider.HttpHandler()) - c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, nonLimitingAccessInterceptor.Handle, config.CustomerPortal) + c, err := console.Start(config.Console, config.ExternalSecure, oidcProvider.IssuerFromRequest, middleware.CallDurationHandler, instanceInterceptor.Handler, limitingAccessInterceptor, config.CustomerPortal) if err != nil { return fmt.Errorf("unable to start console: %w", err) } diff --git a/internal/api/api.go b/internal/api/api.go index 768f0f1a2e..578bae3a45 100644 --- a/internal/api/api.go +++ b/internal/api/api.go @@ -19,24 +19,22 @@ import ( http_mw "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/api/ui/login" "github.com/zitadel/zitadel/internal/errors" - "github.com/zitadel/zitadel/internal/logstore" "github.com/zitadel/zitadel/internal/query" "github.com/zitadel/zitadel/internal/telemetry/metrics" "github.com/zitadel/zitadel/internal/telemetry/tracing" ) type API struct { - port uint16 - grpcServer *grpc.Server - verifier *internal_authz.TokenVerifier - health healthCheck - router *mux.Router - http1HostName string - grpcGateway *server.Gateway - healthServer *health.Server - cookieHandler *http_util.CookieHandler - cookieConfig *http_mw.AccessConfig - queries *query.Queries + port uint16 + grpcServer *grpc.Server + verifier *internal_authz.TokenVerifier + health healthCheck + router *mux.Router + http1HostName string + grpcGateway *server.Gateway + healthServer *health.Server + accessInterceptor *http_mw.AccessInterceptor + queries *query.Queries } type healthCheck interface { @@ -51,23 +49,20 @@ func New( verifier *internal_authz.TokenVerifier, authZ internal_authz.Config, tlsConfig *tls.Config, http2HostName, http1HostName string, - accessSvc *logstore.Service, - cookieHandler *http_util.CookieHandler, - cookieConfig *http_mw.AccessConfig, + accessInterceptor *http_mw.AccessInterceptor, ) (_ *API, err error) { api := &API{ - port: port, - verifier: verifier, - health: queries, - router: router, - http1HostName: http1HostName, - cookieConfig: cookieConfig, - cookieHandler: cookieHandler, - queries: queries, + port: port, + verifier: verifier, + health: queries, + router: router, + http1HostName: http1HostName, + queries: queries, + accessInterceptor: accessInterceptor, } - api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc) - api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, cookieHandler, cookieConfig) + api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessInterceptor.AccessService()) + api.grpcGateway, err = server.CreateGateway(ctx, port, http1HostName, accessInterceptor) if err != nil { return nil, err } @@ -90,8 +85,7 @@ func (a *API) RegisterServer(ctx context.Context, grpcServer server.WithGatewayP grpcServer, a.port, a.http1HostName, - a.cookieHandler, - a.cookieConfig, + a.accessInterceptor, a.queries, ) if err != nil { diff --git a/internal/api/grpc/server/gateway.go b/internal/api/grpc/server/gateway.go index 9798e5dbd0..eed0234be9 100644 --- a/internal/api/grpc/server/gateway.go +++ b/internal/api/grpc/server/gateway.go @@ -16,7 +16,6 @@ import ( client_middleware "github.com/zitadel/zitadel/internal/api/grpc/client/middleware" "github.com/zitadel/zitadel/internal/api/grpc/server/middleware" - http_utils "github.com/zitadel/zitadel/internal/api/http" http_mw "github.com/zitadel/zitadel/internal/api/http/middleware" "github.com/zitadel/zitadel/internal/query" ) @@ -66,16 +65,15 @@ var ( ) type Gateway struct { - mux *runtime.ServeMux - http1HostName string - connection *grpc.ClientConn - cookieHandler *http_utils.CookieHandler - cookieConfig *http_mw.AccessConfig - queries *query.Queries + mux *runtime.ServeMux + http1HostName string + connection *grpc.ClientConn + accessInterceptor *http_mw.AccessInterceptor + queries *query.Queries } func (g *Gateway) Handler() http.Handler { - return addInterceptors(g.mux, g.http1HostName, g.cookieHandler, g.cookieConfig, g.queries) + return addInterceptors(g.mux, g.http1HostName, g.accessInterceptor, g.queries) } type CustomHTTPResponse interface { @@ -89,8 +87,7 @@ func CreateGatewayWithPrefix( g WithGatewayPrefix, port uint16, http1HostName string, - cookieHandler *http_utils.CookieHandler, - cookieConfig *http_mw.AccessConfig, + accessInterceptor *http_mw.AccessInterceptor, queries *query.Queries, ) (http.Handler, string, error) { runtimeMux := runtime.NewServeMux(serveMuxOptions...) @@ -106,10 +103,10 @@ func CreateGatewayWithPrefix( if err != nil { return nil, "", fmt.Errorf("failed to register grpc gateway: %w", err) } - return addInterceptors(runtimeMux, http1HostName, cookieHandler, cookieConfig, queries), g.GatewayPathPrefix(), nil + return addInterceptors(runtimeMux, http1HostName, accessInterceptor, queries), g.GatewayPathPrefix(), nil } -func CreateGateway(ctx context.Context, port uint16, http1HostName string, cookieHandler *http_utils.CookieHandler, cookieConfig *http_mw.AccessConfig) (*Gateway, error) { +func CreateGateway(ctx context.Context, port uint16, http1HostName string, accessInterceptor *http_mw.AccessInterceptor) (*Gateway, error) { connection, err := dial(ctx, port, []grpc.DialOption{ @@ -121,11 +118,10 @@ func CreateGateway(ctx context.Context, port uint16, http1HostName string, cooki } runtimeMux := runtime.NewServeMux(append(serveMuxOptions, runtime.WithHealthzEndpoint(healthpb.NewHealthClient(connection)))...) return &Gateway{ - mux: runtimeMux, - http1HostName: http1HostName, - connection: connection, - cookieHandler: cookieHandler, - cookieConfig: cookieConfig, + mux: runtimeMux, + http1HostName: http1HostName, + connection: connection, + accessInterceptor: accessInterceptor, }, nil } @@ -163,8 +159,7 @@ func dial(ctx context.Context, port uint16, opts []grpc.DialOption) (*grpc.Clien func addInterceptors( handler http.Handler, http1HostName string, - cookieHandler *http_utils.CookieHandler, - cookieConfig *http_mw.AccessConfig, + accessInterceptor *http_mw.AccessInterceptor, queries *query.Queries, ) http.Handler { handler = http_mw.CallDurationHandler(handler) @@ -174,7 +169,7 @@ func addInterceptors( handler = http_mw.DefaultTelemetryHandler(handler) // For some non-obvious reason, the exhaustedCookieInterceptor sends the SetCookie header // only if it follows the http_mw.DefaultTelemetryHandler - handler = exhaustedCookieInterceptor(handler, cookieHandler, cookieConfig, queries) + handler = exhaustedCookieInterceptor(handler, accessInterceptor, queries) handler = http_mw.DefaultMetricsHandler(handler) return handler } @@ -193,35 +188,32 @@ func http1Host(next http.Handler, http1HostName string) http.Handler { func exhaustedCookieInterceptor( next http.Handler, - cookieHandler *http_utils.CookieHandler, - cookieConfig *http_mw.AccessConfig, + accessInterceptor *http_mw.AccessInterceptor, queries *query.Queries, ) http.Handler { return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) { next.ServeHTTP(&cookieResponseWriter{ - ResponseWriter: writer, - cookieHandler: cookieHandler, - cookieConfig: cookieConfig, - request: request, - queries: queries, + ResponseWriter: writer, + accessInterceptor: accessInterceptor, + request: request, + queries: queries, }, request) }) } type cookieResponseWriter struct { http.ResponseWriter - cookieHandler *http_utils.CookieHandler - cookieConfig *http_mw.AccessConfig - request *http.Request - queries *query.Queries + accessInterceptor *http_mw.AccessInterceptor + request *http.Request + queries *query.Queries } func (r *cookieResponseWriter) WriteHeader(status int) { if status >= 200 && status < 300 { - http_mw.DeleteExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.request, r.cookieConfig) + r.accessInterceptor.DeleteExhaustedCookie(r.ResponseWriter, r.request) } if status == http.StatusTooManyRequests { - http_mw.SetExhaustedCookie(r.cookieHandler, r.ResponseWriter, r.cookieConfig, r.request) + r.accessInterceptor.SetExhaustedCookie(r.ResponseWriter, r.request) } r.ResponseWriter.WriteHeader(status) } diff --git a/internal/api/http/middleware/access_interceptor.go b/internal/api/http/middleware/access_interceptor.go index cf52a597d6..469fdd16d7 100644 --- a/internal/api/http/middleware/access_interceptor.go +++ b/internal/api/http/middleware/access_interceptor.go @@ -1,6 +1,7 @@ package middleware import ( + "context" "net" "net/http" "net/url" @@ -32,15 +33,54 @@ type AccessConfig struct { // NewAccessInterceptor intercepts all requests and stores them to the logstore. // If storeOnly is false, it also checks if requests are exhausted. // If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie -func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig, storeOnly bool) *AccessInterceptor { +func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor { return &AccessInterceptor{ svc: svc, cookieHandler: cookieHandler, limitConfig: cookieConfig, - storeOnly: storeOnly, } } +func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor { + return &AccessInterceptor{ + svc: a.svc, + cookieHandler: a.cookieHandler, + limitConfig: a.limitConfig, + storeOnly: true, + } +} + +func (a *AccessInterceptor) AccessService() *logstore.Service { + return a.svc +} + +func (a *AccessInterceptor) Limit(ctx context.Context) bool { + if !a.svc.Enabled() || a.storeOnly { + return false + } + instance := authz.GetInstance(ctx) + remaining := a.svc.Limit(ctx, instance.InstanceID()) + return remaining != nil && *remaining <= 0 +} + +func (a *AccessInterceptor) SetExhaustedCookie(writer http.ResponseWriter, request *http.Request) { + cookieValue := "true" + host := request.Header.Get(middleware.HTTP1Host) + domain := host + if strings.ContainsAny(host, ":") { + var err error + domain, _, err = net.SplitHostPort(host) + if err != nil { + logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host") + } + } + a.cookieHandler.SetCookie(writer, a.limitConfig.ExhaustedCookieKey, domain, cookieValue) +} + +func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter, request *http.Request) { + a.cookieHandler.DeleteCookie(writer, request, a.limitConfig.ExhaustedCookieKey) +} + func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { if !a.svc.Enabled() { return next @@ -49,23 +89,16 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { ctx := request.Context() tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess") wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0} - instance := authz.GetInstance(ctx) - limit := false - if !a.storeOnly { - remaining := a.svc.Limit(tracingCtx, instance.InstanceID()) - limit = remaining != nil && *remaining == 0 - } + limited := a.Limit(tracingCtx) checkSpan.End() - if limit { - // Limit can only be true when storeOnly is false, so set the cookie and the response code - SetExhaustedCookie(a.cookieHandler, wrappedWriter, a.limitConfig, request) + if limited { + a.SetExhaustedCookie(wrappedWriter, request) http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests) - } else { - if !a.storeOnly { - // If not limited and not storeOnly, ensure the cookie is deleted - DeleteExhaustedCookie(a.cookieHandler, wrappedWriter, request, a.limitConfig) - } - // Always serve if not limited + } + if !limited && !a.storeOnly { + a.DeleteExhaustedCookie(wrappedWriter, request) + } + if !limited { next.ServeHTTP(wrappedWriter, request) } tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess") @@ -75,6 +108,7 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { if err != nil { logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url") } + instance := authz.GetInstance(tracingCtx) a.svc.Handle(tracingCtx, &access.Record{ LogDate: time.Now(), Protocol: access.HTTP, @@ -90,24 +124,6 @@ func (a *AccessInterceptor) Handle(next http.Handler) http.Handler { }) } -func SetExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, cookieConfig *AccessConfig, request *http.Request) { - cookieValue := "true" - host := request.Header.Get(middleware.HTTP1Host) - domain := host - if strings.ContainsAny(host, ":") { - var err error - domain, _, err = net.SplitHostPort(host) - if err != nil { - logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host") - } - } - cookieHandler.SetCookie(writer, cookieConfig.ExhaustedCookieKey, domain, cookieValue) -} - -func DeleteExhaustedCookie(cookieHandler *http_utils.CookieHandler, writer http.ResponseWriter, request *http.Request, cookieConfig *AccessConfig) { - cookieHandler.DeleteCookie(writer, request, cookieConfig.ExhaustedCookieKey) -} - type statusRecorder struct { http.ResponseWriter status int diff --git a/internal/api/ui/console/console.go b/internal/api/ui/console/console.go index 45503ca0c1..1980f6cc5f 100644 --- a/internal/api/ui/console/console.go +++ b/internal/api/ui/console/console.go @@ -91,7 +91,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) { return f, nil } -func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) { +func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, callDurationInterceptor, instanceHandler func(http.Handler) http.Handler, limitingAccessInterceptor *middleware.AccessInterceptor, customerPortal string) (http.Handler, error) { fSys, err := fs.Sub(static, "static") if err != nil { return nil, err @@ -106,10 +106,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call handler := mux.NewRouter() - handler.Use(callDurationInterceptor, instanceHandler, security, accessInterceptor) + handler.Use(callDurationInterceptor, instanceHandler, security, limitingAccessInterceptor.WithoutLimiting().Handle) handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { url := http_util.BuildOrigin(r.Host, externalSecure) - instance := authz.GetInstance(r.Context()) + ctx := r.Context() + instance := authz.GetInstance(ctx) instanceMgmtURL, err := templateInstanceManagementURL(config.InstanceManagementURL, instance) if err != nil { http.Error(w, fmt.Sprintf("unable to template instance management url for console: %v", err), http.StatusInternalServerError) @@ -120,6 +121,11 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, call http.Error(w, fmt.Sprintf("unable to marshal env for console: %v", err), http.StatusInternalServerError) return } + if limitingAccessInterceptor.Limit(ctx) { + limitingAccessInterceptor.SetExhaustedCookie(w, r) + } else { + limitingAccessInterceptor.DeleteExhaustedCookie(w, r) + } _, err = w.Write(environmentJSON) logging.OnError(err).Error("error serving environment.json") })))