mirror of
https://github.com/zitadel/zitadel.git
synced 2025-12-03 05:32:18 +00:00
feat: add quotas (#4779)
adds possibilities to cap authenticated requests and execution seconds of actions on a defined intervall
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
http_util "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/api/ui/login"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
@@ -36,7 +37,18 @@ type health interface {
|
||||
Instance(ctx context.Context, shouldTriggerBulk bool) (*query.Instance, error)
|
||||
}
|
||||
|
||||
func New(port uint16, router *mux.Router, queries *query.Queries, verifier *internal_authz.TokenVerifier, authZ internal_authz.Config, externalSecure bool, tlsConfig *tls.Config, http2HostName, http1HostName string) *API {
|
||||
func New(
|
||||
port uint16,
|
||||
router *mux.Router,
|
||||
queries *query.Queries,
|
||||
verifier *internal_authz.TokenVerifier,
|
||||
authZ internal_authz.Config,
|
||||
externalSecure bool,
|
||||
tlsConfig *tls.Config,
|
||||
http2HostName,
|
||||
http1HostName string,
|
||||
accessSvc *logstore.Service,
|
||||
) *API {
|
||||
api := &API{
|
||||
port: port,
|
||||
verifier: verifier,
|
||||
@@ -45,7 +57,8 @@ func New(port uint16, router *mux.Router, queries *query.Queries, verifier *inte
|
||||
externalSecure: externalSecure,
|
||||
http1HostName: http1HostName,
|
||||
}
|
||||
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig)
|
||||
|
||||
api.grpcServer = server.CreateServer(api.verifier, authZ, queries, http2HostName, tlsConfig, accessSvc)
|
||||
api.routeGRPC()
|
||||
|
||||
api.RegisterHandler("/debug", api.healthHandler())
|
||||
|
||||
@@ -82,7 +82,7 @@ func DefaultErrorHandler(w http.ResponseWriter, r *http.Request, err error, code
|
||||
http.Error(w, err.Error(), code)
|
||||
}
|
||||
|
||||
func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authConfig authz.Config, idGenerator id.Generator, storage static.Storage, queries *query.Queries, instanceInterceptor, assetCacheInterceptor func(handler http.Handler) http.Handler) http.Handler {
|
||||
func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authConfig authz.Config, idGenerator id.Generator, storage static.Storage, queries *query.Queries, instanceInterceptor, assetCacheInterceptor, accessInterceptor func(handler http.Handler) http.Handler) http.Handler {
|
||||
h := &Handler{
|
||||
commands: commands,
|
||||
errorHandler: DefaultErrorHandler,
|
||||
@@ -94,7 +94,7 @@ func NewHandler(commands *command.Commands, verifier *authz.TokenVerifier, authC
|
||||
|
||||
verifier.RegisterServer("Assets-API", "assets", AssetsService_AuthMethods)
|
||||
router := mux.NewRouter()
|
||||
router.Use(instanceInterceptor, assetCacheInterceptor)
|
||||
router.Use(instanceInterceptor, assetCacheInterceptor, accessInterceptor)
|
||||
RegisterRoutes(router, h)
|
||||
router.PathPrefix("/{owner}").Methods("GET").HandlerFunc(DownloadHandleFunc(h, h.GetFile()))
|
||||
return http_util.CopyHeadersToContext(http_mw.CORSInterceptor(router))
|
||||
|
||||
@@ -23,7 +23,8 @@ type Instance interface {
|
||||
}
|
||||
|
||||
type InstanceVerifier interface {
|
||||
InstanceByHost(context.Context, string) (Instance, error)
|
||||
InstanceByHost(ctx context.Context, host string) (Instance, error)
|
||||
InstanceByID(ctx context.Context) (Instance, error)
|
||||
}
|
||||
|
||||
type instance struct {
|
||||
|
||||
@@ -55,6 +55,8 @@ func ExtractCaosError(err error) (c codes.Code, msg, id string, ok bool) {
|
||||
return codes.Unavailable, caosErr.GetMessage(), caosErr.GetID(), true
|
||||
case *caos_errs.UnimplementedError:
|
||||
return codes.Unimplemented, caosErr.GetMessage(), caosErr.GetID(), true
|
||||
case *caos_errs.ResourceExhaustedError:
|
||||
return codes.ResourceExhausted, caosErr.GetMessage(), caosErr.GetID(), true
|
||||
default:
|
||||
return codes.Unknown, err.Error(), "", false
|
||||
}
|
||||
|
||||
@@ -136,6 +136,14 @@ func Test_Extract(t *testing.T) {
|
||||
"id",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"exhausted",
|
||||
args{caos_errs.ThrowResourceExhausted(nil, "id", "exhausted")},
|
||||
codes.ResourceExhausted,
|
||||
"exhausted",
|
||||
"id",
|
||||
true,
|
||||
},
|
||||
{
|
||||
"unknown",
|
||||
args{errors.New("unknown")},
|
||||
|
||||
55
internal/api/grpc/server/middleware/access_interceptor.go
Normal file
55
internal/api/grpc/server/middleware/access_interceptor.go
Normal file
@@ -0,0 +1,55 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func AccessStorageInterceptor(svc *logstore.Service) grpc.UnaryServerInterceptor {
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
|
||||
if !svc.Enabled() {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
|
||||
reqMd, _ := metadata.FromIncomingContext(ctx)
|
||||
|
||||
resp, handlerErr := handler(ctx, req)
|
||||
|
||||
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
var respStatus uint32
|
||||
grpcStatus, ok := status.FromError(handlerErr)
|
||||
if ok {
|
||||
respStatus = uint32(grpcStatus.Code())
|
||||
}
|
||||
|
||||
resMd, _ := metadata.FromOutgoingContext(ctx)
|
||||
instance := authz.GetInstance(ctx)
|
||||
|
||||
record := &access.Record{
|
||||
LogDate: time.Now(),
|
||||
Protocol: access.GRPC,
|
||||
RequestURL: info.FullMethod,
|
||||
ResponseStatus: respStatus,
|
||||
RequestHeaders: reqMd,
|
||||
ResponseHeaders: resMd,
|
||||
InstanceID: instance.InstanceID(),
|
||||
ProjectID: instance.ProjectID(),
|
||||
RequestedDomain: instance.RequestedDomain(),
|
||||
RequestedHost: instance.RequestedHost(),
|
||||
}
|
||||
|
||||
svc.Handle(interceptorCtx, record)
|
||||
return resp, handlerErr
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
errs "errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -14,7 +14,7 @@ import (
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
caos_errors "github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/i18n"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
@@ -23,27 +23,36 @@ const (
|
||||
HTTP1Host = "x-zitadel-http1-host"
|
||||
)
|
||||
|
||||
type InstanceVerifier interface {
|
||||
GetInstance(ctx context.Context)
|
||||
}
|
||||
|
||||
func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, ignoredServices ...string) grpc.UnaryServerInterceptor {
|
||||
func InstanceInterceptor(verifier authz.InstanceVerifier, headerName string, explicitInstanceIdServices ...string) grpc.UnaryServerInterceptor {
|
||||
translator, err := newZitadelTranslator(language.English)
|
||||
logging.OnError(err).Panic("unable to get translator")
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (interface{}, error) {
|
||||
return setInstance(ctx, req, info, handler, verifier, headerName, translator, ignoredServices...)
|
||||
return setInstance(ctx, req, info, handler, verifier, headerName, translator, explicitInstanceIdServices...)
|
||||
}
|
||||
}
|
||||
|
||||
func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, headerName string, translator *i18n.Translator, ignoredServices ...string) (_ interface{}, err error) {
|
||||
func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler, verifier authz.InstanceVerifier, headerName string, translator *i18n.Translator, idFromRequestsServices ...string) (_ interface{}, err error) {
|
||||
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
for _, service := range ignoredServices {
|
||||
for _, service := range idFromRequestsServices {
|
||||
if !strings.HasPrefix(service, "/") {
|
||||
service = "/" + service
|
||||
}
|
||||
if strings.HasPrefix(info.FullMethod, service) {
|
||||
return handler(ctx, req)
|
||||
withInstanceIDProperty, ok := req.(interface{ GetInstanceId() string })
|
||||
if !ok {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
ctx = authz.WithInstanceID(ctx, withInstanceIDProperty.GetInstanceId())
|
||||
instance, err := verifier.InstanceByID(ctx)
|
||||
if err != nil {
|
||||
notFoundErr := new(errors.NotFoundError)
|
||||
if errs.As(err, ¬FoundErr) {
|
||||
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
|
||||
}
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
return handler(authz.WithInstance(ctx, instance), req)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -53,9 +62,9 @@ func setInstance(ctx context.Context, req interface{}, info *grpc.UnaryServerInf
|
||||
}
|
||||
instance, err := verifier.InstanceByHost(interceptorCtx, host)
|
||||
if err != nil {
|
||||
caosErr := new(caos_errors.NotFoundError)
|
||||
if errors.As(err, &caosErr) {
|
||||
caosErr.Message = translator.LocalizeFromCtx(ctx, caosErr.GetMessage(), nil)
|
||||
notFoundErr := new(errors.NotFoundError)
|
||||
if errs.As(err, ¬FoundErr) {
|
||||
notFoundErr.Message = translator.LocalizeFromCtx(ctx, notFoundErr.GetMessage(), nil)
|
||||
}
|
||||
return nil, status.Error(codes.NotFound, err.Error())
|
||||
}
|
||||
|
||||
@@ -153,13 +153,15 @@ type mockInstanceVerifier struct {
|
||||
host string
|
||||
}
|
||||
|
||||
func (m *mockInstanceVerifier) InstanceByHost(ctx context.Context, host string) (authz.Instance, error) {
|
||||
func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (authz.Instance, error) {
|
||||
if host != m.host {
|
||||
return nil, fmt.Errorf("invalid host")
|
||||
}
|
||||
return &mockInstance{}, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceVerifier) InstanceByID(context.Context) (authz.Instance, error) { return nil, nil }
|
||||
|
||||
type mockInstance struct{}
|
||||
|
||||
func (m *mockInstance) InstanceID() string {
|
||||
|
||||
46
internal/api/grpc/server/middleware/quota_interceptor.go
Normal file
46
internal/api/grpc/server/middleware/quota_interceptor.go
Normal file
@@ -0,0 +1,46 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
"github.com/zitadel/zitadel/internal/errors"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
func QuotaExhaustedInterceptor(svc *logstore.Service, ignoreService ...string) grpc.UnaryServerInterceptor {
|
||||
|
||||
prunedIgnoredServices := make([]string, len(ignoreService))
|
||||
for idx, service := range ignoreService {
|
||||
if !strings.HasPrefix(service, "/") {
|
||||
service = "/" + service
|
||||
}
|
||||
prunedIgnoredServices[idx] = service
|
||||
}
|
||||
|
||||
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (_ interface{}, err error) {
|
||||
if !svc.Enabled() {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
interceptorCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
for _, service := range prunedIgnoredServices {
|
||||
if strings.HasPrefix(info.FullMethod, service) {
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
|
||||
instance := authz.GetInstance(ctx)
|
||||
remaining := svc.Limit(interceptorCtx, instance.InstanceID())
|
||||
if remaining != nil && *remaining == 0 {
|
||||
return nil, errors.ThrowResourceExhausted(nil, "QUOTA-vjAy8", "Quota.Access.Exhausted")
|
||||
}
|
||||
span.End()
|
||||
return handler(ctx, req)
|
||||
}
|
||||
}
|
||||
@@ -2,7 +2,6 @@ package server
|
||||
|
||||
import (
|
||||
"crypto/tls"
|
||||
|
||||
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials"
|
||||
@@ -10,6 +9,7 @@ import (
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
grpc_api "github.com/zitadel/zitadel/internal/api/grpc"
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/server/middleware"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/query"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/metrics"
|
||||
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
@@ -23,7 +23,14 @@ type Server interface {
|
||||
AuthMethods() authz.MethodMapping
|
||||
}
|
||||
|
||||
func CreateServer(verifier *authz.TokenVerifier, authConfig authz.Config, queries *query.Queries, hostHeaderName string, tlsConfig *tls.Config) *grpc.Server {
|
||||
func CreateServer(
|
||||
verifier *authz.TokenVerifier,
|
||||
authConfig authz.Config,
|
||||
queries *query.Queries,
|
||||
hostHeaderName string,
|
||||
tlsConfig *tls.Config,
|
||||
accessSvc *logstore.Service,
|
||||
) *grpc.Server {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeTotalCount, metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode}
|
||||
serverOptions := []grpc.ServerOption{
|
||||
grpc.UnaryInterceptor(
|
||||
@@ -33,10 +40,12 @@ func CreateServer(verifier *authz.TokenVerifier, authConfig authz.Config, querie
|
||||
middleware.NoCacheInterceptor(),
|
||||
middleware.ErrorHandler(),
|
||||
middleware.InstanceInterceptor(queries, hostHeaderName, system_pb.SystemService_MethodPrefix),
|
||||
middleware.AccessStorageInterceptor(accessSvc),
|
||||
middleware.AuthorizationInterceptor(verifier, authConfig),
|
||||
middleware.TranslationHandler(),
|
||||
middleware.ValidationHandler(),
|
||||
middleware.ServiceHandler(),
|
||||
middleware.QuotaExhaustedInterceptor(accessSvc, system_pb.SystemService_MethodPrefix),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
@@ -31,7 +31,6 @@ func (s *Server) ListInstances(ctx context.Context, req *system_pb.ListInstances
|
||||
}
|
||||
|
||||
func (s *Server) GetInstance(ctx context.Context, req *system_pb.GetInstanceRequest) (*system_pb.GetInstanceResponse, error) {
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
instance, err := s.query.Instance(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -53,7 +52,6 @@ func (s *Server) AddInstance(ctx context.Context, req *system_pb.AddInstanceRequ
|
||||
}
|
||||
|
||||
func (s *Server) UpdateInstance(ctx context.Context, req *system_pb.UpdateInstanceRequest) (*system_pb.UpdateInstanceResponse, error) {
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
details, err := s.command.UpdateInstance(ctx, req.InstanceName)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -86,7 +84,6 @@ func (s *Server) CreateInstance(ctx context.Context, req *system_pb.CreateInstan
|
||||
}
|
||||
|
||||
func (s *Server) RemoveInstance(ctx context.Context, req *system_pb.RemoveInstanceRequest) (*system_pb.RemoveInstanceResponse, error) {
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
details, err := s.command.RemoveInstance(ctx, req.InstanceId)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -97,7 +94,6 @@ func (s *Server) RemoveInstance(ctx context.Context, req *system_pb.RemoveInstan
|
||||
}
|
||||
|
||||
func (s *Server) ListIAMMembers(ctx context.Context, req *system_pb.ListIAMMembersRequest) (*system_pb.ListIAMMembersResponse, error) {
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
queries, err := ListIAMMembersRequestToQuery(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -139,7 +135,6 @@ func (s *Server) ExistsDomain(ctx context.Context, req *system_pb.ExistsDomainRe
|
||||
}
|
||||
|
||||
func (s *Server) ListDomains(ctx context.Context, req *system_pb.ListDomainsRequest) (*system_pb.ListDomainsResponse, error) {
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
queries, err := ListInstanceDomainsRequestToModel(req)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -156,8 +151,6 @@ func (s *Server) ListDomains(ctx context.Context, req *system_pb.ListDomainsRequ
|
||||
}
|
||||
|
||||
func (s *Server) AddDomain(ctx context.Context, req *system_pb.AddDomainRequest) (*system_pb.AddDomainResponse, error) {
|
||||
//TODO: should be solved in interceptor
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
instance, err := s.query.Instance(ctx, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -174,7 +167,6 @@ func (s *Server) AddDomain(ctx context.Context, req *system_pb.AddDomainRequest)
|
||||
}
|
||||
|
||||
func (s *Server) RemoveDomain(ctx context.Context, req *system_pb.RemoveDomainRequest) (*system_pb.RemoveDomainResponse, error) {
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
details, err := s.command.RemoveInstanceDomain(ctx, req.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -185,7 +177,6 @@ func (s *Server) RemoveDomain(ctx context.Context, req *system_pb.RemoveDomainRe
|
||||
}
|
||||
|
||||
func (s *Server) SetPrimaryDomain(ctx context.Context, req *system_pb.SetPrimaryDomainRequest) (*system_pb.SetPrimaryDomainResponse, error) {
|
||||
ctx = authz.WithInstanceID(ctx, req.InstanceId)
|
||||
details, err := s.command.SetPrimaryInstanceDomain(ctx, req.Domain)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
32
internal/api/grpc/system/quota.go
Normal file
32
internal/api/grpc/system/quota.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/grpc/object"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
system_pb "github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
)
|
||||
|
||||
func (s *Server) AddQuota(ctx context.Context, req *system.AddQuotaRequest) (*system.AddQuotaResponse, error) {
|
||||
details, err := s.command.AddQuota(
|
||||
ctx,
|
||||
instanceQuotaPbToCommand(req),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &system_pb.AddQuotaResponse{
|
||||
Details: object.AddToDetailsPb(details.Sequence, details.EventDate, details.ResourceOwner),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (s *Server) RemoveQuota(ctx context.Context, req *system.RemoveQuotaRequest) (*system.RemoveQuotaResponse, error) {
|
||||
details, err := s.command.RemoveQuota(ctx, instanceQuotaUnitPbToCommand(req.Unit))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &system_pb.RemoveQuotaResponse{
|
||||
Details: object.ChangeToDetailsPb(details.Sequence, details.EventDate, details.ResourceOwner),
|
||||
}, nil
|
||||
}
|
||||
43
internal/api/grpc/system/quota_converter.go
Normal file
43
internal/api/grpc/system/quota_converter.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package system
|
||||
|
||||
import (
|
||||
"github.com/zitadel/zitadel/internal/command"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/quota"
|
||||
"github.com/zitadel/zitadel/pkg/grpc/system"
|
||||
)
|
||||
|
||||
func instanceQuotaPbToCommand(req *system.AddQuotaRequest) *command.AddQuota {
|
||||
return &command.AddQuota{
|
||||
Unit: instanceQuotaUnitPbToCommand(req.Unit),
|
||||
From: req.From.AsTime(),
|
||||
ResetInterval: req.ResetInterval.AsDuration(),
|
||||
Amount: req.Amount,
|
||||
Limit: req.Limit,
|
||||
Notifications: instanceQuotaNotificationsPbToCommand(req.Notifications),
|
||||
}
|
||||
}
|
||||
|
||||
func instanceQuotaUnitPbToCommand(unit quota.Unit) command.QuotaUnit {
|
||||
switch unit {
|
||||
case quota.Unit_UNIT_REQUESTS_ALL_AUTHENTICATED:
|
||||
return command.QuotaRequestsAllAuthenticated
|
||||
case quota.Unit_UNIT_ACTIONS_ALL_RUN_SECONDS:
|
||||
return command.QuotaActionsAllRunsSeconds
|
||||
case quota.Unit_UNIT_UNIMPLEMENTED:
|
||||
fallthrough
|
||||
default:
|
||||
return command.QuotaUnit(unit.String())
|
||||
}
|
||||
}
|
||||
|
||||
func instanceQuotaNotificationsPbToCommand(req []*quota.Notification) command.QuotaNotifications {
|
||||
notifications := make([]*command.QuotaNotification, len(req))
|
||||
for idx, item := range req {
|
||||
notifications[idx] = &command.QuotaNotification{
|
||||
Percent: uint16(item.Percent),
|
||||
Repeat: item.Repeat,
|
||||
CallURL: item.CallUrl,
|
||||
}
|
||||
}
|
||||
return notifications
|
||||
}
|
||||
@@ -72,7 +72,9 @@ func WithPath(path string) CookieHandlerOpt {
|
||||
func WithMaxAge(maxAge int) CookieHandlerOpt {
|
||||
return func(c *CookieHandler) {
|
||||
c.maxAge = maxAge
|
||||
c.securecookie.MaxAge(maxAge)
|
||||
if c.securecookie != nil {
|
||||
c.securecookie.MaxAge(maxAge)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
102
internal/api/http/middleware/access_interceptor.go
Normal file
102
internal/api/http/middleware/access_interceptor.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"math"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
"github.com/zitadel/logging"
|
||||
|
||||
"github.com/zitadel/zitadel/internal/api/authz"
|
||||
http_utils "github.com/zitadel/zitadel/internal/api/http"
|
||||
"github.com/zitadel/zitadel/internal/logstore"
|
||||
"github.com/zitadel/zitadel/internal/logstore/emitters/access"
|
||||
"github.com/zitadel/zitadel/internal/telemetry/tracing"
|
||||
)
|
||||
|
||||
type AccessInterceptor struct {
|
||||
svc *logstore.Service
|
||||
cookieHandler *http_utils.CookieHandler
|
||||
limitConfig *AccessConfig
|
||||
}
|
||||
|
||||
type AccessConfig struct {
|
||||
ExhaustedCookieKey string
|
||||
ExhaustedCookieMaxAge time.Duration
|
||||
}
|
||||
|
||||
func NewAccessInterceptor(svc *logstore.Service, cookieConfig *AccessConfig) *AccessInterceptor {
|
||||
return &AccessInterceptor{
|
||||
svc: svc,
|
||||
cookieHandler: http_utils.NewCookieHandler(
|
||||
http_utils.WithUnsecure(),
|
||||
http_utils.WithMaxAge(int(math.Floor(cookieConfig.ExhaustedCookieMaxAge.Seconds()))),
|
||||
),
|
||||
limitConfig: cookieConfig,
|
||||
}
|
||||
}
|
||||
|
||||
func (a *AccessInterceptor) Handle(next http.Handler) http.Handler {
|
||||
if !a.svc.Enabled() {
|
||||
return next
|
||||
}
|
||||
return http.HandlerFunc(func(writer http.ResponseWriter, request *http.Request) {
|
||||
|
||||
ctx := request.Context()
|
||||
var err error
|
||||
|
||||
tracingCtx, span := tracing.NewServerInterceptorSpan(ctx)
|
||||
defer func() { span.EndWithError(err) }()
|
||||
|
||||
wrappedWriter := &statusRecorder{ResponseWriter: writer, status: 0}
|
||||
|
||||
instance := authz.GetInstance(ctx)
|
||||
remaining := a.svc.Limit(tracingCtx, instance.InstanceID())
|
||||
limit := remaining != nil && *remaining == 0
|
||||
|
||||
a.cookieHandler.SetCookie(wrappedWriter, a.limitConfig.ExhaustedCookieKey, request.Host, strconv.FormatBool(limit))
|
||||
|
||||
if limit {
|
||||
wrappedWriter.WriteHeader(http.StatusTooManyRequests)
|
||||
wrappedWriter.ignoreWrites = true
|
||||
}
|
||||
|
||||
next.ServeHTTP(wrappedWriter, request)
|
||||
|
||||
requestURL := request.RequestURI
|
||||
unescapedURL, err := url.QueryUnescape(requestURL)
|
||||
if err != nil {
|
||||
logging.WithError(err).WithField("url", requestURL).Warning("failed to unescape request url")
|
||||
// err = nil is effective because of deferred tracing span end
|
||||
err = nil
|
||||
}
|
||||
a.svc.Handle(tracingCtx, &access.Record{
|
||||
LogDate: time.Now(),
|
||||
Protocol: access.HTTP,
|
||||
RequestURL: unescapedURL,
|
||||
ResponseStatus: uint32(wrappedWriter.status),
|
||||
RequestHeaders: request.Header,
|
||||
ResponseHeaders: writer.Header(),
|
||||
InstanceID: instance.InstanceID(),
|
||||
ProjectID: instance.ProjectID(),
|
||||
RequestedDomain: instance.RequestedDomain(),
|
||||
RequestedHost: instance.RequestedHost(),
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
type statusRecorder struct {
|
||||
http.ResponseWriter
|
||||
status int
|
||||
ignoreWrites bool
|
||||
}
|
||||
|
||||
func (r *statusRecorder) WriteHeader(status int) {
|
||||
if r.ignoreWrites {
|
||||
return
|
||||
}
|
||||
r.status = status
|
||||
r.ResponseWriter.WriteHeader(status)
|
||||
}
|
||||
@@ -244,6 +244,10 @@ func (m *mockInstanceVerifier) InstanceByHost(_ context.Context, host string) (a
|
||||
return &mockInstance{}, nil
|
||||
}
|
||||
|
||||
func (m *mockInstanceVerifier) InstanceByID(context.Context) (authz.Instance, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
type mockInstance struct{}
|
||||
|
||||
func (m *mockInstance) InstanceID() string {
|
||||
|
||||
@@ -427,7 +427,7 @@ func (o *OPStorage) userinfoFlows(ctx context.Context, resourceOwner string, use
|
||||
apiFields,
|
||||
action.Script,
|
||||
action.Name,
|
||||
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
|
||||
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
@@ -583,7 +583,7 @@ func (o *OPStorage) privateClaimsFlows(ctx context.Context, userID string, claim
|
||||
apiFields,
|
||||
action.Script,
|
||||
action.Name,
|
||||
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
|
||||
append(actions.ActionToOptions(action), actions.WithHTTP(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
|
||||
@@ -73,13 +73,13 @@ type OPStorage struct {
|
||||
assetAPIPrefix func(ctx context.Context) string
|
||||
}
|
||||
|
||||
func NewProvider(ctx context.Context, config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *sql.DB, userAgentCookie, instanceHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
|
||||
func NewProvider(ctx context.Context, config Config, defaultLogoutRedirectURI string, externalSecure bool, command *command.Commands, query *query.Queries, repo repository.Repository, encryptionAlg crypto.EncryptionAlgorithm, cryptoKey []byte, es *eventstore.Eventstore, projections *sql.DB, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) (op.OpenIDProvider, error) {
|
||||
opConfig, err := createOPConfig(config, defaultLogoutRedirectURI, cryptoKey)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "OIDC-EGrqd", "cannot create op config: %w")
|
||||
}
|
||||
storage := newStorage(config, command, query, repo, encryptionAlg, es, projections, externalSecure)
|
||||
options, err := createOptions(config, externalSecure, userAgentCookie, instanceHandler)
|
||||
options, err := createOptions(config, externalSecure, userAgentCookie, instanceHandler, accessHandler)
|
||||
if err != nil {
|
||||
return nil, caos_errs.ThrowInternal(err, "OIDC-D3gq1", "cannot create options: %w")
|
||||
}
|
||||
@@ -117,7 +117,7 @@ func createOPConfig(config Config, defaultLogoutRedirectURI string, cryptoKey []
|
||||
return opConfig, nil
|
||||
}
|
||||
|
||||
func createOptions(config Config, externalSecure bool, userAgentCookie, instanceHandler func(http.Handler) http.Handler) ([]op.Option, error) {
|
||||
func createOptions(config Config, externalSecure bool, userAgentCookie, instanceHandler, accessHandler func(http.Handler) http.Handler) ([]op.Option, error) {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
options := []op.Option{
|
||||
op.WithHttpInterceptors(
|
||||
@@ -127,6 +127,7 @@ func createOptions(config Config, externalSecure bool, userAgentCookie, instance
|
||||
instanceHandler,
|
||||
userAgentCookie,
|
||||
http_utils.CopyHeadersToContext,
|
||||
accessHandler,
|
||||
),
|
||||
}
|
||||
if !externalSecure {
|
||||
|
||||
@@ -40,7 +40,8 @@ func NewProvider(
|
||||
es *eventstore.Eventstore,
|
||||
projections *sql.DB,
|
||||
instanceHandler,
|
||||
userAgentCookie func(http.Handler) http.Handler,
|
||||
userAgentCookie,
|
||||
accessHandler func(http.Handler) http.Handler,
|
||||
) (*provider.Provider, error) {
|
||||
metricTypes := []metrics.MetricType{metrics.MetricTypeRequestCount, metrics.MetricTypeStatusCode, metrics.MetricTypeTotalCount}
|
||||
|
||||
@@ -64,6 +65,7 @@ func NewProvider(
|
||||
middleware.NoCacheInterceptor().Handler,
|
||||
instanceHandler,
|
||||
userAgentCookie,
|
||||
accessHandler,
|
||||
http_utils.CopyHeadersToContext,
|
||||
),
|
||||
}
|
||||
|
||||
@@ -88,7 +88,7 @@ func (f *file) Stat() (_ fs.FileInfo, err error) {
|
||||
return f, nil
|
||||
}
|
||||
|
||||
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, instanceHandler func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
|
||||
func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, instanceHandler, accessInterceptor func(http.Handler) http.Handler, customerPortal string) (http.Handler, error) {
|
||||
fSys, err := fs.Sub(static, "static")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -103,7 +103,7 @@ func Start(config Config, externalSecure bool, issuer op.IssuerFromRequest, inst
|
||||
|
||||
handler := mux.NewRouter()
|
||||
|
||||
handler.Use(instanceHandler, security)
|
||||
handler.Use(instanceHandler, security, accessInterceptor)
|
||||
handler.Handle(envRequestPath, middleware.TelemetryHandler()(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
url := http_util.BuildOrigin(r.Host, externalSecure)
|
||||
environmentJSON, err := createEnvironmentJSON(url, issuer(r), authz.GetInstance(r.Context()).ConsoleClientID(), customerPortal)
|
||||
|
||||
@@ -106,7 +106,7 @@ func (l *Login) runPostExternalAuthenticationActions(
|
||||
apiFields,
|
||||
a.Script,
|
||||
a.Name,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
@@ -175,7 +175,7 @@ func (l *Login) runPostInternalAuthenticationActions(
|
||||
apiFields,
|
||||
a.Script,
|
||||
a.Name,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
@@ -274,7 +274,7 @@ func (l *Login) runPreCreationActions(
|
||||
apiFields,
|
||||
a.Script,
|
||||
a.Name,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
@@ -332,7 +332,7 @@ func (l *Login) runPostCreationActions(
|
||||
apiFields,
|
||||
a.Script,
|
||||
a.Name,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx), actions.WithLogger(actions.ServerLog))...,
|
||||
append(actions.ActionToOptions(a), actions.WithHTTP(actionCtx))...,
|
||||
)
|
||||
cancel()
|
||||
if err != nil {
|
||||
|
||||
@@ -69,6 +69,7 @@ func CreateLogin(config Config,
|
||||
oidcInstanceHandler,
|
||||
samlInstanceHandler mux.MiddlewareFunc,
|
||||
assetCache mux.MiddlewareFunc,
|
||||
accessHandler mux.MiddlewareFunc,
|
||||
userCodeAlg crypto.EncryptionAlgorithm,
|
||||
idpConfigAlg crypto.EncryptionAlgorithm,
|
||||
csrfCookieKey []byte,
|
||||
@@ -94,7 +95,7 @@ func CreateLogin(config Config,
|
||||
cacheInterceptor := createCacheInterceptor(config.Cache.MaxAge, config.Cache.SharedMaxAge, assetCache)
|
||||
security := middleware.SecurityHeaders(csp(), login.cspErrorHandler)
|
||||
|
||||
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor)
|
||||
login.router = CreateRouter(login, statikFS, middleware.TelemetryHandler(IgnoreInstanceEndpoints...), oidcInstanceHandler, samlInstanceHandler, csrfInterceptor, cacheInterceptor, security, userAgentCookie, issuerInterceptor, accessHandler)
|
||||
login.renderer = CreateRenderer(HandlerPrefix, statikFS, staticStorage, config.LanguageCookieName)
|
||||
login.parser = form.NewParser()
|
||||
return login, nil
|
||||
|
||||
Reference in New Issue
Block a user