zitadel/internal/api/http/middleware/access_interceptor.go
Elio Bischof 0e251a29c8
fix: set exhausted cookie with env json (#5868)
* fix: set exhausted cookie with env json

* lint
2023-05-15 08:51:02 +02:00

140 lines
4.1 KiB
Go

package middleware
import (
"context"
"net"
"net/http"
"net/url"
"strings"
"time"
"github.com/zitadel/logging"
"github.com/zitadel/zitadel/internal/api/authz"
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
http_utils "github.com/zitadel/zitadel/internal/api/http"
"github.com/zitadel/zitadel/internal/logstore"
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
"github.com/zitadel/zitadel/internal/telemetry/tracing"
)
type AccessInterceptor struct {
svc *logstore.Service
cookieHandler *http_utils.CookieHandler
limitConfig *AccessConfig
storeOnly bool
}
type AccessConfig struct {
ExhaustedCookieKey string
ExhaustedCookieMaxAge time.Duration
}
// NewAccessInterceptor intercepts all requests and stores them to the logstore.
// If storeOnly is false, it also checks if requests are exhausted.
// If requests are exhausted, it also returns http.StatusTooManyRequests and sets a cookie
func NewAccessInterceptor(svc *logstore.Service, cookieHandler *http_utils.CookieHandler, cookieConfig *AccessConfig) *AccessInterceptor {
return &AccessInterceptor{
svc: svc,
cookieHandler: cookieHandler,
limitConfig: cookieConfig,
}
}
func (a *AccessInterceptor) WithoutLimiting() *AccessInterceptor {
return &AccessInterceptor{
svc: a.svc,
cookieHandler: a.cookieHandler,
limitConfig: a.limitConfig,
storeOnly: true,
}
}
func (a *AccessInterceptor) AccessService() *logstore.Service {
return a.svc
}
func (a *AccessInterceptor) Limit(ctx context.Context) bool {
if !a.svc.Enabled() || a.storeOnly {
return false
}
instance := authz.GetInstance(ctx)
remaining := a.svc.Limit(ctx, instance.InstanceID())
return remaining != nil && *remaining <= 0
}
func (a *AccessInterceptor) SetExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
cookieValue := "true"
host := request.Header.Get(middleware.HTTP1Host)
domain := host
if strings.ContainsAny(host, ":") {
var err error
domain, _, err = net.SplitHostPort(host)
if err != nil {
logging.WithError(err).WithField("host", host).Warning("failed to extract cookie domain from request host")
}
}
a.cookieHandler.SetCookie(writer, a.limitConfig.ExhaustedCookieKey, domain, cookieValue)
}
func (a *AccessInterceptor) DeleteExhaustedCookie(writer http.ResponseWriter, request *http.Request) {
a.cookieHandler.DeleteCookie(writer, request, a.limitConfig.ExhaustedCookieKey)
}
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
if !a.svc.Enabled() {
return next
}
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
ctx := request.Context()
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
limited := a.Limit(tracingCtx)
checkSpan.End()
if limited {
a.SetExhaustedCookie(wrappedWriter, request)
http.Error(wrappedWriter, "quota for authenticated requests is exhausted", http.StatusTooManyRequests)
}
if !limited && !a.storeOnly {
a.DeleteExhaustedCookie(wrappedWriter, request)
}
if !limited {
next.ServeHTTP(wrappedWriter, request)
}
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
defer writeSpan.End()
requestURL := request.RequestURI
unescapedURL, err := url.QueryUnescape(requestURL)
if err != nil {
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
}
instance := authz.GetInstance(tracingCtx)
a.svc.Handle(tracingCtx, &access.Record{
LogDate: time.Now(),
Protocol: access.HTTP,
RequestURL: unescapedURL,
ResponseStatus: uint32(wrappedWriter.status),
RequestHeaders: request.Header,
ResponseHeaders: writer.Header(),
InstanceID: instance.InstanceID(),
ProjectID: instance.ProjectID(),
RequestedDomain: instance.RequestedDomain(),
RequestedHost: instance.RequestedHost(),
})
})
}
type statusRecorder struct {
http.ResponseWriter
status int
ignoreWrites bool
}
func (r *statusRecorder) WriteHeader(status int) {
if r.ignoreWrites {
return
}
r.status = status
r.ResponseWriter.WriteHeader(status)
}