mirror of
https://github.com/zitadel/zitadel.git
synced 2024-12-13 19:44:21 +00:00
106 lines
2.8 KiB
Go
106 lines
2.8 KiB
Go
package middleware
|
|
|
|
import (
|
|
"math"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"time"
|
|
|
|
"github.com/zitadel/logging"
|
|
|
|
"github.com/zitadel/zitadel/internal/api/authz"
|
|
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
|
"github.com/zitadel/zitadel/internal/logstore"
|
|
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
|
|
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
|
)
|
|
|
|
type AccessInterceptor struct {
|
|
svc *logstore.Service
|
|
cookieHandler *http_utils.CookieHandler
|
|
limitConfig *AccessConfig
|
|
}
|
|
|
|
type AccessConfig struct {
|
|
ExhaustedCookieKey string
|
|
ExhaustedCookieMaxAge time.Duration
|
|
}
|
|
|
|
func NewAccessInterceptor(svc *logstore.Service, cookieConfig *AccessConfig) *AccessInterceptor {
|
|
return &AccessInterceptor{
|
|
svc: svc,
|
|
cookieHandler: http_utils.NewCookieHandler(
|
|
http_utils.WithUnsecure(),
|
|
http_utils.WithMaxAge(int(math.Floor(cookieConfig.ExhaustedCookieMaxAge.Seconds()))),
|
|
),
|
|
limitConfig: cookieConfig,
|
|
}
|
|
}
|
|
|
|
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
|
if !a.svc.Enabled() {
|
|
return next
|
|
}
|
|
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
|
ctx := request.Context()
|
|
var err error
|
|
|
|
tracingCtx, checkSpan := tracing.NewNamedSpan(ctx, "checkAccess")
|
|
|
|
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
|
|
|
|
instance := authz.GetInstance(ctx)
|
|
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
|
|
limit := remaining != nil && *remaining == 0
|
|
|
|
a.cookieHandler.SetCookie(wrappedWriter, a.limitConfig.ExhaustedCookieKey, request.Host, strconv.FormatBool(limit))
|
|
|
|
if limit {
|
|
wrappedWriter.WriteHeader(http.StatusTooManyRequests)
|
|
wrappedWriter.ignoreWrites = true
|
|
}
|
|
|
|
checkSpan.End()
|
|
|
|
next.ServeHTTP(wrappedWriter, request)
|
|
|
|
tracingCtx, writeSpan := tracing.NewNamedSpan(tracingCtx, "writeAccess")
|
|
defer writeSpan.End()
|
|
|
|
requestURL := request.RequestURI
|
|
unescapedURL, err := url.QueryUnescape(requestURL)
|
|
if err != nil {
|
|
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
|
|
// err = nil is effective because of deferred tracing span end
|
|
err = nil
|
|
}
|
|
a.svc.Handle(tracingCtx, &access.Record{
|
|
LogDate: time.Now(),
|
|
Protocol: access.HTTP,
|
|
RequestURL: unescapedURL,
|
|
ResponseStatus: uint32(wrappedWriter.status),
|
|
RequestHeaders: request.Header,
|
|
ResponseHeaders: writer.Header(),
|
|
InstanceID: instance.InstanceID(),
|
|
ProjectID: instance.ProjectID(),
|
|
RequestedDomain: instance.RequestedDomain(),
|
|
RequestedHost: instance.RequestedHost(),
|
|
})
|
|
})
|
|
}
|
|
|
|
type statusRecorder struct {
|
|
http.ResponseWriter
|
|
status int
|
|
ignoreWrites bool
|
|
}
|
|
|
|
func (r *statusRecorder) WriteHeader(status int) {
|
|
if r.ignoreWrites {
|
|
return
|
|
}
|
|
r.status = status
|
|
r.ResponseWriter.WriteHeader(status)
|
|
}
|